Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/JuMP/JuMP.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import JuMP

# Equivalent of `AbstractJuMPScalar` but for arrays
abstract type AbstractJuMPArray{T,N} <: AbstractArray{T,N} end
const AbstractJuMPVector{T} = AbstractJuMPArray{T,1}
const AbstractJuMPMatrix{T} = AbstractJuMPArray{T,2}

include("variables.jl")
Expand Down
18 changes: 18 additions & 0 deletions src/JuMP/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,24 @@ function Base.:(*)(A::AbstractJuMPMatrix, B::AbstractJuMPMatrix)
return _matmul(JuMP.variable_ref_type(A), A, B)
end

# Matrix-vector products: output is a 1-D `GenericArrayExpr` of length
# `size(A, 1)`. Allows users to write `W * x` for a vector variable `x`.
function _matvec(::Type{V}, A, b) where {V}
return GenericArrayExpr{V,1}(:*, Any[A, b], (size(A, 1),), false)
end

function Base.:(*)(A::AbstractJuMPMatrix, b::Vector)
return _matvec(JuMP.variable_ref_type(A), A, b)
end

function Base.:(*)(A::Matrix, b::AbstractJuMPVector{T}) where {T}
return _matvec(JuMP.variable_ref_type(b), A, b)
end

function Base.:(*)(A::AbstractJuMPMatrix, b::AbstractJuMPVector{T}) where {T}
return _matvec(JuMP.variable_ref_type(A), A, b)
end

function __broadcast(
::Type{V},
axes::NTuple{N,Base.OneTo{Int}},
Expand Down
68 changes: 56 additions & 12 deletions src/reverse_mode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,46 @@
# Use of this source code is governed by an MIT-style license that can be found
# in the LICENSE.md file or at https://opensource.org/licenses/MIT.

# Reverse-mode contribution for a matmul node `k` with children `ix1`, `ix2`.
# `f.sizes.ndims[k]` may be 1 (mat-vec) or 2 (mat-mat); `_reshape_call` picks
# the right view type for each node and `LinearAlgebra.mul!` covers both
# shape combinations.
function _matmul_reverse!(f, k::Int, ix1::Int, ix2::Int)
_reshape_call(
f.forward_storage,
f.sizes,
(ix1, ix2),
_matmul_reverse_outer,
(f.reverse_storage, f.sizes, ix1, ix2, k),
)
return
end

function _matmul_reverse_outer(
reverse_storage,
sizes::Sizes,
ix1::Int,
ix2::Int,
k::Int,
v1,
v2,
)
_reshape_call(
reverse_storage,
sizes,
(ix1, ix2, k),
_matmul_reverse_inner!,
(v1, v2),
)
return
end

function _matmul_reverse_inner!(v1, v2, rev_v1, rev_v2, rev_parent)
LinearAlgebra.mul!(rev_v1, rev_parent, transpose(v2))
LinearAlgebra.mul!(rev_v2, transpose(v1), rev_parent)
return
end

"""
_reverse_mode(d::NLPEvaluator, x)

Expand Down Expand Up @@ -177,10 +217,17 @@ function _forward_eval(
idx2 = last(children_indices)
@inbounds ix1 = children_arr[idx1]
@inbounds ix2 = children_arr[idx2]
v1 = _view_matrix(f.forward_storage, f.sizes, ix1)
v2 = _view_matrix(f.forward_storage, f.sizes, ix2)
out = _view_matrix(f.forward_storage, f.sizes, k)
LinearAlgebra.mul!(out, v1, v2)
# `_reshape_call` dispatches each node to the right view
# type based on its `ndims`. `LinearAlgebra.mul!` then
# picks the matching method — mat-mat for `ndims[k] == 2`,
# mat-vec for `ndims[k] == 1`.
_reshape_call(
f.forward_storage,
f.sizes,
(k, ix1, ix2),
LinearAlgebra.mul!,
(),
)
# We deliberately don't write v1/v2 into partials_storage
# here: the matmul reverse branch reads forward_storage
# directly, so those writes were dead.
Expand Down Expand Up @@ -787,18 +834,15 @@ function _reverse_eval(
# straight from forward_storage (the matmul forward
# branch deliberately doesn't snapshot them into
# partials_storage), and the reverse views are written
# in place.
# in place. Two nested `_reshape_call`s dispatch each
# node to the right view type based on `ndims`, so the
# same code path covers mat-mat (`ndims[k] == 2`) and
# mat-vec (`ndims[k] == 1`).
idx1 = first(children_indices)
idx2 = last(children_indices)
ix1 = children_arr[idx1]
ix2 = children_arr[idx2]
v1 = _view_matrix(f.forward_storage, f.sizes, ix1)
v2 = _view_matrix(f.forward_storage, f.sizes, ix2)
rev_parent = _view_matrix(f.reverse_storage, f.sizes, k)
rev_v1 = _view_matrix(f.reverse_storage, f.sizes, ix1)
rev_v2 = _view_matrix(f.reverse_storage, f.sizes, ix2)
LinearAlgebra.mul!(rev_v1, rev_parent, v2')
LinearAlgebra.mul!(rev_v2, v1', rev_parent)
_matmul_reverse!(f, k, ix1, ix2)
continue
end
elseif op == :vect
Expand Down
86 changes: 86 additions & 0 deletions test/JuMP.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1063,6 +1063,92 @@ function test_op_reversed_args()
return
end

function test_matvec_gradient()
# `W * x` for a 1-D `ArrayOfVariables` `x` exercises the mat-vec branch
# of `_matmul_reverse!`. Loss is `sum((W*x - target).^2)` so the analytic
# gradient is `2 * W' * (W*x - target)`.
m, n = 3, 4
W = [
0.4 -0.2 0.1 0.3
-0.3 0.5 0.2 -0.1
0.1 0.1 -0.4 0.2
]
target = [0.5, -0.2, 0.1]
model = Model()
@variable(model, x[1:n], container = ArrayDiff.ArrayOfVariables)
y = W * x
@test y isa ArrayDiff.GenericArrayExpr
@test ndims(y) == 1
@test size(y) == (m,)
@test y.head == :*
loss = sum((y .- target) .^ 2)
x_val = [0.6, -0.3, 0.4, -0.1]
_, val, g, _ = _eval(model, loss, x_val; x_grad = x_val)
@test val ≈ sum((W * x_val .- target) .^ 2)
@test g ≈ 2 * W' * (W * x_val .- target)
return
end

function test_matvec_jump_matrix_times_const_vector_gradient()
# `W * x` where `W` is an `AbstractJuMPMatrix` and `x` is a constant
# `Vector`. Loss = sum((W*x - target).^2). Analytic gradient w.r.t. W is
# `2 * (W*x - target) * x'` (outer product). JuMP stores `W` column-major,
# so the flat gradient vector is `vec(2 * (W*x - target) * x')`.
m, n = 3, 4
x_const = [0.6, -0.3, 0.4, -0.1]
target = [0.5, -0.2, 0.1]
model = Model()
@variable(model, W[1:m, 1:n], container = ArrayDiff.ArrayOfVariables)
y = W * x_const
@test y isa ArrayDiff.GenericArrayExpr
@test ndims(y) == 1
@test size(y) == (m,)
@test y.head == :*
loss = sum((y .- target) .^ 2)
W_val = [
0.4 -0.2 0.1 0.3
-0.3 0.5 0.2 -0.1
0.1 0.1 -0.4 0.2
]
flat_W = vec(W_val)
_, val, g, _ = _eval(model, loss, flat_W; x_grad = flat_W)
@test val ≈ sum((W_val * x_const .- target) .^ 2)
@test g ≈ vec(2 * (W_val * x_const .- target) * x_const')
return
end

function test_matvec_jump_matrix_times_jump_vector_gradient()
# `W * x` where both `W` and `x` are `ArrayOfVariables`. Loss is
# `sum((W*x - target).^2)`. Gradients: ∂/∂W = 2 (Wx-t) x',
# ∂/∂x = 2 W' (Wx-t). The flat variable layout is `[vec(W); x]` because
# `W` is declared first.
m, n = 3, 4
target = [0.5, -0.2, 0.1]
model = Model()
@variable(model, W[1:m, 1:n], container = ArrayDiff.ArrayOfVariables)
@variable(model, x[1:n], container = ArrayDiff.ArrayOfVariables)
y = W * x
@test y isa ArrayDiff.GenericArrayExpr
@test ndims(y) == 1
@test size(y) == (m,)
@test y.head == :*
loss = sum((y .- target) .^ 2)
W_val = [
0.4 -0.2 0.1 0.3
-0.3 0.5 0.2 -0.1
0.1 0.1 -0.4 0.2
]
x_val = [0.6, -0.3, 0.4, -0.1]
flat = [vec(W_val); x_val]
_, val, g, _ = _eval(model, loss, flat; x_grad = flat)
@test val ≈ sum((W_val * x_val .- target) .^ 2)
residual = W_val * x_val .- target
grad_W = 2 * residual * x_val'
grad_x = 2 * W_val' * residual
@test g ≈ [vec(grad_W); grad_x]
return
end

end # module

TestJuMP.runtests()
Loading