Skip to content

Commit 801464d

Browse files
committed
Create Output struct
1 parent 99472de commit 801464d

File tree

8 files changed

+118
-70
lines changed

8 files changed

+118
-70
lines changed

src/SolverCore.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@ module SolverCore
33
# stdlib
44
using Logging, Printf
55

6-
# our packages
7-
using NLPModels
8-
6+
# include("stats.jl")
7+
include("logger.jl")
8+
include("output.jl")
99
include("solver.jl")
10+
include("traits.jl")
11+
1012
include("grid-search-tuning.jl")
11-
include("logger.jl")
12-
include("stats.jl")
13+
14+
include("optsolver.jl")
1315

1416
end

src/optsolver.jl

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
abstract type AbstractOptSolver{T} <: AbstractSolver{T}
1+
using NLPModels
2+
3+
export AbstractOptSolver, OptSolverOutput
4+
5+
abstract type AbstractOptSolver{T} <: AbstractSolver{T} end
26

37
#=
48
Constructors:
@@ -14,6 +18,54 @@ function (::Type{S})(::Type{T}, nlp :: AbstractNLPModel) where {T, S <: Abstract
1418
return output, solver
1519
end
1620
(::Type{S})(::Type{T}, ::Val{:nosolve}, nlp :: AbstractNLPModel) where {T, S <: AbstractOptSolver} = S(T, nlp.meta)
17-
(::Type{S})(::Val{:nosolve}, nlp :: AbstractNLPModel) where {S <: AbstractOptSolver} = S(Float64, Val(:nosolve), nlp)
18-
(::Type{S})(nlp :: AbstractNLPModel) where {S <: AbstractOptSolver} = S(Float64, nlp)
19-
(::Type{S})(meta :: AbstractNLPModelMeta) where {S <: AbstractOptSolver} = S(Float64, meta)
21+
(::Type{S})(::Val{:nosolve}, nlp :: AbstractNLPModel) where {S <: AbstractOptSolver} = S(eltype(nlp.meta.x0), Val(:nosolve), nlp)
22+
(::Type{S})(nlp :: AbstractNLPModel) where {S <: AbstractOptSolver} = S(eltype(nlp.meta.x0), nlp)
23+
(::Type{S})(meta :: AbstractNLPModelMeta) where {S <: AbstractOptSolver} = S(eltype(meta.x0), meta)
24+
25+
mutable struct OptSolverOutput{T} <: AbstractSolverOutput{T}
26+
status :: Symbol
27+
solution
28+
objective :: T # f(x)
29+
dual_feas :: T # ‖∇f(x)‖₂ for unc, ‖P[x - ∇f(x)] - x‖₂ for bnd, etc.
30+
primal_feas :: T # ‖c(x)‖ for equalities
31+
multipliers
32+
multipliers_L
33+
multipliers_U
34+
iter :: Int
35+
counters :: NLPModels.NLSCounters
36+
elapsed_time :: Float64
37+
solver_specific :: Dict{Symbol,Any}
38+
end
39+
40+
function OptSolverOutput(
41+
status :: Symbol,
42+
solution :: AbstractArray{T},
43+
nlp :: AbstractNLPModel;
44+
objective :: T = T(Inf),
45+
dual_feas :: T = T(Inf),
46+
primal_feas :: T = unconstrained(nlp) || bound_constrained(nlp) ? zero(T) : T(Inf),
47+
multipliers :: Vector = T[],
48+
multipliers_L :: Vector = T[],
49+
multipliers_U :: Vector = T[],
50+
iter :: Int=-1,
51+
elapsed_time :: Float64=Inf,
52+
solver_specific :: Dict = Dict{Symbol,Any}()
53+
) where T
54+
if !(status in keys(STATUSES))
55+
@error "status $status is not a valid status. Use one of the following: " join(keys(STATUSES), ", ")
56+
throw(KeyError(status))
57+
end
58+
c = NLSCounters()
59+
for counter in fieldnames(Counters)
60+
setfield!(c.counters, counter, eval(Meta.parse("$counter"))(nlp))
61+
end
62+
if nlp isa AbstractNLSModel
63+
for counter in fieldnames(NLSCounters)
64+
counter == :counters && continue
65+
setfield!(c, counter, eval(Meta.parse("$counter"))(nlp))
66+
end
67+
end
68+
return OptSolverOutput{T}(status, solution, objective, dual_feas, primal_feas,
69+
multipliers, multipliers_L, multipliers_U, iter,
70+
c, elapsed_time, solver_specific)
71+
end

src/output.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
export AbstractSolverOutput
2+
3+
"""
4+
AbstractSolverOutput
5+
6+
Base type for output of JSO-compliant solvers.
7+
An output must have at least the following:
8+
- `status :: Bool`
9+
- `solution`
10+
"""
11+
abstract type AbstractSolverOutput{T} end
12+
13+
const STATUSES = Dict(
14+
:exception => "unhandled exception",
15+
:first_order => "first-order stationary",
16+
:acceptable => "solved to within acceptable tolerances",
17+
:infeasible => "problem may be infeasible",
18+
:max_eval => "maximum number of function evaluations",
19+
:max_iter => "maximum iteration",
20+
:max_time => "maximum elapsed time",
21+
:neg_pred => "negative predicted reduction",
22+
:not_desc => "not a descent direction",
23+
:small_residual => "small residual",
24+
:small_step => "step too small",
25+
:stalled => "stalled",
26+
:unbounded => "objective function may be unbounded from below",
27+
:unknown => "unknown",
28+
:user => "user-requested stop",
29+
)
30+
31+
function Base.show(io :: IO, output :: AbstractSolverOutput)
32+
println(io, "Solver output of type $(typeof(output))")
33+
println(io, "Status: $(STATUSES[output.status])")
34+
end

src/solver.jl

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ A solver must have three members:
1212
abstract type AbstractSolver{T} end
1313

1414
function Base.show(io :: IO, solver :: AbstractSolver)
15-
show(io, "Solver $(typeof(solver))")
15+
println(io, "Solver $(typeof(solver))")
1616
end
1717

1818
"""
@@ -43,9 +43,4 @@ Each key of `named_tuple` is the name of a parameter, and its value is a NamedTu
4343
function parameters(::Type{AbstractSolver{T}}) where T end
4444

4545
parameters(::Type{S}) where S <: AbstractSolver = parameters(S{Float64})
46-
parameters(solver :: AbstractSolver) = parameters(typeof(solver))
47-
48-
# To be removed in the future
49-
50-
include("optsolver.jl")
51-
# include("linearsolver.jl")
46+
parameters(solver :: AbstractSolver) = parameters(typeof(solver))

src/stats.jl

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,6 @@
11
export AbstractExecutionStats, GenericExecutionStats,
22
statsgetfield, statshead, statsline, getStatus, show_statuses
33

4-
const STATUSES = Dict(
5-
:exception => "unhandled exception",
6-
:first_order => "first-order stationary",
7-
:acceptable => "solved to within acceptable tolerances",
8-
:infeasible => "problem may be infeasible",
9-
:max_eval => "maximum number of function evaluations",
10-
:max_iter => "maximum iteration",
11-
:max_time => "maximum elapsed time",
12-
:neg_pred => "negative predicted reduction",
13-
:not_desc => "not a descent direction",
14-
:small_residual => "small residual",
15-
:small_step => "step too small",
16-
:stalled => "stalled",
17-
:unbounded => "objective function may be unbounded from below",
18-
:unknown => "unknown",
19-
:user => "user-requested stop",
20-
)
214

225
"""
236
show_statuses()

test/dummy_solver.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
mutable struct DummySolver{T} <: AbstractSolver{T}
1+
mutable struct DummySolver{T} <: AbstractOptSolver{T}
22
initialized :: Bool
33
params :: Dict
44
workspace
@@ -59,7 +59,7 @@ function SolverCore.solve!(solver::DummySolver{T}, nlp :: AbstractNLPModel;
5959
x = solver.workspace.x # Change reference
6060

6161
cx = solver.workspace.cx .= ncon > 0 ? cons(nlp, x) : zeros(T, 0)
62-
ct = solver.workspace.ct = zeros(T, ncon)
62+
ct = solver.workspace.ct .= zero(T)
6363
grad!(nlp, x, solver.workspace.gx)
6464
gx = solver.workspace.gx
6565
Jx = ncon > 0 ? jac(nlp, x) : zeros(T, 0, nvar)
@@ -133,8 +133,9 @@ function SolverCore.solve!(solver::DummySolver{T}, nlp :: AbstractNLPModel;
133133
:max_eval
134134
end
135135

136-
return GenericExecutionStats(
136+
return OptSolverOutput(
137137
status,
138+
x,
138139
nlp,
139140
objective=fx,
140141
dual_feas=norm(dual),
@@ -143,7 +144,6 @@ function SolverCore.solve!(solver::DummySolver{T}, nlp :: AbstractNLPModel;
143144
multipliers_L=zeros(T, nvar),
144145
multipliers_U=zeros(T, nvar),
145146
elapsed_time=elapsed_time,
146-
solution=x,
147147
iter=iter
148148
)
149149
end

test/test_logging.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ function test_logging()
99

1010
with_logger(ConsoleLogger()) do
1111
@info "Testing dummy solver with logger"
12-
solver = DummySolver()
13-
solver(nlps[1], max_eval=20)
12+
solver = DummySolver(nlps[1].meta)
13+
solve!(solver, nlps[1], max_eval=20)
1414
reset!.(nlps)
1515
end
1616
end

test/test_stats.jl

Lines changed: 13 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
function test_stats()
2-
show_statuses()
32
nlp = ADNLPModel(x->dot(x,x), zeros(2))
4-
stats = GenericExecutionStats(
3+
stats = OptSolverOutput(
54
:first_order,
5+
ones(100),
66
nlp,
77
objective=1.0,
88
dual_feas=1e-12,
9-
solution=ones(100),
109
iter=10,
1110
solver_specific=Dict(:matvec=>10,
1211
:dot=>25,
@@ -17,39 +16,23 @@ function test_stats()
1716
)
1817
)
1918

20-
show(stats)
21-
print(stats)
22-
println(stats)
23-
open("teststats.out", "w") do f
24-
println(f, stats)
25-
end
26-
27-
println(stats, showvec=(io,x)->print(io,x))
28-
open("teststats.out", "a") do f
29-
println(f, stats, showvec=(io,x)->print(io,x))
30-
end
31-
32-
line = [:status, :neval_obj, :objective, :iter]
33-
for field in line
34-
value = statsgetfield(stats, field)
35-
println("$field -> $value")
36-
end
37-
println(statshead(line))
38-
println(statsline(stats, line))
19+
io = IOBuffer()
20+
show(io, stats)
21+
@test String(take!(io)) == "Solver output of type OptSolverOutput{Float64}\nStatus: first-order stationary\n"
3922

4023
@testset "Testing inference" begin
4124
for T in (Float16, Float32, Float64, BigFloat)
4225
nlp = ADNLPModel(x->dot(x, x), ones(T, 2))
4326

44-
stats = GenericExecutionStats(:first_order, nlp)
27+
stats = OptSolverOutput(:first_order, nlp.meta.x0, nlp)
4528
@test stats.status == :first_order
4629
@test typeof(stats.objective) == T
4730
@test typeof(stats.dual_feas) == T
4831
@test typeof(stats.primal_feas) == T
4932

5033
nlp = ADNLPModel(x->dot(x, x), ones(T, 2), x->[sum(x)-1], [0.0], [0.0])
5134

52-
stats = GenericExecutionStats(:first_order, nlp)
35+
stats = OptSolverOutput(:first_order, nlp.meta.x0, nlp)
5336
@test stats.status == :first_order
5437
@test typeof(stats.objective) == T
5538
@test typeof(stats.dual_feas) == T
@@ -58,17 +41,16 @@ function test_stats()
5841
end
5942

6043
@testset "Test throws" begin
61-
@test_throws Exception GenericExecutionStats(:bad, nlp)
62-
@test_throws Exception GenericExecutionStats(:unkwown, nlp, bad=true)
44+
@test_throws Exception OptSolverOutput(:bad, nlp)
45+
@test_throws Exception OptSolverOutput(:unkwown, nlp, bad=true)
6346
end
6447

6548
@testset "Testing Dummy Solver with multi-precision" begin
66-
solver = DummySolver()
6749
for T in (Float16, Float32, Float64, BigFloat)
6850
nlp = ADNLPModel(x->dot(x, x), ones(T, 2))
6951

70-
with_logger(NullLogger()) do
71-
stats = solver(nlp)
52+
stats, solver = with_logger(NullLogger()) do
53+
DummySolver(nlp)
7254
end
7355
@test typeof(stats.objective) == T
7456
@test typeof(stats.dual_feas) == T
@@ -80,8 +62,8 @@ function test_stats()
8062

8163
nlp = ADNLPModel(x->dot(x, x), ones(T, 2), x->[sum(x)-1], [0.0], [0.0])
8264

83-
with_logger(NullLogger()) do
84-
stats = solver(nlp)
65+
stats, solver = with_logger(NullLogger()) do
66+
DummySolver(nlp)
8567
end
8668
@test typeof(stats.objective) == T
8769
@test typeof(stats.dual_feas) == T

0 commit comments

Comments
 (0)