diff --git a/src/gauss_adjoint.jl b/src/gauss_adjoint.jl index 2c176d85f..5310a4f91 100644 --- a/src/gauss_adjoint.jl +++ b/src/gauss_adjoint.jl @@ -495,14 +495,15 @@ function vec_pjac!(out, λ, y, t, S::GaussIntegrand) vtmp4 .= λ Enzyme.remake_zero!(tmp3) Enzyme.remake_zero!(out) - + + dp = isscimlstructure(p) ? repack(out) : out if SciMLBase.isinplace(sol.prob.f) Enzyme.remake_zero!(tmp6) Enzyme.autodiff( Enzyme.Reverse, Enzyme.Duplicated(pf, tmp6), Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4), - Enzyme.Const(y), Enzyme.Duplicated(p, out), Enzyme.Const(t)) + Enzyme.Const(y), Enzyme.Duplicated(p, dp), Enzyme.Const(t)) else function g(du, u, p, t) du .= f(u, p, t) @@ -512,7 +513,10 @@ function vec_pjac!(out, λ, y, t, S::GaussIntegrand) Enzyme.autodiff( Enzyme.Reverse, Enzyme.Duplicated(g, tmp6), Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4), - Enzyme.Const(y), Enzyme.Duplicated(p, out), Enzyme.Const(t)) + Enzyme.Const(y), Enzyme.Duplicated(p, dp), Enzyme.Const(t)) + end + if isscimlstructure(p) + out .= canonicalize(Tunable(), dp)[1] end elseif sensealg.autojacvec isa MooncakeVJP _, _, p_grad = mooncake_run_ad(paramjac_config, y, p, t, λ) diff --git a/test/scimlstructures_interface.jl b/test/scimlstructures_interface.jl index 303fa7a9c..043495077 100644 --- a/test/scimlstructures_interface.jl +++ b/test/scimlstructures_interface.jl @@ -159,3 +159,4 @@ end run_diff(initialize()) @test !iszero(Zygote.gradient(run_diff, initialize(), GaussAdjoint())[1].ps) @test !iszero(Zygote.gradient(run_diff, initialize(), GaussAdjoint(autojacvec = false))[1].ps) +@test !iszero(Zygote.gradient(run_diff, initialize(), GaussAdjoint(autojacvec = EnzymeVJP()))[1].ps)