1- function scalar_nlsolve_ad (prob, alg, args... ; kwargs... )
2- f = prob. f
1+ function SciMLBase. solve (prob:: NonlinearProblem {<: Union{Number, <:AbstractArray} ,
2+ iip, <: Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} },
3+ alg:: AbstractSimpleNonlinearSolveAlgorithm , args... ; kwargs... ) where {T, V, P, iip}
4+ sol, partials = __nlsolve_ad (prob, alg, args... ; kwargs... )
5+ dual_soln = __nlsolve_dual_soln (sol. u, partials, prob. p)
6+ return SciMLBase. build_solution (prob, alg, dual_soln, sol. resid; sol. retcode, sol. stats,
7+ sol. original)
8+ end
9+
10+ # Handle Ambiguities
11+ for algType in (Bisection, Brent, Alefeld, Falsi, ITP, Ridder)
12+ @eval begin
13+ function SciMLBase. solve (prob:: IntervalNonlinearProblem {uType, iip,
14+ <: Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} },
15+ alg:: $ (algType), args... ; kwargs... ) where {uType, T, V, P, iip}
16+ sol, partials = __nlsolve_ad (prob, alg, args... ; kwargs... )
17+ dual_soln = __nlsolve_dual_soln (sol. u, partials, prob. p)
18+ return SciMLBase. build_solution (prob, alg, dual_soln, sol. resid; sol. retcode,
19+ sol. stats, sol. original, left = Dual {T, V, P} (sol. left, partials),
20+ right = Dual {T, V, P} (sol. right, partials))
21+ end
22+ end
23+ end
24+
25+ function __nlsolve_ad (prob, alg, args... ; kwargs... )
326 p = value (prob. p)
427 if prob isa IntervalNonlinearProblem
528 tspan = value .(prob. tspan)
6- newprob = IntervalNonlinearProblem (f, tspan, p; prob. kwargs... )
29+ newprob = IntervalNonlinearProblem (prob . f, tspan, p; prob. kwargs... )
730 else
831 u0 = value (prob. u0)
9- newprob = NonlinearProblem (f, u0, p; prob. kwargs... )
32+ newprob = NonlinearProblem (prob . f, u0, p; prob. kwargs... )
1033 end
1134
1235 sol = solve (newprob, alg, args... ; kwargs... )
1336
1437 uu = sol. u
15- f_p = scalar_nlsolve_ ∂f_∂p (f, uu, p)
16- f_x = scalar_nlsolve_ ∂f_∂u (f, uu, p)
38+ f_p = __nlsolve_ ∂f_∂p (prob, prob . f, uu, p)
39+ f_x = __nlsolve_ ∂f_∂u (prob, prob . f, uu, p)
1740
18- z_arr = - inv ( f_x) * f_p
41+ z_arr = - f_x \ f_p
1942
2043 pp = prob. p
2144 sumfun = ((z, p),) -> map (zᵢ -> zᵢ * ForwardDiff. partials (p), z)
@@ -30,60 +53,47 @@ function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
3053 return sol, partials
3154end
3255
33- function SciMLBase. solve (prob:: NonlinearProblem {<: Union{Number, SVector, <:AbstractArray} ,
34- false , <: Dual{T, V, P} }, alg:: AbstractSimpleNonlinearSolveAlgorithm , args... ;
35- kwargs... ) where {T, V, P}
36- sol, partials = scalar_nlsolve_ad (prob, alg, args... ; kwargs... )
37- dual_soln = scalar_nlsolve_dual_soln (sol. u, partials, prob. p)
38- return SciMLBase. build_solution (prob, alg, dual_soln, sol. resid; sol. retcode)
39- end
40-
41- function SciMLBase. solve (prob:: NonlinearProblem {<: Union{Number, SVector, <:AbstractArray} ,
42- false , <: AbstractArray{<:Dual{T, V, P}} },
43- alg:: AbstractSimpleNonlinearSolveAlgorithm , args... ; kwargs... ) where {T, V, P}
44- sol, partials = scalar_nlsolve_ad (prob, alg, args... ; kwargs... )
45- dual_soln = scalar_nlsolve_dual_soln (sol. u, partials, prob. p)
46- return SciMLBase. build_solution (prob, alg, dual_soln, sol. resid; sol. retcode)
47- end
48-
49- function scalar_nlsolve_∂f_∂p (f, u, p)
50- ff = p isa Number ? ForwardDiff. derivative :
51- (u isa Number ? ForwardDiff. gradient : ForwardDiff. jacobian)
52- return ff (Base. Fix1 (f, u), p)
56+ @inline function __nlsolve_∂f_∂p (prob, f:: F , u, p) where {F}
57+ if isinplace (prob)
58+ __f = p -> begin
59+ du = similar (u, promote_type (eltype (u), eltype (p)))
60+ f (du, u, p)
61+ return du
62+ end
63+ else
64+ __f = Base. Fix1 (f, u)
65+ end
66+ if p isa Number
67+ return __reshape (ForwardDiff. derivative (__f, p), :, 1 )
68+ elseif u isa Number
69+ return __reshape (ForwardDiff. gradient (__f, p), 1 , :)
70+ else
71+ return ForwardDiff. jacobian (__f, p)
72+ end
5373end
5474
55- function scalar_nlsolve_∂f_∂u (f, u, p)
56- ff = u isa Number ? ForwardDiff. derivative : ForwardDiff. jacobian
57- return ff (Base. Fix2 (f, p), u)
75+ @inline function __nlsolve_∂f_∂u (prob, f:: F , u, p) where {F}
76+ if isinplace (prob)
77+ du = similar (u)
78+ __f = (du, u) -> f (du, u, p)
79+ ForwardDiff. jacobian (__f, du, u)
80+ else
81+ __f = Base. Fix2 (f, p)
82+ if u isa Number
83+ return ForwardDiff. derivative (__f, u)
84+ else
85+ return ForwardDiff. jacobian (__f, u)
86+ end
87+ end
5888end
5989
60- function scalar_nlsolve_dual_soln (u:: Number , partials,
90+ @inline function __nlsolve_dual_soln (u:: Number , partials,
6191 :: Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}} ) where {T, V, P}
6292 return Dual {T, V, P} (u, partials)
6393end
6494
65- function scalar_nlsolve_dual_soln (u:: AbstractArray , partials,
95+ @inline function __nlsolve_dual_soln (u:: AbstractArray , partials,
6696 :: Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}} ) where {T, V, P}
67- return map (((uᵢ, pᵢ),) -> Dual {T, V, P} (uᵢ, pᵢ), zip (u, partials))
68- end
69-
70- # avoid ambiguities
71- for Alg in [Bisection]
72- @eval function SciMLBase. solve (prob:: IntervalNonlinearProblem {uType, iip,
73- <: Dual{T, V, P} }, alg:: $Alg , args... ; kwargs... ) where {uType, iip, T, V, P}
74- sol, partials = scalar_nlsolve_ad (prob, alg, args... ; kwargs... )
75- dual_soln = scalar_nlsolve_dual_soln (sol. u, partials, prob. p)
76- return SciMLBase. build_solution (prob, alg, dual_soln, sol. resid; sol. retcode,
77- left = Dual {T, V, P} (sol. left, partials),
78- right = Dual {T, V, P} (sol. right, partials))
79- end
80- @eval function SciMLBase. solve (prob:: IntervalNonlinearProblem {uType, iip,
81- <: AbstractArray{<:Dual{T, V, P}} }, alg:: $Alg , args... ;
82- kwargs... ) where {uType, iip, T, V, P}
83- sol, partials = scalar_nlsolve_ad (prob, alg, args... ; kwargs... )
84- dual_soln = scalar_nlsolve_dual_soln (sol. u, partials, prob. p)
85- return SciMLBase. build_solution (prob, alg, dual_soln, sol. resid; sol. retcode,
86- left = Dual {T, V, P} (sol. left, partials),
87- right = Dual {T, V, P} (sol. right, partials))
88- end
97+ _partials = _restructure (u, partials)
98+ return map (((uᵢ, pᵢ),) -> Dual {T, V, P} (uᵢ, pᵢ), zip (u, _partials))
8999end
0 commit comments