diff --git a/src/JuMP/operators.jl b/src/JuMP/operators.jl index 276e284..b531488 100644 --- a/src/JuMP/operators.jl +++ b/src/JuMP/operators.jl @@ -147,3 +147,68 @@ function Base.:(+)( @assert size(x) == size(y) return GenericArrayExpr{V,N}(:+, Any[x, y], size(x), false) end + +# ── User-defined array operators ───────────────────────────────────────────── +# +# `add_operator(model, arity, f)` registers `f` on the `ArrayDiff.Model` via +# the [`UserDefinedArrayOperator`](@ref) attribute and returns a +# `JuMP.NonlinearOperator` wrapping `f`. When the returned operator is called +# with at least one `AbstractJuMPArray` argument, the dispatch methods below +# build either a `GenericArrayExpr` (when `f` returns an array) or a +# `JuMP.GenericNonlinearExpr` (scalar output) with that `name`. The reverse- +# mode derivative is pulled from `ChainRulesCore.rrule` at evaluator time. + +""" + add_operator(model::Model, arity::Int, f::Function; name::Symbol = Symbol(f)) + +Register `f` as a user-defined array operator on `model` and return a +`JuMP.NonlinearOperator` wrapping it. Mirrors `JuMP.add_nonlinear_operator`: +the call internally does `MOI.set(model, UserDefinedArrayOperator(name; arity), f)`. +The returned operator can be called with `AbstractJuMPArray` arguments to +build a `GenericArrayExpr` (array result) or `JuMP.GenericNonlinearExpr` +(scalar result). The output shape is determined by [`infer_sizes`](@ref). +""" +function add_operator( + model::Model, + arity::Int, + f::Function; + name::Symbol = Symbol(f), +) + MOI.set(model, UserDefinedArrayOperator(name; arity), f) + return JuMP.NonlinearOperator(f, name) +end + +function _build_user_op_expr(op::JuMP.NonlinearOperator, V::Type, args::Tuple) + shapes = map(size, args) + out_sz = infer_sizes(op.func, shapes...) + if isempty(out_sz) + return JuMP.GenericNonlinearExpr{V}(op.head, Any[args...]) + end + return GenericArrayExpr{V,length(out_sz)}( + op.head, + Any[args...], + out_sz, + false, + ) +end + +function (op::JuMP.NonlinearOperator)(x::AbstractJuMPArray) + V = JuMP.variable_ref_type(x) + return _build_user_op_expr(op, V, (x,)) +end + +function (op::JuMP.NonlinearOperator)( + x::AbstractJuMPArray, + y::Union{Real,AbstractArray{<:Real},AbstractJuMPArray}, +) + V = JuMP.variable_ref_type(x) + return _build_user_op_expr(op, V, (x, y)) +end + +function (op::JuMP.NonlinearOperator)( + x::Union{Real,AbstractArray{<:Real}}, + y::AbstractJuMPArray, +) + V = JuMP.variable_ref_type(y) + return _build_user_op_expr(op, V, (x, y)) +end diff --git a/src/sizes.jl b/src/sizes.jl index 3b5a5a4..f250a0d 100644 --- a/src/sizes.jl +++ b/src/sizes.jl @@ -339,6 +339,42 @@ function _assert_scalar_children(sizes, children_arr, children_indices, op) end end +""" + infer_sizes(op, child_sizes::Tuple...) -> Tuple + +Return the output shape of applying `op` to arguments of shapes `child_sizes`. +Each `child_sizes[i]` is `()` if argument `i` is a scalar, or a tuple of +positive integers if it is an array. The returned shape is `()` for a scalar +result. + +The default implementation constructs dummy arguments with `zeros(sz)` +(or `0.0` for scalars) and calls `op(args...)`. Specialise on `op`'s +`typeof` to avoid the allocation, to support operators that error on zero +inputs, or to compute the output shape symbolically. + +## Example + +For multiplication, `infer_sizes` can be implemented as follows +```julia +function infer_sizes(::typeof(*), lhs, rhs) + if isempty(lhs) + return rhs + end + if isempty(rhs) + return lhs + end + return (lhs[1:end-1]..., rhs[2:end]...) +end +``` +""" +function infer_sizes(op, child_sizes::Tuple...) + args = map(child_sizes) do sz + return isempty(sz) ? 0.0 : zeros(sz) + end + y = op(args...) + return y isa AbstractArray ? size(y) : () +end + function _infer_sizes( nodes::Vector{Node}, adj::SparseArrays.SparseMatrixCSC{Bool,Int}, @@ -373,20 +409,15 @@ function _infer_sizes( op_sym = operators.multivariate_operators[node.index] if haskey(operators.chainrules_operators, op_sym) f = operators.chainrules_operators[op_sym] - # Determine output shape by calling `f` with - # zero-initialised arrays sized like each child. - args = Any[] - for c_idx in children_indices - child = children_arr[c_idx] - child_shape = ntuple( - d -> _size(sizes, child, d), - sizes.ndims[child], - ) - push!(args, zeros(child_shape)) - end - y = f(args...) - if y isa AbstractArray - _add_size!(sizes, k, size(y)) + child_shapes = Tuple( + ntuple( + d -> _size(sizes, children_arr[c_idx], d), + sizes.ndims[children_arr[c_idx]], + ) for c_idx in children_indices + ) + out_sz = infer_sizes(f, child_shapes...) + if !isempty(out_sz) + _add_size!(sizes, k, out_sz) end # Scalar output → ndims = 0 (already initialised). continue diff --git a/test/JuMP.jl b/test/JuMP.jl index bfcf48e..35ba6f6 100644 --- a/test/JuMP.jl +++ b/test/JuMP.jl @@ -35,6 +35,51 @@ function ChainRulesCore.rrule( return val, pullback end +# `my_crossentropy1` and `my_crossentropy2` exercise the `infer_sizes` API: +# the first relies on the default zeros-probe, the second has an explicit +# override. To prove the override is what runs, `my_crossentropy2` asserts +# `all(p .> 0)` — the default probe would build `p = zeros(...)` and assert- +# fail, so reaching the test's final `@test` confirms `infer_sizes` was +# specialised. +my_crossentropy1(p, q) = -sum(q .* log.(p .+ 1e-3)) +function my_crossentropy2(p, q) + @assert all(>(0), p) + return -sum(q .* log.(p)) +end + +function ArrayDiff.infer_sizes(::typeof(my_crossentropy2), ::Tuple, ::Tuple) + return () +end + +function ChainRulesCore.rrule( + ::typeof(my_crossentropy1), + p::AbstractArray, + q::AbstractArray, +) + ε = 1e-3 + val = my_crossentropy1(p, q) + function pullback(δ) + dp = δ .* (-q ./ (p .+ ε)) + dq = δ .* (-log.(p .+ ε)) + return ChainRulesCore.NoTangent(), dp, dq + end + return val, pullback +end + +function ChainRulesCore.rrule( + ::typeof(my_crossentropy2), + p::AbstractArray, + q::AbstractArray, +) + val = my_crossentropy2(p, q) + function pullback(δ) + dp = δ .* (-q ./ p) + dq = δ .* (-log.(p)) + return ChainRulesCore.NoTangent(), dp, dq + end + return val, pullback +end + function runtests() for name in names(@__MODULE__; all = true) if startswith("$(name)", "test_") @@ -721,6 +766,51 @@ function test_chainrules_crossentropy_of_relu() return end +function test_add_operator_crossentropy_of_relu() + n = 2 + X = [1.0 0.5; 0.3 0.8] + target = [0.5 0.2; 0.1 0.7] + model = Model() + @variable(model, W[1:n, 1:n], container = ArrayDiff.ArrayOfVariables) + mode = ArrayDiff.Mode() + ad = ArrayDiff.model(mode) + MOI.set( + ad, + ArrayDiff.UserDefinedArrayOperator(:my_relu; arity = 1), + my_relu, + ) + op_crossentropy = ArrayDiff.add_operator(ad, 2, my_crossentropy) + @test op_crossentropy isa JuMP.NonlinearOperator + @test op_crossentropy.head == :my_crossentropy + Y = W * X + Z = my_relu.(Y) + loss = op_crossentropy(Z, target) + @test loss isa JuMP.NonlinearExpr + @test loss.head == :my_crossentropy + @test loss.args[1] === Z + @test loss.args[2] === target + MOI.Nonlinear.set_objective(ad, JuMP.moi_function(loss)) + evaluator = MOI.Nonlinear.Evaluator( + ad, + mode, + JuMP.index.(JuMP.all_variables(model)), + ) + MOI.initialize(evaluator, [:Grad]) + W_val = [0.3 -0.2; 0.1 0.4] + x_in = vec(W_val) + val = MOI.eval_objective(evaluator, x_in) + @test val ≈ my_crossentropy(my_relu.(W_val * X), target) + g = zeros(length(x_in)) + MOI.eval_objective_gradient(evaluator, g, x_in) + ε = 1e-3 + Y_val = W_val * X + Z_val = my_relu.(Y_val) + dL_dZ = -target ./ (Z_val .+ ε) + dL_dY = dL_dZ .* Float64.(Y_val .> 0) + @test g ≈ vec(dL_dY * X') + return +end + function test_chainrules_broadcasted_relu() n = 2 X = [1.0 0.5; 0.3 0.8] @@ -873,6 +963,106 @@ function test_broadcast_divide_gradient() return end +function _run_infer_sizes_op(f, head::Symbol) + n = 2 + X = [1.0 0.5; 0.3 0.8] + target = [0.5 0.2; 0.1 0.7] + model = Model() + @variable(model, W[1:n, 1:n], container = ArrayDiff.ArrayOfVariables) + mode = ArrayDiff.Mode() + ad = ArrayDiff.model(mode) + MOI.set(ad, ArrayDiff.UserDefinedArrayOperator(head; arity = 2), f) + Y = W * X + Z = my_relu.(Y) # ensure p > 0 for `my_crossentropy2`'s assertion + MOI.set( + ad, + ArrayDiff.UserDefinedArrayOperator(:my_relu; arity = 1), + my_relu, + ) + loss_moi = + MOI.ScalarNonlinearFunction(head, Any[JuMP.moi_function(Z), target]) + MOI.Nonlinear.set_objective(ad, loss_moi) + evaluator = MOI.Nonlinear.Evaluator( + ad, + mode, + JuMP.index.(JuMP.all_variables(model)), + ) + MOI.initialize(evaluator, [:Grad]) + W_val = [0.3 0.2; 0.1 0.4] + x_in = vec(W_val) + val = MOI.eval_objective(evaluator, x_in) + @test val ≈ f(my_relu.(W_val * X), target) + g = zeros(length(x_in)) + MOI.eval_objective_gradient(evaluator, g, x_in) + h = 1e-6 + g_fd = zeros(length(x_in)) + for i in eachindex(x_in) + xp = copy(x_in) + xp[i] += h + xm = copy(x_in) + xm[i] -= h + g_fd[i] = + ( + MOI.eval_objective(evaluator, xp) - + MOI.eval_objective(evaluator, xm) + ) / (2h) + end + @test isapprox(g, g_fd; rtol = 1e-4) + return +end + +# `my_crossentropy1` has no `infer_sizes` override, so output shape is found +# by probing the function with `zeros(Float64, ...)`. The probe runs fine +# because `log(0 + 1e-3)` is finite. +function test_infer_sizes_default_probe() + return _run_infer_sizes_op(my_crossentropy1, :my_crossentropy1) +end + +# `my_crossentropy2` asserts `all(p .> 0)`, so the default zeros-probe would +# fail. The `infer_sizes` override returns `()` symbolically and avoids +# touching the function — reaching the final `@test` proves the override is +# what ran. +function test_infer_sizes_user_override() + return _run_infer_sizes_op(my_crossentropy2, :my_crossentropy2) +end + +# Whole-array unary op used to exercise the array-output / unary dispatch. +my_double(x::AbstractArray) = 2 .* x + +# Covers the unary `(op::NonlinearOperator)(x::AbstractJuMPArray)` dispatch +# and the array-output branch in `_build_user_op_expr` (where `infer_sizes` +# returns a non-empty shape and we build a `GenericArrayExpr` rather than a +# `GenericNonlinearExpr`). +function test_op_unary_jump_array() + n = 2 + model = Model() + @variable(model, W[1:n, 1:n], container = ArrayDiff.ArrayOfVariables) + op = JuMP.NonlinearOperator(my_double, :my_double) + result = op(W) + @test result isa ArrayDiff.MatrixExpr + @test result.head == :my_double + @test size(result) == (n, n) + @test !result.broadcasted + @test result.args[1] === W + return +end + +# Covers the `(op::NonlinearOperator)(x::Union{Real,AbstractArray{<:Real}}, y::AbstractJuMPArray)` +# dispatch — constant array first, JuMP array second. +function test_op_reversed_args() + n = 2 + model = Model() + @variable(model, W[1:n, 1:n], container = ArrayDiff.ArrayOfVariables) + target = rand(n, n) + op = JuMP.NonlinearOperator(my_crossentropy, :my_crossentropy) + result = op(target, W) + @test result isa JuMP.NonlinearExpr + @test result.head == :my_crossentropy + @test result.args[1] === target + @test result.args[2] === W + return +end + end # module TestJuMP.runtests()