@@ -21,7 +21,22 @@ function SimpleLimitedMemoryBroyden(; threshold::Union{Val, Int} = Val(27))
2121 return SimpleLimitedMemoryBroyden {SciMLBase._unwrap_val(threshold)} ()
2222end
2323
24- @views function SciMLBase. __solve (prob:: NonlinearProblem , alg:: SimpleLimitedMemoryBroyden ,
24+ function SciMLBase. __solve (prob:: NonlinearProblem , alg:: SimpleLimitedMemoryBroyden ,
25+ args... ; termination_condition = nothing , kwargs... )
26+ if prob. u0 isa SArray
27+ if termination_condition === nothing ||
28+ termination_condition isa AbsNormTerminationMode
29+ return __static_solve (prob, alg, args... ; termination_condition, kwargs... )
30+ end
31+ @warn " Specifying `termination_condition = $(termination_condition) ` for \
32+ `SimpleLimitedMemoryBroyden` with `SArray` is not non-allocating. Use \
33+ either `termination_condition = AbsNormTerminationMode()` or \
34+ `termination_condition = nothing`." maxlog= 1
35+ end
36+ return __generic_solve (prob, alg, args... ; termination_condition, kwargs... )
37+ end
38+
39+ @views function __generic_solve (prob:: NonlinearProblem , alg:: SimpleLimitedMemoryBroyden ,
2540 args... ; abstol = nothing , reltol = nothing , maxiters = 1000 , alias_u0 = false ,
2641 termination_condition = nothing , kwargs... )
2742 x = __maybe_unaliased (prob. u0, alias_u0)
3651
3752 fx = _get_fx (prob, x)
3853
39- U, Vᵀ = __init_low_rank_jacobian (x, fx, threshold)
54+ U, Vᵀ = __init_low_rank_jacobian (x, fx, x isa StaticArray ? threshold : Val (η) )
4055
4156 abstol, reltol, tc_cache = init_termination_cache (abstol, reltol, fx, x,
4257 termination_condition)
4863 @bb δf = copy (fx)
4964
5065 @bb vᵀ_cache = copy (x)
51- Tcache = __lbroyden_threshold_cache (x, threshold)
66+ Tcache = __lbroyden_threshold_cache (x, x isa StaticArray ? threshold : Val (η) )
5267 @bb mat_cache = copy (x)
5368
5469 for i in 1 : maxiters
8398 return build_solution (prob, alg, x, fx; retcode = ReturnCode. MaxIters)
8499end
85100
101+ # Non-allocating StaticArrays version of SimpleLimitedMemoryBroyden is actually quite
102+ # finicky, so we'll implement it separately from the generic version
103+ # Ignore termination_condition. Don't pass things into internal functions
104+ function __static_solve (prob:: NonlinearProblem{<:SArray} , alg:: SimpleLimitedMemoryBroyden ,
105+ args... ; abstol = nothing , maxiters = 1000 , kwargs... )
106+ x = prob. u0
107+ fx = _get_fx (prob, x)
108+ threshold = __get_threshold (alg)
109+
110+ U, Vᵀ = __init_low_rank_jacobian (vec (x), vec (fx), threshold)
111+
112+ abstol = DiffEqBase. _get_tolerance (abstol, eltype (x))
113+
114+ xo, δx, fo, δf = x, - fx, fx, fx
115+
116+ converged, res = __unrolled_lbroyden_initial_iterations (prob, xo, fo, δx, abstol, U, Vᵀ,
117+ threshold)
118+
119+ converged &&
120+ return build_solution (prob, alg, res. x, res. fx; retcode = ReturnCode. Success)
121+
122+ xo, fo, δx = res. x, res. fx, res. δx
123+
124+ for i in 1 : (maxiters - SciMLBase. _unwrap_val (threshold))
125+ x = xo .+ δx
126+ fx = prob. f (x, prob. p)
127+ δf = fx - fo
128+
129+ maximum (abs, fx) ≤ abstol &&
130+ return build_solution (prob, alg, x, fx; retcode = ReturnCode. Success)
131+
132+ vᵀ = _restructure (x, _rmatvec!! (U, Vᵀ, vec (δx)))
133+ mvec = _restructure (x, _matvec!! (U, Vᵀ, vec (δf)))
134+
135+ d = dot (vᵀ, δf)
136+ δx = @. (δx - mvec) / d
137+
138+ U = Base. setindex (U, vec (δx), mod1 (i, SciMLBase. _unwrap_val (threshold)))
139+ Vᵀ = Base. setindex (Vᵀ, vec (vᵀ), mod1 (i, SciMLBase. _unwrap_val (threshold)))
140+
141+ δx = - _restructure (fx, _matvec!! (U, Vᵀ, vec (fx)))
142+
143+ xo = x
144+ fo = fx
145+ end
146+
147+ return build_solution (prob, alg, xo, fo; retcode = ReturnCode. MaxIters)
148+ end
149+
150+ @generated function __unrolled_lbroyden_initial_iterations (prob, xo, fo, δx, abstol, U,
151+ Vᵀ, :: Val{threshold} ) where {threshold}
152+ calls = []
153+ for i in 1 : threshold
154+ static_idx, static_idx_p1 = Val (i - 1 ), Val (i)
155+ push! (calls,
156+ quote
157+ x = xo .+ δx
158+ fx = prob. f (x, prob. p)
159+ δf = fx - fo
160+
161+ maximum (abs, fx) ≤ abstol && return true , (; x, fx, δx)
162+
163+ _U = __first_n_getindex (U, $ (static_idx))
164+ _Vᵀ = __first_n_getindex (Vᵀ, $ (static_idx))
165+
166+ vᵀ = _restructure (x, _rmatvec!! (_U, _Vᵀ, vec (δx)))
167+ mvec = _restructure (x, _matvec!! (_U, _Vᵀ, vec (δf)))
168+
169+ d = dot (vᵀ, δf)
170+ δx = @. (δx - mvec) / d
171+
172+ U = Base. setindex (U, vec (δx), $ (i))
173+ Vᵀ = Base. setindex (Vᵀ, vec (vᵀ), $ (i))
174+
175+ _U = __first_n_getindex (U, $ (static_idx_p1))
176+ _Vᵀ = __first_n_getindex (Vᵀ, $ (static_idx_p1))
177+ δx = - _restructure (fx, _matvec!! (_U, _Vᵀ, vec (fx)))
178+
179+ xo = x
180+ fo = fx
181+ end )
182+ end
183+ push! (calls, quote
184+ # Termination Check
185+ maximum (abs, fx) ≤ abstol && return true , (; x, fx, δx)
186+
187+ return false , (; x, fx, δx)
188+ end )
189+ return Expr (:block , calls... )
190+ end
191+
86192function _rmatvec!! (y, xᵀU, U, Vᵀ, x)
87193 # xᵀ × (-I + UVᵀ)
88194 η = size (U, 2 )
@@ -98,6 +204,9 @@ function _rmatvec!!(y, xᵀU, U, Vᵀ, x)
98204 return y
99205end
100206
207+ @inline _rmatvec!! (:: Nothing , Vᵀ, x) = - x
208+ @inline _rmatvec!! (U, Vᵀ, x) = __mapTdot (__mapdot (x, U), Vᵀ) .- x
209+
101210function _matvec!! (y, Vᵀx, U, Vᵀ, x)
102211 # (-I + UVᵀ) × x
103212 η = size (U, 2 )
@@ -113,7 +222,56 @@ function _matvec!!(y, Vᵀx, U, Vᵀ, x)
113222 return y
114223end
115224
225+ @inline _matvec!! (:: Nothing , Vᵀ, x) = - x
226+ @inline _matvec!! (U, Vᵀ, x) = __mapTdot (__mapdot (x, Vᵀ), U) .- x
227+
228+ function __mapdot (x:: SVector{S1} , Y:: SVector{S2, <:SVector{S1}} ) where {S1, S2}
229+ return map (Base. Fix1 (dot, x), Y)
230+ end
231+ @generated function __mapTdot (x:: SVector{S1} , Y:: SVector{S1, <:SVector{S2}} ) where {S1, S2}
232+ calls = []
233+ syms = [gensym (" m$(i) " ) for i in 1 : S1]
234+ for i in 1 : S1
235+ push! (calls, :($ (syms[i]) = x[$ (i)] .* Y[$ i]))
236+ end
237+ push! (calls, :(return .+ ($ (syms... ))))
238+ return Expr (:block , calls... )
239+ end
240+
241+ @generated function __first_n_getindex (x:: SVector{L, T} , :: Val{N} ) where {L, T, N}
242+ @assert N ≤ L
243+ getcalls = ntuple (i -> :(x[$ i]), N)
244+ N == 0 && return :(return nothing )
245+ return :(return SVector {$N, $T} (($ (getcalls... ))))
246+ end
247+
116248__lbroyden_threshold_cache (x, :: Val{threshold} ) where {threshold} = similar (x, threshold)
117- function __lbroyden_threshold_cache (x:: SArray , :: Val{threshold} ) where {threshold}
118- return SArray {Tuple{threshold}, eltype(x)} (ntuple (_ -> zero (eltype (x)), threshold))
249+ function __lbroyden_threshold_cache (x:: StaticArray , :: Val{threshold} ) where {threshold}
250+ return zeros (MArray{Tuple{threshold}, eltype (x)})
251+ end
252+ __lbroyden_threshold_cache (x:: SArray , :: Val{threshold} ) where {threshold} = nothing
253+
254+ function __init_low_rank_jacobian (u:: StaticArray{S1, T1} , fu:: StaticArray{S2, T2} ,
255+ :: Val{threshold} ) where {S1, S2, T1, T2, threshold}
256+ T = promote_type (T1, T2)
257+ fuSize, uSize = Size (fu), Size (u)
258+ Vᵀ = MArray {Tuple{threshold, prod(uSize)}, T} (undef)
259+ U = MArray {Tuple{prod(fuSize), threshold}, T} (undef)
260+ return U, Vᵀ
261+ end
262+ @generated function __init_low_rank_jacobian (u:: SVector{Lu, T1} , fu:: SVector{Lfu, T2} ,
263+ :: Val{threshold} ) where {Lu, Lfu, T1, T2, threshold}
264+ T = promote_type (T1, T2)
265+ inner_inits_Vᵀ = [:(zeros (SVector{$ Lu, $ T})) for i in 1 : threshold]
266+ inner_inits_U = [:(zeros (SVector{$ Lfu, $ T})) for i in 1 : threshold]
267+ return quote
268+ Vᵀ = SVector ($ (inner_inits_Vᵀ... ))
269+ U = SVector ($ (inner_inits_U... ))
270+ return U, Vᵀ
271+ end
272+ end
273+ function __init_low_rank_jacobian (u, fu, :: Val{threshold} ) where {threshold}
274+ Vᵀ = similar (u, threshold, length (u))
275+ U = similar (u, length (fu), threshold)
276+ return U, Vᵀ
119277end
0 commit comments