Skip to content

Commit 99472de

Browse files
committed
[WIP] Try to generalize abstract solver and add workspace variable
1 parent 483716b commit 99472de

File tree

3 files changed

+53
-44
lines changed

3 files changed

+53
-44
lines changed

src/optsolver.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
abstract type AbstractOptSolver{T} <: AbstractSolver{T}
2+
3+
#=
4+
Constructors:
5+
- Solver(T, Val(:nosolve), nlp)
6+
- Solver(T, nlp)
7+
- Solver(meta)
8+
- Solver(Val(:nosolve), nlp)
9+
- Solver(nlp)
10+
=#
11+
function (::Type{S})(::Type{T}, nlp :: AbstractNLPModel) where {T, S <: AbstractOptSolver}
12+
solver = S(T, nlp.meta)
13+
output = solve!(solver, nlp)
14+
return output, solver
15+
end
16+
(::Type{S})(::Type{T}, ::Val{:nosolve}, nlp :: AbstractNLPModel) where {T, S <: AbstractOptSolver} = S(T, nlp.meta)
17+
(::Type{S})(::Val{:nosolve}, nlp :: AbstractNLPModel) where {S <: AbstractOptSolver} = S(Float64, Val(:nosolve), nlp)
18+
(::Type{S})(nlp :: AbstractNLPModel) where {S <: AbstractOptSolver} = S(Float64, nlp)
19+
(::Type{S})(meta :: AbstractNLPModelMeta) where {S <: AbstractOptSolver} = S(Float64, meta)

src/solver.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ export AbstractSolver, solve!, parameters
44
AbstractSolver
55
66
Base type for JSO-compliant solvers.
7+
A solver must have three members:
8+
- `initialized :: Bool`, indicating whether the solver was initialized
9+
- `params :: Dict`, a dictionary of solvers
10+
- `workspace`, a named tuple with arrays used by the solver.
711
"""
812
abstract type AbstractSolver{T} end
913

@@ -15,9 +19,8 @@ end
1519
output = solve!(solver, problem)
1620
1721
Solve `problem` with `solver`.
18-
This modifies internal
1922
"""
20-
function solve!(::AbstractSolver, ::AbstractNLPModel) end
23+
function solve!(::AbstractSolver, ::Any) end
2124

2225
"""
2326
named_tuple = parameters(solver)
@@ -40,4 +43,9 @@ Each key of `named_tuple` is the name of a parameter, and its value is a NamedTu
4043
function parameters(::Type{AbstractSolver{T}}) where T end
4144

4245
parameters(::Type{S}) where S <: AbstractSolver = parameters(S{Float64})
43-
parameters(solver :: AbstractSolver) = parameters(typeof(solver))
46+
parameters(solver :: AbstractSolver) = parameters(typeof(solver))
47+
48+
# To be removed in the future
49+
50+
include("optsolver.jl")
51+
# include("linearsolver.jl")

test/dummy_solver.jl

Lines changed: 23 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,40 @@
11
mutable struct DummySolver{T} <: AbstractSolver{T}
22
initialized :: Bool
33
params :: Dict
4-
x :: Vector{T}
5-
xt :: Vector{T}
6-
gx :: Vector{T}
7-
dual :: Vector{T}
8-
y :: Vector{T}
9-
cx :: Vector{T}
10-
ct :: Vector{T}
4+
workspace
115
end
126

137
function SolverCore.parameters(::Type{DummySolver{T}}) where T
148
(
159
α = (default=T(1e-2), type=:log, min=√√eps(T), max=one(T) / 2),
1610
δ = (default=eps(T), type=:log, min=eps(T), max=√√√eps(T)),
17-
reboot_y = (default=false, type=:bool)
11+
reboot_y = (default=false, type=:bool),
1812
)
1913
end
2014

15+
# function for validating given parameters. Instead of using constraints.
16+
2117
function DummySolver(::Type{T}, meta :: AbstractNLPModelMeta; kwargs...) where T
2218
nvar, ncon = meta.nvar, meta.ncon
2319
params = parameters(DummySolver{T})
2420
solver = DummySolver{T}(true,
2521
Dict(k => v[:default] for (k,v) in pairs(params)),
26-
zeros(T, nvar),
27-
zeros(T, nvar),
28-
zeros(T, nvar),
29-
zeros(T, nvar),
30-
zeros(T, ncon),
31-
zeros(T, ncon),
32-
zeros(T, ncon),
22+
( # workspace
23+
x = zeros(T, nvar),
24+
xt = zeros(T, nvar),
25+
gx = zeros(T, nvar),
26+
dual = zeros(T, nvar),
27+
y = zeros(T, ncon),
28+
cx = zeros(T, ncon),
29+
ct = zeros(T, ncon),
30+
)
3331
)
3432
for (k,v) in kwargs
3533
solver.params[k] = v
3634
end
3735
solver
3836
end
3937

40-
function DummySolver(::Type{T}, ::Val{:nosolve}, nlp :: AbstractNLPModel) where T
41-
solver = DummySolver(T, nlp.meta)
42-
return solver
43-
end
44-
45-
function DummySolver(::Type{T}, nlp :: AbstractNLPModel) where T
46-
solver = DummySolver(T, nlp.meta)
47-
output = solve!(solver, nlp)
48-
return output, solver
49-
end
50-
51-
DummySolver(meta :: AbstractNLPModelMeta) = DummySolver(Float64, meta :: AbstractNLPModelMeta)
52-
DummySolver(::Val{:nosolve}, nlp :: AbstractNLPModel) = DummySolver(Float64, Val(:nosolve), nlp :: AbstractNLPModel)
53-
DummySolver(nlp :: AbstractNLPModel) = DummySolver(Float64, nlp :: AbstractNLPModel)
54-
55-
5638
function SolverCore.solve!(solver::DummySolver{T}, nlp :: AbstractNLPModel;
5739
x :: AbstractVector{T} = T.(nlp.meta.x0),
5840
atol :: Real = sqrt(eps(T)),
@@ -61,6 +43,7 @@ function SolverCore.solve!(solver::DummySolver{T}, nlp :: AbstractNLPModel;
6143
max_time :: Float64 = 30.0,
6244
kwargs...
6345
) where T
46+
# Check dim
6447
solver.initialized || error("Solver not initialized.")
6548
nvar, ncon = nlp.meta.nvar, nlp.meta.ncon
6649
for (k,v) in kwargs
@@ -72,18 +55,18 @@ function SolverCore.solve!(solver::DummySolver{T}, nlp :: AbstractNLPModel;
7255

7356
start_time = time()
7457
elapsed_time = 0.0
75-
solver.x .= x # Copy values
76-
x = solver.x # Change reference
58+
solver.workspace.x .= x # Copy values
59+
x = solver.workspace.x # Change reference
7760

78-
cx = solver.cx .= ncon > 0 ? cons(nlp, x) : zeros(T, 0)
79-
ct = solver.ct = zeros(T, ncon)
80-
grad!(nlp, x, solver.gx)
81-
gx = solver.gx
61+
cx = solver.workspace.cx .= ncon > 0 ? cons(nlp, x) : zeros(T, 0)
62+
ct = solver.workspace.ct = zeros(T, ncon)
63+
grad!(nlp, x, solver.workspace.gx)
64+
gx = solver.workspace.gx
8265
Jx = ncon > 0 ? jac(nlp, x) : zeros(T, 0, nvar)
83-
y = solver.y .= -Jx' \ gx
66+
y = solver.workspace.y .= -Jx' \ gx
8467
Hxy = ncon > 0 ? hess(nlp, x, y) : hess(nlp, x)
8568

86-
dual = solver.dual .= gx .+ Jx' * y
69+
dual = solver.workspace.dual .= gx .+ Jx' * y
8770

8871
iter = 0
8972

@@ -106,7 +89,7 @@ function SolverCore.solve!(solver::DummySolver{T}, nlp :: AbstractNLPModel;
10689

10790
AΔx = Jx * Δx
10891
ϕx = ϕ(fx, cx, y)
109-
xt = solver.xt .= x + Δx
92+
xt = solver.workspace.xt .= x + Δx
11093
if ncon > 0
11194
cons!(nlp, xt, ct)
11295
end
@@ -124,7 +107,6 @@ function SolverCore.solve!(solver::DummySolver{T}, nlp :: AbstractNLPModel;
124107

125108
x .= xt
126109

127-
128110
fx = ft
129111
grad!(nlp, x, gx)
130112
Jx = ncon > 0 ? jac(nlp, x) : zeros(T, 0, nvar)

0 commit comments

Comments
 (0)