From 655a3bd75fcb36c1810beaef75b5c8b5757c6ed7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Wed, 27 May 2026 09:29:28 +0200 Subject: [PATCH 1/4] Add support for mat-vec product --- src/JuMP/JuMP.jl | 1 + src/JuMP/operators.jl | 21 +++++++++++++ src/reverse_mode.jl | 68 +++++++++++++++++++++++++++++++++++-------- 3 files changed, 78 insertions(+), 12 deletions(-) diff --git a/src/JuMP/JuMP.jl b/src/JuMP/JuMP.jl index 9ed23d4..ddb62e1 100644 --- a/src/JuMP/JuMP.jl +++ b/src/JuMP/JuMP.jl @@ -4,6 +4,7 @@ import JuMP # Equivalent of `AbstractJuMPScalar` but for arrays abstract type AbstractJuMPArray{T,N} <: AbstractArray{T,N} end +const AbstractJuMPVector{T} = AbstractJuMPArray{T,1} const AbstractJuMPMatrix{T} = AbstractJuMPArray{T,2} include("variables.jl") diff --git a/src/JuMP/operators.jl b/src/JuMP/operators.jl index b531488..4b17d1f 100644 --- a/src/JuMP/operators.jl +++ b/src/JuMP/operators.jl @@ -12,6 +12,27 @@ function Base.:(*)(A::AbstractJuMPMatrix, B::AbstractJuMPMatrix) return _matmul(JuMP.variable_ref_type(A), A, B) end +# Matrix-vector products: output is a 1-D `GenericArrayExpr` of length +# `size(A, 1)`. Allows users to write `W * x` for a vector variable `x`. +function _matvec(::Type{V}, A, b) where {V} + return GenericArrayExpr{V,1}(:*, Any[A, b], (size(A, 1),), false) +end + +function Base.:(*)(A::AbstractJuMPMatrix, b::Vector) + return _matvec(JuMP.variable_ref_type(A), A, b) +end + +function Base.:(*)(A::Matrix, b::AbstractJuMPVector{T}) where {T} + return _matvec(JuMP.variable_ref_type(b), A, b) +end + +function Base.:(*)( + A::AbstractJuMPMatrix, + b::AbstractJuMPVector{T}, +) where {T} + return _matvec(JuMP.variable_ref_type(A), A, b) +end + function __broadcast( ::Type{V}, axes::NTuple{N,Base.OneTo{Int}}, diff --git a/src/reverse_mode.jl b/src/reverse_mode.jl index db93906..7a9484e 100644 --- a/src/reverse_mode.jl +++ b/src/reverse_mode.jl @@ -4,6 +4,46 @@ # Use of this source code is governed by an MIT-style license that can be found # in the LICENSE.md file or at https://opensource.org/licenses/MIT. +# Reverse-mode contribution for a matmul node `k` with children `ix1`, `ix2`. +# `f.sizes.ndims[k]` may be 1 (mat-vec) or 2 (mat-mat); `_reshape_call` picks +# the right view type for each node and `LinearAlgebra.mul!` covers both +# shape combinations. +function _matmul_reverse!(f, k::Int, ix1::Int, ix2::Int) + _reshape_call( + f.forward_storage, + f.sizes, + (ix1, ix2), + _matmul_reverse_outer, + (f.reverse_storage, f.sizes, ix1, ix2, k), + ) + return +end + +function _matmul_reverse_outer( + reverse_storage, + sizes::Sizes, + ix1::Int, + ix2::Int, + k::Int, + v1, + v2, +) + _reshape_call( + reverse_storage, + sizes, + (ix1, ix2, k), + _matmul_reverse_inner!, + (v1, v2), + ) + return +end + +function _matmul_reverse_inner!(v1, v2, rev_v1, rev_v2, rev_parent) + LinearAlgebra.mul!(rev_v1, rev_parent, transpose(v2)) + LinearAlgebra.mul!(rev_v2, transpose(v1), rev_parent) + return +end + """ _reverse_mode(d::NLPEvaluator, x) @@ -177,10 +217,17 @@ function _forward_eval( idx2 = last(children_indices) @inbounds ix1 = children_arr[idx1] @inbounds ix2 = children_arr[idx2] - v1 = _view_matrix(f.forward_storage, f.sizes, ix1) - v2 = _view_matrix(f.forward_storage, f.sizes, ix2) - out = _view_matrix(f.forward_storage, f.sizes, k) - LinearAlgebra.mul!(out, v1, v2) + # `_reshape_call` dispatches each node to the right view + # type based on its `ndims`. `LinearAlgebra.mul!` then + # picks the matching method — mat-mat for `ndims[k] == 2`, + # mat-vec for `ndims[k] == 1`. + _reshape_call( + f.forward_storage, + f.sizes, + (k, ix1, ix2), + LinearAlgebra.mul!, + (), + ) # We deliberately don't write v1/v2 into partials_storage # here: the matmul reverse branch reads forward_storage # directly, so those writes were dead. @@ -787,18 +834,15 @@ function _reverse_eval( # straight from forward_storage (the matmul forward # branch deliberately doesn't snapshot them into # partials_storage), and the reverse views are written - # in place. + # in place. Two nested `_reshape_call`s dispatch each + # node to the right view type based on `ndims`, so the + # same code path covers mat-mat (`ndims[k] == 2`) and + # mat-vec (`ndims[k] == 1`). idx1 = first(children_indices) idx2 = last(children_indices) ix1 = children_arr[idx1] ix2 = children_arr[idx2] - v1 = _view_matrix(f.forward_storage, f.sizes, ix1) - v2 = _view_matrix(f.forward_storage, f.sizes, ix2) - rev_parent = _view_matrix(f.reverse_storage, f.sizes, k) - rev_v1 = _view_matrix(f.reverse_storage, f.sizes, ix1) - rev_v2 = _view_matrix(f.reverse_storage, f.sizes, ix2) - LinearAlgebra.mul!(rev_v1, rev_parent, v2') - LinearAlgebra.mul!(rev_v2, v1', rev_parent) + _matmul_reverse!(f, k, ix1, ix2) continue end elseif op == :vect From c2d44cf321d7dfc89d971c852f526dca91b0c67a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Wed, 27 May 2026 09:31:00 +0200 Subject: [PATCH 2/4] Add test --- test/JuMP.jl | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/test/JuMP.jl b/test/JuMP.jl index 35ba6f6..9575a46 100644 --- a/test/JuMP.jl +++ b/test/JuMP.jl @@ -1063,6 +1063,32 @@ function test_op_reversed_args() return end +function test_matvec_gradient() + # `W * x` for a 1-D `ArrayOfVariables` `x` exercises the mat-vec branch + # of `_matmul_reverse!`. Loss is `sum((W*x - target).^2)` so the analytic + # gradient is `2 * W' * (W*x - target)`. + m, n = 3, 4 + W = [ + 0.4 -0.2 0.1 0.3 + -0.3 0.5 0.2 -0.1 + 0.1 0.1 -0.4 0.2 + ] + target = [0.5, -0.2, 0.1] + model = Model() + @variable(model, x[1:n], container = ArrayDiff.ArrayOfVariables) + y = W * x + @test y isa ArrayDiff.GenericArrayExpr + @test ndims(y) == 1 + @test size(y) == (m,) + @test y.head == :* + loss = sum((y .- target) .^ 2) + x_val = [0.6, -0.3, 0.4, -0.1] + _, val, g, _ = _eval(model, loss, x_val; x_grad = x_val) + @test val ≈ sum((W * x_val .- target) .^ 2) + @test g ≈ 2 * W' * (W * x_val .- target) + return +end + end # module TestJuMP.runtests() From 39a2619939712f1d04bfc2a702a8b495cc672993 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Wed, 27 May 2026 09:31:07 +0200 Subject: [PATCH 3/4] Fix format --- src/JuMP/operators.jl | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/JuMP/operators.jl b/src/JuMP/operators.jl index 4b17d1f..0d88b6f 100644 --- a/src/JuMP/operators.jl +++ b/src/JuMP/operators.jl @@ -26,10 +26,7 @@ function Base.:(*)(A::Matrix, b::AbstractJuMPVector{T}) where {T} return _matvec(JuMP.variable_ref_type(b), A, b) end -function Base.:(*)( - A::AbstractJuMPMatrix, - b::AbstractJuMPVector{T}, -) where {T} +function Base.:(*)(A::AbstractJuMPMatrix, b::AbstractJuMPVector{T}) where {T} return _matvec(JuMP.variable_ref_type(A), A, b) end From 22c30d50c63f90a21e754ac40682a026e01cd75f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Wed, 27 May 2026 09:47:33 +0200 Subject: [PATCH 4/4] Add tests --- test/JuMP.jl | 60 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/test/JuMP.jl b/test/JuMP.jl index 9575a46..08ca93b 100644 --- a/test/JuMP.jl +++ b/test/JuMP.jl @@ -1089,6 +1089,66 @@ function test_matvec_gradient() return end +function test_matvec_jump_matrix_times_const_vector_gradient() + # `W * x` where `W` is an `AbstractJuMPMatrix` and `x` is a constant + # `Vector`. Loss = sum((W*x - target).^2). Analytic gradient w.r.t. W is + # `2 * (W*x - target) * x'` (outer product). JuMP stores `W` column-major, + # so the flat gradient vector is `vec(2 * (W*x - target) * x')`. + m, n = 3, 4 + x_const = [0.6, -0.3, 0.4, -0.1] + target = [0.5, -0.2, 0.1] + model = Model() + @variable(model, W[1:m, 1:n], container = ArrayDiff.ArrayOfVariables) + y = W * x_const + @test y isa ArrayDiff.GenericArrayExpr + @test ndims(y) == 1 + @test size(y) == (m,) + @test y.head == :* + loss = sum((y .- target) .^ 2) + W_val = [ + 0.4 -0.2 0.1 0.3 + -0.3 0.5 0.2 -0.1 + 0.1 0.1 -0.4 0.2 + ] + flat_W = vec(W_val) + _, val, g, _ = _eval(model, loss, flat_W; x_grad = flat_W) + @test val ≈ sum((W_val * x_const .- target) .^ 2) + @test g ≈ vec(2 * (W_val * x_const .- target) * x_const') + return +end + +function test_matvec_jump_matrix_times_jump_vector_gradient() + # `W * x` where both `W` and `x` are `ArrayOfVariables`. Loss is + # `sum((W*x - target).^2)`. Gradients: ∂/∂W = 2 (Wx-t) x', + # ∂/∂x = 2 W' (Wx-t). The flat variable layout is `[vec(W); x]` because + # `W` is declared first. + m, n = 3, 4 + target = [0.5, -0.2, 0.1] + model = Model() + @variable(model, W[1:m, 1:n], container = ArrayDiff.ArrayOfVariables) + @variable(model, x[1:n], container = ArrayDiff.ArrayOfVariables) + y = W * x + @test y isa ArrayDiff.GenericArrayExpr + @test ndims(y) == 1 + @test size(y) == (m,) + @test y.head == :* + loss = sum((y .- target) .^ 2) + W_val = [ + 0.4 -0.2 0.1 0.3 + -0.3 0.5 0.2 -0.1 + 0.1 0.1 -0.4 0.2 + ] + x_val = [0.6, -0.3, 0.4, -0.1] + flat = [vec(W_val); x_val] + _, val, g, _ = _eval(model, loss, flat; x_grad = flat) + @test val ≈ sum((W_val * x_val .- target) .^ 2) + residual = W_val * x_val .- target + grad_W = 2 * residual * x_val' + grad_x = 2 * W_val' * residual + @test g ≈ [vec(grad_W); grad_x] + return +end + end # module TestJuMP.runtests()