Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions src/JuMP/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,68 @@ function Base.:(+)(
@assert size(x) == size(y)
return GenericArrayExpr{V,N}(:+, Any[x, y], size(x), false)
end

# ── User-defined array operators ─────────────────────────────────────────────
#
# `add_operator(model, arity, f)` registers `f` on the `ArrayDiff.Model` via
# the [`UserDefinedArrayOperator`](@ref) attribute and returns a
# `JuMP.NonlinearOperator` wrapping `f`. When the returned operator is called
# with at least one `AbstractJuMPArray` argument, the dispatch methods below
# build either a `GenericArrayExpr` (when `f` returns an array) or a
# `JuMP.GenericNonlinearExpr` (scalar output) with that `name`. The reverse-
# mode derivative is pulled from `ChainRulesCore.rrule` at evaluator time.

"""
add_operator(model::Model, arity::Int, f::Function; name::Symbol = Symbol(f))

Register `f` as a user-defined array operator on `model` and return a
`JuMP.NonlinearOperator` wrapping it. Mirrors `JuMP.add_nonlinear_operator`:
the call internally does `MOI.set(model, UserDefinedArrayOperator(name; arity), f)`.
The returned operator can be called with `AbstractJuMPArray` arguments to
build a `GenericArrayExpr` (array result) or `JuMP.GenericNonlinearExpr`
(scalar result). The output shape is determined by [`infer_sizes`](@ref).
"""
function add_operator(
model::Model,
arity::Int,
f::Function;
name::Symbol = Symbol(f),
)
MOI.set(model, UserDefinedArrayOperator(name; arity), f)
return JuMP.NonlinearOperator(f, name)
end

function _build_user_op_expr(op::JuMP.NonlinearOperator, V::Type, args::Tuple)
shapes = map(size, args)
out_sz = infer_sizes(op.func, shapes...)
if isempty(out_sz)
return JuMP.GenericNonlinearExpr{V}(op.head, Any[args...])
end
return GenericArrayExpr{V,length(out_sz)}(
op.head,
Any[args...],
out_sz,
false,
)
end

function (op::JuMP.NonlinearOperator)(x::AbstractJuMPArray)
V = JuMP.variable_ref_type(x)
return _build_user_op_expr(op, V, (x,))
end

function (op::JuMP.NonlinearOperator)(
x::AbstractJuMPArray,
y::Union{Real,AbstractArray{<:Real},AbstractJuMPArray},
)
V = JuMP.variable_ref_type(x)
return _build_user_op_expr(op, V, (x, y))
end

function (op::JuMP.NonlinearOperator)(
x::Union{Real,AbstractArray{<:Real}},
y::AbstractJuMPArray,
)
V = JuMP.variable_ref_type(y)
return _build_user_op_expr(op, V, (x, y))
end
59 changes: 45 additions & 14 deletions src/sizes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,42 @@ function _assert_scalar_children(sizes, children_arr, children_indices, op)
end
end

"""
infer_sizes(op, child_sizes::Tuple...) -> Tuple

Return the output shape of applying `op` to arguments of shapes `child_sizes`.
Each `child_sizes[i]` is `()` if argument `i` is a scalar, or a tuple of
positive integers if it is an array. The returned shape is `()` for a scalar
result.

The default implementation constructs dummy arguments with `zeros(sz)`
(or `0.0` for scalars) and calls `op(args...)`. Specialise on `op`'s
`typeof` to avoid the allocation, to support operators that error on zero
inputs, or to compute the output shape symbolically.

## Example

For multiplication, `infer_sizes` can be implemented as follows
```julia
function infer_sizes(::typeof(*), lhs, rhs)
if isempty(lhs)
return rhs
end
if isempty(rhs)
return lhs
end
return (lhs[1:end-1]..., rhs[2:end]...)
end
```
"""
function infer_sizes(op, child_sizes::Tuple...)
args = map(child_sizes) do sz
return isempty(sz) ? 0.0 : zeros(sz)
end
y = op(args...)
return y isa AbstractArray ? size(y) : ()
end

function _infer_sizes(
nodes::Vector{Node},
adj::SparseArrays.SparseMatrixCSC{Bool,Int},
Expand Down Expand Up @@ -373,20 +409,15 @@ function _infer_sizes(
op_sym = operators.multivariate_operators[node.index]
if haskey(operators.chainrules_operators, op_sym)
f = operators.chainrules_operators[op_sym]
# Determine output shape by calling `f` with
# zero-initialised arrays sized like each child.
args = Any[]
for c_idx in children_indices
child = children_arr[c_idx]
child_shape = ntuple(
d -> _size(sizes, child, d),
sizes.ndims[child],
)
push!(args, zeros(child_shape))
end
y = f(args...)
if y isa AbstractArray
_add_size!(sizes, k, size(y))
child_shapes = Tuple(
ntuple(
d -> _size(sizes, children_arr[c_idx], d),
sizes.ndims[children_arr[c_idx]],
) for c_idx in children_indices
)
out_sz = infer_sizes(f, child_shapes...)
if !isempty(out_sz)
_add_size!(sizes, k, out_sz)
end
# Scalar output → ndims = 0 (already initialised).
continue
Expand Down
190 changes: 190 additions & 0 deletions test/JuMP.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,51 @@ function ChainRulesCore.rrule(
return val, pullback
end

# `my_crossentropy1` and `my_crossentropy2` exercise the `infer_sizes` API:
# the first relies on the default zeros-probe, the second has an explicit
# override. To prove the override is what runs, `my_crossentropy2` asserts
# `all(p .> 0)` — the default probe would build `p = zeros(...)` and assert-
# fail, so reaching the test's final `@test` confirms `infer_sizes` was
# specialised.
my_crossentropy1(p, q) = -sum(q .* log.(p .+ 1e-3))
function my_crossentropy2(p, q)
@assert all(>(0), p)
return -sum(q .* log.(p))
end

function ArrayDiff.infer_sizes(::typeof(my_crossentropy2), ::Tuple, ::Tuple)
return ()
end

function ChainRulesCore.rrule(
::typeof(my_crossentropy1),
p::AbstractArray,
q::AbstractArray,
)
ε = 1e-3
val = my_crossentropy1(p, q)
function pullback(δ)
dp = δ .* (-q ./ (p .+ ε))
dq = δ .* (-log.(p .+ ε))
return ChainRulesCore.NoTangent(), dp, dq
end
return val, pullback
end

function ChainRulesCore.rrule(
::typeof(my_crossentropy2),
p::AbstractArray,
q::AbstractArray,
)
val = my_crossentropy2(p, q)
function pullback(δ)
dp = δ .* (-q ./ p)
dq = δ .* (-log.(p))
return ChainRulesCore.NoTangent(), dp, dq
end
return val, pullback
end

function runtests()
for name in names(@__MODULE__; all = true)
if startswith("$(name)", "test_")
Expand Down Expand Up @@ -721,6 +766,51 @@ function test_chainrules_crossentropy_of_relu()
return
end

function test_add_operator_crossentropy_of_relu()
n = 2
X = [1.0 0.5; 0.3 0.8]
target = [0.5 0.2; 0.1 0.7]
model = Model()
@variable(model, W[1:n, 1:n], container = ArrayDiff.ArrayOfVariables)
mode = ArrayDiff.Mode()
ad = ArrayDiff.model(mode)
MOI.set(
ad,
ArrayDiff.UserDefinedArrayOperator(:my_relu; arity = 1),
my_relu,
)
op_crossentropy = ArrayDiff.add_operator(ad, 2, my_crossentropy)
@test op_crossentropy isa JuMP.NonlinearOperator
@test op_crossentropy.head == :my_crossentropy
Y = W * X
Z = my_relu.(Y)
loss = op_crossentropy(Z, target)
@test loss isa JuMP.NonlinearExpr
@test loss.head == :my_crossentropy
@test loss.args[1] === Z
@test loss.args[2] === target
MOI.Nonlinear.set_objective(ad, JuMP.moi_function(loss))
evaluator = MOI.Nonlinear.Evaluator(
ad,
mode,
JuMP.index.(JuMP.all_variables(model)),
)
MOI.initialize(evaluator, [:Grad])
W_val = [0.3 -0.2; 0.1 0.4]
x_in = vec(W_val)
val = MOI.eval_objective(evaluator, x_in)
@test val ≈ my_crossentropy(my_relu.(W_val * X), target)
g = zeros(length(x_in))
MOI.eval_objective_gradient(evaluator, g, x_in)
ε = 1e-3
Y_val = W_val * X
Z_val = my_relu.(Y_val)
dL_dZ = -target ./ (Z_val .+ ε)
dL_dY = dL_dZ .* Float64.(Y_val .> 0)
@test g ≈ vec(dL_dY * X')
return
end

function test_chainrules_broadcasted_relu()
n = 2
X = [1.0 0.5; 0.3 0.8]
Expand Down Expand Up @@ -873,6 +963,106 @@ function test_broadcast_divide_gradient()
return
end

function _run_infer_sizes_op(f, head::Symbol)
n = 2
X = [1.0 0.5; 0.3 0.8]
target = [0.5 0.2; 0.1 0.7]
model = Model()
@variable(model, W[1:n, 1:n], container = ArrayDiff.ArrayOfVariables)
mode = ArrayDiff.Mode()
ad = ArrayDiff.model(mode)
MOI.set(ad, ArrayDiff.UserDefinedArrayOperator(head; arity = 2), f)
Y = W * X
Z = my_relu.(Y) # ensure p > 0 for `my_crossentropy2`'s assertion
MOI.set(
ad,
ArrayDiff.UserDefinedArrayOperator(:my_relu; arity = 1),
my_relu,
)
loss_moi =
MOI.ScalarNonlinearFunction(head, Any[JuMP.moi_function(Z), target])
MOI.Nonlinear.set_objective(ad, loss_moi)
evaluator = MOI.Nonlinear.Evaluator(
ad,
mode,
JuMP.index.(JuMP.all_variables(model)),
)
MOI.initialize(evaluator, [:Grad])
W_val = [0.3 0.2; 0.1 0.4]
x_in = vec(W_val)
val = MOI.eval_objective(evaluator, x_in)
@test val ≈ f(my_relu.(W_val * X), target)
g = zeros(length(x_in))
MOI.eval_objective_gradient(evaluator, g, x_in)
h = 1e-6
g_fd = zeros(length(x_in))
for i in eachindex(x_in)
xp = copy(x_in)
xp[i] += h
xm = copy(x_in)
xm[i] -= h
g_fd[i] =
(
MOI.eval_objective(evaluator, xp) -
MOI.eval_objective(evaluator, xm)
) / (2h)
end
@test isapprox(g, g_fd; rtol = 1e-4)
return
end

# `my_crossentropy1` has no `infer_sizes` override, so output shape is found
# by probing the function with `zeros(Float64, ...)`. The probe runs fine
# because `log(0 + 1e-3)` is finite.
function test_infer_sizes_default_probe()
return _run_infer_sizes_op(my_crossentropy1, :my_crossentropy1)
end

# `my_crossentropy2` asserts `all(p .> 0)`, so the default zeros-probe would
# fail. The `infer_sizes` override returns `()` symbolically and avoids
# touching the function — reaching the final `@test` proves the override is
# what ran.
function test_infer_sizes_user_override()
return _run_infer_sizes_op(my_crossentropy2, :my_crossentropy2)
end

# Whole-array unary op used to exercise the array-output / unary dispatch.
my_double(x::AbstractArray) = 2 .* x

# Covers the unary `(op::NonlinearOperator)(x::AbstractJuMPArray)` dispatch
# and the array-output branch in `_build_user_op_expr` (where `infer_sizes`
# returns a non-empty shape and we build a `GenericArrayExpr` rather than a
# `GenericNonlinearExpr`).
function test_op_unary_jump_array()
n = 2
model = Model()
@variable(model, W[1:n, 1:n], container = ArrayDiff.ArrayOfVariables)
op = JuMP.NonlinearOperator(my_double, :my_double)
result = op(W)
@test result isa ArrayDiff.MatrixExpr
@test result.head == :my_double
@test size(result) == (n, n)
@test !result.broadcasted
@test result.args[1] === W
return
end

# Covers the `(op::NonlinearOperator)(x::Union{Real,AbstractArray{<:Real}}, y::AbstractJuMPArray)`
# dispatch — constant array first, JuMP array second.
function test_op_reversed_args()
n = 2
model = Model()
@variable(model, W[1:n, 1:n], container = ArrayDiff.ArrayOfVariables)
target = rand(n, n)
op = JuMP.NonlinearOperator(my_crossentropy, :my_crossentropy)
result = op(target, W)
@test result isa JuMP.NonlinearExpr
@test result.head == :my_crossentropy
@test result.args[1] === target
@test result.args[2] === W
return
end

end # module

TestJuMP.runtests()
Loading