From 56133565bf5c42d357d30f27d3c31e67d06159e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Fri, 5 Jun 2026 14:24:27 +0100 Subject: [PATCH] Fix Optimisers.jl wrapper on GPU --- test/OptimisersSolver.jl | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/test/OptimisersSolver.jl b/test/OptimisersSolver.jl index 78b6164..64329c7 100644 --- a/test/OptimisersSolver.jl +++ b/test/OptimisersSolver.jl @@ -1,3 +1,4 @@ +import LinearAlgebra import SolverCore import NLPModels import Optimisers @@ -6,24 +7,28 @@ import Optimisers # the variable vector of an unconstrained `AbstractNLPModel` using `obj` and # `grad!`. Designed to be plugged into `NLPModelsJuMP.Optimizer` via # `set_attribute(model, "solver", OptimisersSolver)`. -mutable struct OptimisersSolver{R<:Optimisers.AbstractRule} <: +mutable struct OptimisersSolver{R<:Optimisers.AbstractRule,V<:AbstractVector} <: SolverCore.AbstractOptimizationSolver rule::R - x::Vector{Float64} - g::Vector{Float64} + x::V + g::V end function OptimisersSolver( - nlp::NLPModels.AbstractNLPModel; - rule::Optimisers.AbstractRule = Optimisers.Adam(0.05), -) + nlp::NLPModels.AbstractNLPModel{T,V}; + rule::Optimisers.AbstractRule = Optimisers.Adam(T(0.05)), +) where {T,V<:AbstractVector{T}} nvar = NLPModels.get_nvar(nlp.meta) - return OptimisersSolver(rule, zeros(Float64, nvar), zeros(Float64, nvar)) + x = similar(NLPModels.get_x0(nlp.meta), nvar) + g = similar(x) + fill!(x, zero(T)) + fill!(g, zero(T)) + return OptimisersSolver(rule, x, g) end function SolverCore.reset!(solver::OptimisersSolver) - fill!(solver.x, 0.0) - fill!(solver.g, 0.0) + fill!(solver.x, zero(eltype(solver.x))) + fill!(solver.g, zero(eltype(solver.g))) return solver end