1- function SciMLBase. solve (
2- prob:: NonlinearProblem {<: Union{Number, <:AbstractArray} , iip,
3- <: Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} },
4- alg:: AbstractSimpleNonlinearSolveAlgorithm ,
5- args... ;
6- kwargs... ) where {T, V, P, iip}
7- sol, partials = __nlsolve_ad (prob, alg, args... ; kwargs... )
8- dual_soln = __nlsolve_dual_soln (sol. u, partials, prob. p)
9- return SciMLBase. build_solution (
10- prob, alg, dual_soln, sol. resid; sol. retcode, sol. stats, sol. original)
11- end
12-
13- function SciMLBase. solve (
14- prob:: NonlinearLeastSquaresProblem {
15- <: AbstractArray , iip, <: Union{<:AbstractArray{<:Dual{T, V, P}}} },
16- alg:: AbstractSimpleNonlinearSolveAlgorithm ,
17- args... ;
18- kwargs... ) where {T, V, P, iip}
19- sol, partials = __nlsolve_ad (prob, alg, args... ; kwargs... )
20- dual_soln = __nlsolve_dual_soln (sol. u, partials, prob. p)
21- return SciMLBase. build_solution (
22- prob, alg, dual_soln, sol. resid; sol. retcode, sol. stats, sol. original)
1+ for pType in (NonlinearProblem, NonlinearLeastSquaresProblem)
2+ @eval function SciMLBase. solve (
3+ prob:: $ (pType){<: Union{Number, <:AbstractArray} , iip,
4+ <: Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} },
5+ alg:: AbstractSimpleNonlinearSolveAlgorithm ,
6+ args... ;
7+ kwargs... ) where {T, V, P, iip}
8+ sol, partials = __nlsolve_ad (prob, alg, args... ; kwargs... )
9+ dual_soln = __nlsolve_dual_soln (sol. u, partials, prob. p)
10+ return SciMLBase. build_solution (
11+ prob, alg, dual_soln, sol. resid; sol. retcode, sol. stats, sol. original)
12+ end
2313end
2414
2515for algType in (Bisection, Brent, Alefeld, Falsi, ITP, Ridder)
@@ -47,8 +37,7 @@ function __nlsolve_ad(
4737 tspan = value .(prob. tspan)
4838 newprob = IntervalNonlinearProblem (prob. f, tspan, p; prob. kwargs... )
4939 else
50- u0 = value (prob. u0)
51- newprob = NonlinearProblem (prob. f, u0, p; prob. kwargs... )
40+ newprob = remake (prob; p, u0 = value (prob. u0))
5241 end
5342
5443 sol = solve (newprob, alg, args... ; kwargs... )
@@ -73,20 +62,16 @@ function __nlsolve_ad(
7362end
7463
7564function __nlsolve_ad (prob:: NonlinearLeastSquaresProblem , alg, args... ; kwargs... )
76- p = value (prob. p)
77- u0 = value (prob. u0)
78- newprob = NonlinearLeastSquaresProblem (prob. f, u0, p; prob. kwargs... )
79-
65+ newprob = remake (prob; p = value (prob. p), u0 = value (prob. u0))
8066 sol = solve (newprob, alg, args... ; kwargs... )
81-
8267 uu = sol. u
8368
8469 # First check for custom `vjp` then custom `Jacobian` and if nothing is provided use
8570 # nested autodiff as the last resort
8671 if SciMLBase. has_vjp (prob. f)
8772 if isinplace (prob)
8873 _F = @closure (du, u, p) -> begin
89- resid = similar (du, length (sol. resid))
74+ resid = __similar (du, length (sol. resid))
9075 prob. f (resid, u, p)
9176 prob. f. vjp (du, resid, u, p)
9277 du .*= 2
@@ -101,9 +86,9 @@ function __nlsolve_ad(prob::NonlinearLeastSquaresProblem, alg, args...; kwargs..
10186 elseif SciMLBase. has_jac (prob. f)
10287 if isinplace (prob)
10388 _F = @closure (du, u, p) -> begin
104- J = similar (du, length (sol. resid), length (u))
89+ J = __similar (du, length (sol. resid), length (u))
10590 prob. f. jac (J, u, p)
106- resid = similar (du, length (sol. resid))
91+ resid = __similar (du, length (sol. resid))
10792 prob. f (resid, u, p)
10893 mul! (reshape (du, 1 , :), vec (resid)' , J, 2 , false )
10994 return nothing
@@ -116,43 +101,40 @@ function __nlsolve_ad(prob::NonlinearLeastSquaresProblem, alg, args...; kwargs..
116101 else
117102 if isinplace (prob)
118103 _F = @closure (du, u, p) -> begin
119- resid = similar (du, length (sol. resid))
120- res = DiffResults. DiffResult (
121- resid, similar (du, length (sol. resid), length (u)))
122104 _f = @closure (du, u) -> prob. f (du, u, p)
123- ForwardDiff . jacobian! (res, _f, resid, u )
124- mul! ( reshape (du, 1 , :), vec (DiffResults . value (res)) ' ,
125- DiffResults . jacobian (res) , 2 , false )
105+ resid = __similar (du, length (sol . resid) )
106+ v, J = DI . value_and_jacobian (_f, resid, AutoForwardDiff (), u)
107+ mul! ( reshape (du, 1 , :), vec (v) ' , J , 2 , false )
126108 return nothing
127109 end
128110 else
129111 # For small problems, nesting ForwardDiff is actually quite fast
130112 if __is_extension_loaded (Val (:Zygote )) && (length (uu) + length (sol. resid) ≥ 50 )
131- _F = @closure (u, p) -> __zygote_compute_nlls_vjp (prob. f, u, p)
113+ # TODO : Remove once DI has the value_and_pullback_split defined
114+ _F = @closure (u, p) -> begin
115+ _f = Base. Fix2 (prob. f, p)
116+ return __zygote_compute_nlls_vjp (_f, u, p)
117+ end
132118 else
133119 _F = @closure (u, p) -> begin
134- T = promote_type (eltype (u), eltype (p))
135- res = DiffResults. DiffResult (similar (u, T, size (sol. resid)),
136- similar (u, T, length (sol. resid), length (u)))
137- ForwardDiff. jacobian! (res, Base. Fix2 (prob. f, p), u)
138- return reshape (
139- 2 .* vec (DiffResults. value (res))' * DiffResults. jacobian (res),
140- size (u))
120+ _f = Base. Fix2 (prob. f, p)
121+ v, J = DI. value_and_jacobian (_f, AutoForwardDiff (), u)
122+ return reshape (2 .* vec (v)' * J, size (u))
141123 end
142124 end
143125 end
144126 end
145127
146- f_p = __nlsolve_∂f_∂p (prob, _F, uu, p)
147- f_x = __nlsolve_∂f_∂u (prob, _F, uu, p)
128+ f_p = __nlsolve_∂f_∂p (prob, _F, uu, newprob . p)
129+ f_x = __nlsolve_∂f_∂u (prob, _F, uu, newprob . p)
148130
149131 z_arr = - f_x \ f_p
150132
151133 pp = prob. p
152134 sumfun = ((z, p),) -> map (zᵢ -> zᵢ * ForwardDiff. partials (p), z)
153135 if uu isa Number
154136 partials = sum (sumfun, zip (z_arr, pp))
155- elseif p isa Number
137+ elseif pp isa Number
156138 partials = sumfun ((z_arr, pp))
157139 else
158140 partials = sum (sumfun, zip (eachcol (z_arr), pp))
164146@inline function __nlsolve_∂f_∂p (prob, f:: F , u, p) where {F}
165147 if isinplace (prob)
166148 __f = p -> begin
167- du = similar (u, promote_type (eltype (u), eltype (p)))
149+ du = __similar (u, promote_type (eltype (u), eltype (p)))
168150 f (du, u, p)
169151 return du
170152 end
@@ -182,16 +164,12 @@ end
182164
183165@inline function __nlsolve_∂f_∂u (prob, f:: F , u, p) where {F}
184166 if isinplace (prob)
185- du = similar (u)
186- __f = (du, u) -> f (du, u, p)
187- ForwardDiff. jacobian (__f, du, u)
167+ __f = @closure (du, u) -> f (du, u, p)
168+ return ForwardDiff. jacobian (__f, __similar (u), u)
188169 else
189170 __f = Base. Fix2 (f, p)
190- if u isa Number
191- return ForwardDiff. derivative (__f, u)
192- else
193- return ForwardDiff. jacobian (__f, u)
194- end
171+ u isa Number && return ForwardDiff. derivative (__f, u)
172+ return ForwardDiff. jacobian (__f, u)
195173 end
196174end
197175
0 commit comments