From 2d85cb67ba220be3de8760546a57f8c02e744b6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Tue, 26 May 2026 08:11:29 +0200 Subject: [PATCH 1/8] Add add_operator --- src/JuMP/operators.jl | 68 +++++++++++++++++++++++++++++++++++++++++++ test/JuMP.jl | 50 +++++++++++++++++++++++++++++++ 2 files changed, 118 insertions(+) diff --git a/src/JuMP/operators.jl b/src/JuMP/operators.jl index 276e284..10d5a7a 100644 --- a/src/JuMP/operators.jl +++ b/src/JuMP/operators.jl @@ -147,3 +147,71 @@ function Base.:(+)( @assert size(x) == size(y) return GenericArrayExpr{V,N}(:+, Any[x, y], size(x), false) end + +# ── User-defined array operators ───────────────────────────────────────────── +# +# `add_operator(f)` wraps `f` in a `JuMP.NonlinearOperator` whose `head` is +# `Symbol(f)`. When the operator is called with at least one `AbstractJuMPArray` +# argument, the dispatch methods below build either a `GenericArrayExpr` (when +# `f` returns an array) or a `JuMP.GenericNonlinearExpr` (scalar output) with +# that `head`. The user is still responsible for registering `f` with the +# `ArrayDiff.Model` via [`UserDefinedArrayOperator`](@ref) so the evaluator can +# call `f` and pull its reverse-mode derivative from `ChainRulesCore.rrule`. + +""" + add_operator(f::Function; head::Symbol = Symbol(f)) + +Return a `JuMP.NonlinearOperator` wrapping `f`. When the returned operator is +called with `AbstractJuMPArray` arguments, it builds a JuMP expression whose +`head` is `head` — a `GenericArrayExpr` if `f` returns an array, otherwise a +`JuMP.GenericNonlinearExpr`. The output shape is determined by probing `f` +with zero arrays sized like the JuMP-array arguments. +""" +function add_operator(f::Function; head::Symbol = Symbol(f)) + return JuMP.NonlinearOperator(f, head) +end + +function _user_op_probe_arg(a::AbstractJuMPArray) + return zeros(Float64, size(a)) +end +_user_op_probe_arg(a::AbstractArray{<:Real}) = Float64.(a) +_user_op_probe_arg(a::Real) = Float64(a) + +function _build_user_op_expr( + op::JuMP.NonlinearOperator, + V::Type, + args::Tuple, +) + probe = map(_user_op_probe_arg, args) + y = op.func(probe...) + if y isa AbstractArray + return GenericArrayExpr{V,ndims(y)}( + op.head, + Any[args...], + size(y), + false, + ) + end + return JuMP.GenericNonlinearExpr{V}(op.head, Any[args...]) +end + +function (op::JuMP.NonlinearOperator)(x::AbstractJuMPArray) + V = JuMP.variable_ref_type(x) + return _build_user_op_expr(op, V, (x,)) +end + +function (op::JuMP.NonlinearOperator)( + x::AbstractJuMPArray, + y::Union{Real,AbstractArray{<:Real},AbstractJuMPArray}, +) + V = JuMP.variable_ref_type(x) + return _build_user_op_expr(op, V, (x, y)) +end + +function (op::JuMP.NonlinearOperator)( + x::Union{Real,AbstractArray{<:Real}}, + y::AbstractJuMPArray, +) + V = JuMP.variable_ref_type(y) + return _build_user_op_expr(op, V, (x, y)) +end diff --git a/test/JuMP.jl b/test/JuMP.jl index bfcf48e..62b5151 100644 --- a/test/JuMP.jl +++ b/test/JuMP.jl @@ -721,6 +721,56 @@ function test_chainrules_crossentropy_of_relu() return end +function test_add_operator_crossentropy_of_relu() + n = 2 + X = [1.0 0.5; 0.3 0.8] + target = [0.5 0.2; 0.1 0.7] + model = Model() + @variable(model, W[1:n, 1:n], container = ArrayDiff.ArrayOfVariables) + mode = ArrayDiff.Mode() + ad = ArrayDiff.model(mode) + MOI.set( + ad, + ArrayDiff.UserDefinedArrayOperator(:my_relu; arity = 1), + my_relu, + ) + MOI.set( + ad, + ArrayDiff.UserDefinedArrayOperator(:my_crossentropy; arity = 2), + my_crossentropy, + ) + op_crossentropy = ArrayDiff.add_operator(my_crossentropy) + @test op_crossentropy isa JuMP.NonlinearOperator + @test op_crossentropy.head == :my_crossentropy + Y = W * X + Z = my_relu.(Y) + loss = op_crossentropy(Z, target) + @test loss isa JuMP.NonlinearExpr + @test loss.head == :my_crossentropy + @test loss.args[1] === Z + @test loss.args[2] === target + MOI.Nonlinear.set_objective(ad, JuMP.moi_function(loss)) + evaluator = MOI.Nonlinear.Evaluator( + ad, + mode, + JuMP.index.(JuMP.all_variables(model)), + ) + MOI.initialize(evaluator, [:Grad]) + W_val = [0.3 -0.2; 0.1 0.4] + x_in = vec(W_val) + val = MOI.eval_objective(evaluator, x_in) + @test val ≈ my_crossentropy(my_relu.(W_val * X), target) + g = zeros(length(x_in)) + MOI.eval_objective_gradient(evaluator, g, x_in) + ε = 1e-3 + Y_val = W_val * X + Z_val = my_relu.(Y_val) + dL_dZ = -target ./ (Z_val .+ ε) + dL_dY = dL_dZ .* Float64.(Y_val .> 0) + @test g ≈ vec(dL_dY * X') + return +end + function test_chainrules_broadcasted_relu() n = 2 X = [1.0 0.5; 0.3 0.8] From 7525bae824fe1b71a2fa15f07daf9bd1a7d8e940 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Tue, 26 May 2026 10:35:46 +0200 Subject: [PATCH 2/8] Simplify --- src/JuMP/operators.jl | 32 ++++-------- src/sizes.jl | 48 +++++++++++++----- test/JuMP.jl | 115 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 160 insertions(+), 35 deletions(-) diff --git a/src/JuMP/operators.jl b/src/JuMP/operators.jl index 10d5a7a..e85f43f 100644 --- a/src/JuMP/operators.jl +++ b/src/JuMP/operators.jl @@ -171,28 +171,18 @@ function add_operator(f::Function; head::Symbol = Symbol(f)) return JuMP.NonlinearOperator(f, head) end -function _user_op_probe_arg(a::AbstractJuMPArray) - return zeros(Float64, size(a)) -end -_user_op_probe_arg(a::AbstractArray{<:Real}) = Float64.(a) -_user_op_probe_arg(a::Real) = Float64(a) - -function _build_user_op_expr( - op::JuMP.NonlinearOperator, - V::Type, - args::Tuple, -) - probe = map(_user_op_probe_arg, args) - y = op.func(probe...) - if y isa AbstractArray - return GenericArrayExpr{V,ndims(y)}( - op.head, - Any[args...], - size(y), - false, - ) +function _build_user_op_expr(op::JuMP.NonlinearOperator, V::Type, args::Tuple) + shapes = map(size, args) + out_sz = infer_sizes(JuMP.value_type(V), op.func, shapes...) + if isempty(out_sz) + return JuMP.GenericNonlinearExpr{V}(op.head, Any[args...]) end - return JuMP.GenericNonlinearExpr{V}(op.head, Any[args...]) + return GenericArrayExpr{V,length(out_sz)}( + op.head, + Any[args...], + out_sz, + false, + ) end function (op::JuMP.NonlinearOperator)(x::AbstractJuMPArray) diff --git a/src/sizes.jl b/src/sizes.jl index 3b5a5a4..58a9763 100644 --- a/src/sizes.jl +++ b/src/sizes.jl @@ -339,6 +339,27 @@ function _assert_scalar_children(sizes, children_arr, children_indices, op) end end +""" + infer_sizes(::Type{T}, op, child_sizes::Tuple...) -> Tuple + +Return the output shape of applying `op` to arguments of shapes `child_sizes`. +Each `child_sizes[i]` is `()` if argument `i` is a scalar, or a tuple of +positive integers if it is an array. The returned shape is `()` for a scalar +result. + +The default implementation constructs dummy arguments with `zeros(T, sz)` +(or `zero(T)` for scalars) and calls `op(args...)`. Specialise on `op`'s +`typeof` to avoid the allocation, to support operators that error on zero +inputs, or to compute the output shape symbolically. +""" +function infer_sizes(::Type{T}, op, child_sizes::Tuple...) where {T} + args = map(child_sizes) do sz + return isempty(sz) ? zero(T) : zeros(T, sz) + end + y = op(args...) + return y isa AbstractArray ? size(y) : () +end + function _infer_sizes( nodes::Vector{Node}, adj::SparseArrays.SparseMatrixCSC{Bool,Int}, @@ -373,20 +394,19 @@ function _infer_sizes( op_sym = operators.multivariate_operators[node.index] if haskey(operators.chainrules_operators, op_sym) f = operators.chainrules_operators[op_sym] - # Determine output shape by calling `f` with - # zero-initialised arrays sized like each child. - args = Any[] - for c_idx in children_indices - child = children_arr[c_idx] - child_shape = ntuple( - d -> _size(sizes, child, d), - sizes.ndims[child], - ) - push!(args, zeros(child_shape)) - end - y = f(args...) - if y isa AbstractArray - _add_size!(sizes, k, size(y)) + child_shapes = Tuple( + ntuple( + d -> _size( + sizes, + children_arr[c_idx], + d, + ), + sizes.ndims[children_arr[c_idx]], + ) for c_idx in children_indices + ) + out_sz = infer_sizes(Float64, f, child_shapes...) + if !isempty(out_sz) + _add_size!(sizes, k, out_sz) end # Scalar output → ndims = 0 (already initialised). continue diff --git a/test/JuMP.jl b/test/JuMP.jl index 62b5151..47a788e 100644 --- a/test/JuMP.jl +++ b/test/JuMP.jl @@ -35,6 +35,56 @@ function ChainRulesCore.rrule( return val, pullback end +# `my_crossentropy1` and `my_crossentropy2` exercise the `infer_sizes` API: +# the first relies on the default zeros-probe, the second has an explicit +# override. To prove the override is what runs, `my_crossentropy2` asserts +# `all(p .> 0)` — the default probe would build `p = zeros(...)` and assert- +# fail, so reaching the test's final `@test` confirms `infer_sizes` was +# specialised. +my_crossentropy1(p, q) = -sum(q .* log.(p .+ 1e-3)) +function my_crossentropy2(p, q) + @assert all(>(0), p) + return -sum(q .* log.(p)) +end + +function ArrayDiff.infer_sizes( + ::Type{T}, + ::typeof(my_crossentropy2), + ::Tuple, + ::Tuple, +) where {T} + return () +end + +function ChainRulesCore.rrule( + ::typeof(my_crossentropy1), + p::AbstractArray, + q::AbstractArray, +) + ε = 1e-3 + val = my_crossentropy1(p, q) + function pullback(δ) + dp = δ .* (-q ./ (p .+ ε)) + dq = δ .* (-log.(p .+ ε)) + return ChainRulesCore.NoTangent(), dp, dq + end + return val, pullback +end + +function ChainRulesCore.rrule( + ::typeof(my_crossentropy2), + p::AbstractArray, + q::AbstractArray, +) + val = my_crossentropy2(p, q) + function pullback(δ) + dp = δ .* (-q ./ p) + dq = δ .* (-log.(p)) + return ChainRulesCore.NoTangent(), dp, dq + end + return val, pullback +end + function runtests() for name in names(@__MODULE__; all = true) if startswith("$(name)", "test_") @@ -923,6 +973,71 @@ function test_broadcast_divide_gradient() return end +function _run_infer_sizes_op(f, head::Symbol) + n = 2 + X = [1.0 0.5; 0.3 0.8] + target = [0.5 0.2; 0.1 0.7] + model = Model() + @variable(model, W[1:n, 1:n], container = ArrayDiff.ArrayOfVariables) + mode = ArrayDiff.Mode() + ad = ArrayDiff.model(mode) + MOI.set(ad, ArrayDiff.UserDefinedArrayOperator(head; arity = 2), f) + Y = W * X + Z = my_relu.(Y) # ensure p > 0 for `my_crossentropy2`'s assertion + MOI.set( + ad, + ArrayDiff.UserDefinedArrayOperator(:my_relu; arity = 1), + my_relu, + ) + loss_moi = MOI.ScalarNonlinearFunction( + head, + Any[JuMP.moi_function(Z), target], + ) + MOI.Nonlinear.set_objective(ad, loss_moi) + evaluator = MOI.Nonlinear.Evaluator( + ad, + mode, + JuMP.index.(JuMP.all_variables(model)), + ) + MOI.initialize(evaluator, [:Grad]) + W_val = [0.3 0.2; 0.1 0.4] + x_in = vec(W_val) + val = MOI.eval_objective(evaluator, x_in) + @test val ≈ f(my_relu.(W_val * X), target) + g = zeros(length(x_in)) + MOI.eval_objective_gradient(evaluator, g, x_in) + h = 1e-6 + g_fd = zeros(length(x_in)) + for i in eachindex(x_in) + xp = copy(x_in) + xp[i] += h + xm = copy(x_in) + xm[i] -= h + g_fd[i] = + ( + MOI.eval_objective(evaluator, xp) - + MOI.eval_objective(evaluator, xm) + ) / (2h) + end + @test isapprox(g, g_fd; rtol = 1e-4) + return +end + +# `my_crossentropy1` has no `infer_sizes` override, so output shape is found +# by probing the function with `zeros(Float64, ...)`. The probe runs fine +# because `log(0 + 1e-3)` is finite. +function test_infer_sizes_default_probe() + return _run_infer_sizes_op(my_crossentropy1, :my_crossentropy1) +end + +# `my_crossentropy2` asserts `all(p .> 0)`, so the default zeros-probe would +# fail. The `infer_sizes` override returns `()` symbolically and avoids +# touching the function — reaching the final `@test` proves the override is +# what ran. +function test_infer_sizes_user_override() + return _run_infer_sizes_op(my_crossentropy2, :my_crossentropy2) +end + end # module TestJuMP.runtests() From 1a7a5fade3a88a1141fa2bc3deba6863a82313e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Tue, 26 May 2026 10:35:52 +0200 Subject: [PATCH 3/8] Fix format --- src/sizes.jl | 6 +----- test/JuMP.jl | 6 ++---- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/src/sizes.jl b/src/sizes.jl index 58a9763..8788ee1 100644 --- a/src/sizes.jl +++ b/src/sizes.jl @@ -396,11 +396,7 @@ function _infer_sizes( f = operators.chainrules_operators[op_sym] child_shapes = Tuple( ntuple( - d -> _size( - sizes, - children_arr[c_idx], - d, - ), + d -> _size(sizes, children_arr[c_idx], d), sizes.ndims[children_arr[c_idx]], ) for c_idx in children_indices ) diff --git a/test/JuMP.jl b/test/JuMP.jl index 47a788e..798e542 100644 --- a/test/JuMP.jl +++ b/test/JuMP.jl @@ -989,10 +989,8 @@ function _run_infer_sizes_op(f, head::Symbol) ArrayDiff.UserDefinedArrayOperator(:my_relu; arity = 1), my_relu, ) - loss_moi = MOI.ScalarNonlinearFunction( - head, - Any[JuMP.moi_function(Z), target], - ) + loss_moi = + MOI.ScalarNonlinearFunction(head, Any[JuMP.moi_function(Z), target]) MOI.Nonlinear.set_objective(ad, loss_moi) evaluator = MOI.Nonlinear.Evaluator( ad, From a62cdbc7e3989ae5012e5559fe5dcc7efe7ad7a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Tue, 26 May 2026 11:13:37 +0200 Subject: [PATCH 4/8] Fix --- src/JuMP/operators.jl | 39 +++++++++++++++++++++++---------------- test/JuMP.jl | 7 +------ 2 files changed, 24 insertions(+), 22 deletions(-) diff --git a/src/JuMP/operators.jl b/src/JuMP/operators.jl index e85f43f..ed2b169 100644 --- a/src/JuMP/operators.jl +++ b/src/JuMP/operators.jl @@ -150,25 +150,32 @@ end # ── User-defined array operators ───────────────────────────────────────────── # -# `add_operator(f)` wraps `f` in a `JuMP.NonlinearOperator` whose `head` is -# `Symbol(f)`. When the operator is called with at least one `AbstractJuMPArray` -# argument, the dispatch methods below build either a `GenericArrayExpr` (when -# `f` returns an array) or a `JuMP.GenericNonlinearExpr` (scalar output) with -# that `head`. The user is still responsible for registering `f` with the -# `ArrayDiff.Model` via [`UserDefinedArrayOperator`](@ref) so the evaluator can -# call `f` and pull its reverse-mode derivative from `ChainRulesCore.rrule`. +# `add_operator(model, arity, f)` registers `f` on the `ArrayDiff.Model` via +# the [`UserDefinedArrayOperator`](@ref) attribute and returns a +# `JuMP.NonlinearOperator` wrapping `f`. When the returned operator is called +# with at least one `AbstractJuMPArray` argument, the dispatch methods below +# build either a `GenericArrayExpr` (when `f` returns an array) or a +# `JuMP.GenericNonlinearExpr` (scalar output) with that `name`. The reverse- +# mode derivative is pulled from `ChainRulesCore.rrule` at evaluator time. """ - add_operator(f::Function; head::Symbol = Symbol(f)) - -Return a `JuMP.NonlinearOperator` wrapping `f`. When the returned operator is -called with `AbstractJuMPArray` arguments, it builds a JuMP expression whose -`head` is `head` — a `GenericArrayExpr` if `f` returns an array, otherwise a -`JuMP.GenericNonlinearExpr`. The output shape is determined by probing `f` -with zero arrays sized like the JuMP-array arguments. + add_operator(model::Model, arity::Int, f::Function; name::Symbol = Symbol(f)) + +Register `f` as a user-defined array operator on `model` and return a +`JuMP.NonlinearOperator` wrapping it. Mirrors `JuMP.add_nonlinear_operator`: +the call internally does `MOI.set(model, UserDefinedArrayOperator(name; arity), f)`. +The returned operator can be called with `AbstractJuMPArray` arguments to +build a `GenericArrayExpr` (array result) or `JuMP.GenericNonlinearExpr` +(scalar result). The output shape is determined by [`infer_sizes`](@ref). """ -function add_operator(f::Function; head::Symbol = Symbol(f)) - return JuMP.NonlinearOperator(f, head) +function add_operator( + model::Model, + arity::Int, + f::Function; + name::Symbol = Symbol(f), +) + MOI.set(model, UserDefinedArrayOperator(name; arity), f) + return JuMP.NonlinearOperator(f, name) end function _build_user_op_expr(op::JuMP.NonlinearOperator, V::Type, args::Tuple) diff --git a/test/JuMP.jl b/test/JuMP.jl index 798e542..08c11bd 100644 --- a/test/JuMP.jl +++ b/test/JuMP.jl @@ -784,12 +784,7 @@ function test_add_operator_crossentropy_of_relu() ArrayDiff.UserDefinedArrayOperator(:my_relu; arity = 1), my_relu, ) - MOI.set( - ad, - ArrayDiff.UserDefinedArrayOperator(:my_crossentropy; arity = 2), - my_crossentropy, - ) - op_crossentropy = ArrayDiff.add_operator(my_crossentropy) + op_crossentropy = ArrayDiff.add_operator(ad, 2, my_crossentropy) @test op_crossentropy isa JuMP.NonlinearOperator @test op_crossentropy.head == :my_crossentropy Y = W * X From 080a9c07176d29e2e65fd754e8c9cd4c7aa2c08f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Tue, 26 May 2026 11:31:41 +0200 Subject: [PATCH 5/8] Simplify --- src/JuMP/operators.jl | 2 +- src/sizes.jl | 23 ++++++++++++++++++----- test/JuMP.jl | 7 +------ 3 files changed, 20 insertions(+), 12 deletions(-) diff --git a/src/JuMP/operators.jl b/src/JuMP/operators.jl index ed2b169..b531488 100644 --- a/src/JuMP/operators.jl +++ b/src/JuMP/operators.jl @@ -180,7 +180,7 @@ end function _build_user_op_expr(op::JuMP.NonlinearOperator, V::Type, args::Tuple) shapes = map(size, args) - out_sz = infer_sizes(JuMP.value_type(V), op.func, shapes...) + out_sz = infer_sizes(op.func, shapes...) if isempty(out_sz) return JuMP.GenericNonlinearExpr{V}(op.head, Any[args...]) end diff --git a/src/sizes.jl b/src/sizes.jl index 8788ee1..1493f3a 100644 --- a/src/sizes.jl +++ b/src/sizes.jl @@ -340,21 +340,34 @@ function _assert_scalar_children(sizes, children_arr, children_indices, op) end """ - infer_sizes(::Type{T}, op, child_sizes::Tuple...) -> Tuple + infer_sizes(op, child_sizes::Tuple...) -> Tuple Return the output shape of applying `op` to arguments of shapes `child_sizes`. Each `child_sizes[i]` is `()` if argument `i` is a scalar, or a tuple of positive integers if it is an array. The returned shape is `()` for a scalar result. -The default implementation constructs dummy arguments with `zeros(T, sz)` -(or `zero(T)` for scalars) and calls `op(args...)`. Specialise on `op`'s +The default implementation constructs dummy arguments with `zeros(sz)` +(or `0.0` for scalars) and calls `op(args...)`. Specialise on `op`'s `typeof` to avoid the allocation, to support operators that error on zero inputs, or to compute the output shape symbolically. + +## Example + +For mulitplication, `infer_sizes` can be implemented as follows +```julia +infer_sizes(::typeof(*), ::Tuple{}, ::Tuple{}) = () +infer_sizes(::typeof(*), ::Tuple{}, rhs::Tuple) = rhs +infer_sizes(::typeof(*), lhs::Tuple, ::Tuple{}) = lhs +function infer_sizes(::typeof(*), lhs, rhs) + return (lhs[1:end-1]..., rhs[2:end]...) +end +``` +""" """ -function infer_sizes(::Type{T}, op, child_sizes::Tuple...) where {T} +function infer_sizes(op, child_sizes::Tuple...) where {T} args = map(child_sizes) do sz - return isempty(sz) ? zero(T) : zeros(T, sz) + return isempty(sz) ? 0.0 : zeros(sz) end y = op(args...) return y isa AbstractArray ? size(y) : () diff --git a/test/JuMP.jl b/test/JuMP.jl index 08c11bd..2bbc2c8 100644 --- a/test/JuMP.jl +++ b/test/JuMP.jl @@ -47,12 +47,7 @@ function my_crossentropy2(p, q) return -sum(q .* log.(p)) end -function ArrayDiff.infer_sizes( - ::Type{T}, - ::typeof(my_crossentropy2), - ::Tuple, - ::Tuple, -) where {T} +function ArrayDiff.infer_sizes(::typeof(my_crossentropy2), ::Tuple, ::Tuple) return () end From 20bee30782180efd352384d05c50f92193d202cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Tue, 26 May 2026 11:38:18 +0200 Subject: [PATCH 6/8] Better example --- src/sizes.jl | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/sizes.jl b/src/sizes.jl index 1493f3a..39ad59b 100644 --- a/src/sizes.jl +++ b/src/sizes.jl @@ -354,18 +354,21 @@ inputs, or to compute the output shape symbolically. ## Example -For mulitplication, `infer_sizes` can be implemented as follows +For multiplication, `infer_sizes` can be implemented as follows ```julia -infer_sizes(::typeof(*), ::Tuple{}, ::Tuple{}) = () -infer_sizes(::typeof(*), ::Tuple{}, rhs::Tuple) = rhs -infer_sizes(::typeof(*), lhs::Tuple, ::Tuple{}) = lhs function infer_sizes(::typeof(*), lhs, rhs) + if isempty(lhs) + return rhs + end + if isempty(rhs) + return lhs + end return (lhs[1:end-1]..., rhs[2:end]...) end ``` """ """ -function infer_sizes(op, child_sizes::Tuple...) where {T} +function infer_sizes(op, child_sizes::Tuple...) args = map(child_sizes) do sz return isempty(sz) ? 0.0 : zeros(sz) end From e422d46f11fd337e244233dc95b77e958620ace1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Tue, 26 May 2026 11:40:41 +0200 Subject: [PATCH 7/8] Fix parsing --- src/sizes.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/sizes.jl b/src/sizes.jl index 39ad59b..f250a0d 100644 --- a/src/sizes.jl +++ b/src/sizes.jl @@ -367,7 +367,6 @@ function infer_sizes(::typeof(*), lhs, rhs) end ``` """ -""" function infer_sizes(op, child_sizes::Tuple...) args = map(child_sizes) do sz return isempty(sz) ? 0.0 : zeros(sz) @@ -416,7 +415,7 @@ function _infer_sizes( sizes.ndims[children_arr[c_idx]], ) for c_idx in children_indices ) - out_sz = infer_sizes(Float64, f, child_shapes...) + out_sz = infer_sizes(f, child_shapes...) if !isempty(out_sz) _add_size!(sizes, k, out_sz) end From 4f32237de78943f7fbe23792043510c84dcc383c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Tue, 26 May 2026 12:31:19 +0200 Subject: [PATCH 8/8] Add tests --- test/JuMP.jl | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/test/JuMP.jl b/test/JuMP.jl index 2bbc2c8..35ba6f6 100644 --- a/test/JuMP.jl +++ b/test/JuMP.jl @@ -1026,6 +1026,43 @@ function test_infer_sizes_user_override() return _run_infer_sizes_op(my_crossentropy2, :my_crossentropy2) end +# Whole-array unary op used to exercise the array-output / unary dispatch. +my_double(x::AbstractArray) = 2 .* x + +# Covers the unary `(op::NonlinearOperator)(x::AbstractJuMPArray)` dispatch +# and the array-output branch in `_build_user_op_expr` (where `infer_sizes` +# returns a non-empty shape and we build a `GenericArrayExpr` rather than a +# `GenericNonlinearExpr`). +function test_op_unary_jump_array() + n = 2 + model = Model() + @variable(model, W[1:n, 1:n], container = ArrayDiff.ArrayOfVariables) + op = JuMP.NonlinearOperator(my_double, :my_double) + result = op(W) + @test result isa ArrayDiff.MatrixExpr + @test result.head == :my_double + @test size(result) == (n, n) + @test !result.broadcasted + @test result.args[1] === W + return +end + +# Covers the `(op::NonlinearOperator)(x::Union{Real,AbstractArray{<:Real}}, y::AbstractJuMPArray)` +# dispatch — constant array first, JuMP array second. +function test_op_reversed_args() + n = 2 + model = Model() + @variable(model, W[1:n, 1:n], container = ArrayDiff.ArrayOfVariables) + target = rand(n, n) + op = JuMP.NonlinearOperator(my_crossentropy, :my_crossentropy) + result = op(target, W) + @test result isa JuMP.NonlinearExpr + @test result.head == :my_crossentropy + @test result.args[1] === target + @test result.args[2] === W + return +end + end # module TestJuMP.runtests()