Skip to content

Commit 7529692

Browse files
committed
fix: force \ for linear solving
1 parent 262c592 commit 7529692

File tree

2 files changed

+42
-34
lines changed

2 files changed

+42
-34
lines changed

benchmarks/NonlinearProblem/Manifest.toml

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -645,9 +645,9 @@ version = "1.0.5"
645645

646646
[[deps.Enzyme]]
647647
deps = ["CEnum", "EnzymeCore", "Enzyme_jll", "GPUCompiler", "LLVM", "Libdl", "LinearAlgebra", "ObjectFile", "PrecompileTools", "Preferences", "Printf", "Random", "SparseArrays"]
648-
git-tree-sha1 = "59c1db6e150d55f2df6a1383759931bf8571c6b8"
648+
git-tree-sha1 = "71147df4e324219b36b74ec4d34cea332c41933e"
649649
uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9"
650-
version = "0.13.35"
650+
version = "0.13.36"
651651

652652
[deps.Enzyme.extensions]
653653
EnzymeBFloat16sExt = "BFloat16s"
@@ -1352,9 +1352,9 @@ version = "3.2.2+2"
13521352

13531353
[[deps.Libglvnd_jll]]
13541354
deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_libX11_jll", "Xorg_libXext_jll"]
1355-
git-tree-sha1 = "ff3b4b9d35de638936a525ecd36e86a8bb919d11"
1355+
git-tree-sha1 = "36c4b9df1d1bac2fadb77b27959512ba6c541d91"
13561356
uuid = "7e76a0d4-f3c7-5321-8279-8d96eeed0f29"
1357-
version = "1.7.0+0"
1357+
version = "1.7.1+0"
13581358

13591359
[[deps.Libiconv_jll]]
13601360
deps = ["Artifacts", "JLLWrappers", "Libdl"]
@@ -1670,9 +1670,12 @@ version = "1.2.0"
16701670

16711671
[[deps.NonlinearProblemLibrary]]
16721672
deps = ["LinearAlgebra", "SciMLBase"]
1673-
git-tree-sha1 = "063d428dfdf88b79c834953b9f53b2c464a436bd"
1673+
git-tree-sha1 = "f5a4e83740e335b2cac3c12f016866c2cded2aaa"
1674+
repo-rev = "ap/eff"
1675+
repo-subdir = "lib/NonlinearProblemLibrary"
1676+
repo-url = "https://github.com/SciML/DiffEqProblemLibrary.jl.git"
16741677
uuid = "b7050fa9-e91f-4b37-bcee-a89a063da141"
1675-
version = "0.1.2"
1678+
version = "0.1.3"
16761679

16771680
[[deps.NonlinearSolve]]
16781681
deps = ["ADTypes", "ArrayInterface", "BracketingNonlinearSolve", "CommonSolve", "ConcreteStructs", "DiffEqBase", "DifferentiationInterface", "FastClosures", "FiniteDiff", "ForwardDiff", "LineSearch", "LinearAlgebra", "LinearSolve", "NonlinearSolveBase", "NonlinearSolveFirstOrder", "NonlinearSolveQuasiNewton", "NonlinearSolveSpectralMethods", "PrecompileTools", "Preferences", "Reexport", "SciMLBase", "SimpleNonlinearSolve", "SparseArrays", "SparseMatrixColorings", "StaticArraysCore", "SymbolicIndexingInterface"]
@@ -1708,9 +1711,9 @@ version = "4.6.0"
17081711

17091712
[[deps.NonlinearSolveBase]]
17101713
deps = ["ADTypes", "Adapt", "ArrayInterface", "CommonSolve", "Compat", "ConcreteStructs", "DifferentiationInterface", "EnzymeCore", "FastClosures", "LinearAlgebra", "Markdown", "MaybeInplace", "Preferences", "Printf", "RecursiveArrayTools", "SciMLBase", "SciMLJacobianOperators", "SciMLOperators", "StaticArraysCore", "SymbolicIndexingInterface", "TimerOutputs"]
1711-
git-tree-sha1 = "f8ece81557f7e42879f017fa089e5283988a5f67"
1714+
git-tree-sha1 = "e56b37efd60a8caefc7f96831bd9dd225afd6a4a"
17121715
uuid = "be0214bd-f91f-a760-ac4e-3421ce2b2da0"
1713-
version = "1.5.2"
1716+
version = "1.5.3"
17141717

17151718
[deps.NonlinearSolveBase.extensions]
17161719
NonlinearSolveBaseBandedMatricesExt = "BandedMatrices"

benchmarks/NonlinearProblem/nonlinear_solver_23_tests.jmd

Lines changed: 31 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -51,31 +51,31 @@ HagerZhang() = LineSearchesJL(; method = LineSearches.HagerZhang())
5151
MoreThuente() = LineSearchesJL(; method = LineSearches.MoreThuente())
5252

5353
solvers_all = [
54-
(; pkg = :nonlinearsolve, type = :general, name = "Default PolyAlg.", solver = Dict(:alg => FastShortcutNonlinearPolyalg(; u0_len = 10))),
55-
(; pkg = :nonlinearsolve, type = :NR, name = "Newton Raphson", solver = Dict(:alg => NewtonRaphson())),
56-
(; pkg = :nonlinearsolve, type = :NR, name = "NR (HagerZhang)", solver = Dict(:alg => NewtonRaphson(; linesearch = HagerZhang()))),
57-
(; pkg = :nonlinearsolve, type = :NR, name = "NR (MoreThuente)", solver = Dict(:alg => NewtonRaphson(; linesearch = MoreThuente()))),
58-
(; pkg = :nonlinearsolve, type = :NR, name = "NR (BackTracking)", solver = Dict(:alg => NewtonRaphson(; linesearch = BackTracking()))),
59-
(; pkg = :nonlinearsolve, type = :TR, name = "Trust Region", solver = Dict(:alg => TrustRegion())),
60-
(; pkg = :nonlinearsolve, type = :TR, name = "TR (NLsolve Update)", solver = Dict(:alg => TrustRegion(; radius_update_scheme = RUS.NLsolve))),
61-
(; pkg = :nonlinearsolve, type = :TR, name = "TR (Nocedal Wright)", solver = Dict(:alg => TrustRegion(; radius_update_scheme = RUS.NocedalWright))),
62-
(; pkg = :nonlinearsolve, type = :TR, name = "TR (Hei)", solver = Dict(:alg => TrustRegion(; radius_update_scheme = RUS.Hei))),
63-
(; pkg = :nonlinearsolve, type = :TR, name = "TR (Yuan)", solver = Dict(:alg => TrustRegion(; radius_update_scheme = RUS.Yuan))),
64-
(; pkg = :nonlinearsolve, type = :TR, name = "TR (Bastin)", solver = Dict(:alg => TrustRegion(; radius_update_scheme = RUS.Bastin))),
65-
(; pkg = :nonlinearsolve, type = :TR, name = "TR (Fan)", solver = Dict(:alg => TrustRegion(; radius_update_scheme = RUS.Fan))),
54+
(; pkg = :nonlinearsolve, type = :general, name = "Default PolyAlg.", solver = Dict(:alg => FastShortcutNonlinearPolyalg(; u0_len = 10, linsolve = \))),
55+
(; pkg = :nonlinearsolve, type = :NR, name = "Newton Raphson", solver = Dict(:alg => NewtonRaphson(; linsolve = \))),
56+
(; pkg = :nonlinearsolve, type = :NR, name = "NR (HagerZhang)", solver = Dict(:alg => NewtonRaphson(; linsolve = \, linesearch = HagerZhang()))),
57+
(; pkg = :nonlinearsolve, type = :NR, name = "NR (MoreThuente)", solver = Dict(:alg => NewtonRaphson(; linsolve = \, linesearch = MoreThuente()))),
58+
(; pkg = :nonlinearsolve, type = :NR, name = "NR (BackTracking)", solver = Dict(:alg => NewtonRaphson(; linsolve = \, linesearch = BackTracking()))),
59+
(; pkg = :nonlinearsolve, type = :TR, name = "Trust Region", solver = Dict(:alg => TrustRegion(; linsolve = \))),
60+
(; pkg = :nonlinearsolve, type = :TR, name = "TR (NLsolve Update)", solver = Dict(:alg => TrustRegion(; linsolve = \, radius_update_scheme = RUS.NLsolve))),
61+
(; pkg = :nonlinearsolve, type = :TR, name = "TR (Nocedal Wright)", solver = Dict(:alg => TrustRegion(; linsolve = \, radius_update_scheme = RUS.NocedalWright))),
62+
(; pkg = :nonlinearsolve, type = :TR, name = "TR (Hei)", solver = Dict(:alg => TrustRegion(; linsolve = \, radius_update_scheme = RUS.Hei))),
63+
(; pkg = :nonlinearsolve, type = :TR, name = "TR (Yuan)", solver = Dict(:alg => TrustRegion(; linsolve = \, radius_update_scheme = RUS.Yuan))),
64+
(; pkg = :nonlinearsolve, type = :TR, name = "TR (Bastin)", solver = Dict(:alg => TrustRegion(; linsolve = \, radius_update_scheme = RUS.Bastin))),
65+
(; pkg = :nonlinearsolve, type = :TR, name = "TR (Fan)", solver = Dict(:alg => TrustRegion(; linsolve = \, radius_update_scheme = RUS.Fan))),
6666
(; pkg = :nonlinearsolve, type = :LM, name = "Levenberg-Marquardt", solver = Dict(:alg => LevenbergMarquardt(; linsolve = QRFactorization()))),
6767
(; pkg = :nonlinearsolve, type = :LM, name = "LM with Cholesky", solver = Dict(:alg => LevenbergMarquardt(; linsolve = CholeskyFactorization()))),
6868
(; pkg = :nonlinearsolve, type = :LM, name = "LM (α_geodesic=0.5)", solver = Dict(:alg => LevenbergMarquardt(; linsolve = QRFactorization(), α_geodesic=0.5))),
6969
(; pkg = :nonlinearsolve, type = :LM, name = "LM (α_geodesic=0.5) Chol.", solver = Dict(:alg => LevenbergMarquardt(; linsolve = CholeskyFactorization(), α_geodesic=0.5))),
7070
(; pkg = :nonlinearsolve, type = :LM, name = "LM (no Accln.)", solver = Dict(:alg => LevenbergMarquardt(; linsolve = QRFactorization(), disable_geodesic = Val(true)))),
7171
(; pkg = :nonlinearsolve, type = :LM, name = "LM (no Accln.) Chol.", solver = Dict(:alg => LevenbergMarquardt(; linsolve = CholeskyFactorization(), disable_geodesic = Val(true)))),
72-
(; pkg = :nonlinearsolve, type = :general, name = "Pseudo Transient", solver = Dict(:alg => PseudoTransient(; alpha_initial=10.0))),
72+
(; pkg = :nonlinearsolve, type = :general, name = "Pseudo Transient", solver = Dict(:alg => PseudoTransient(; linsolve = \, alpha_initial=10.0))),
7373
(; pkg = :wrapper, type = :general, name = "Powell [MINPACK]", solver = Dict(:alg => CMINPACK(; method=:hybr))),
74-
(; pkg = :wrapper, type = :general, name = "LM [MINPACK]", solver = Dict(:alg => CMINPACK(; method=:lm))),
75-
(; pkg = :wrapper, type = :general, name = "NR [NLsolve.jl]", solver = Dict(:alg => NLsolveJL(; method=:newton))),
76-
(; pkg = :wrapper, type = :general, name = "TR [NLsolve.jl]", solver = Dict(:alg => NLsolveJL())),
77-
(; pkg = :wrapper, type = :general, name = "NR [Sundials]", solver = Dict(:alg => KINSOL(; linear_solver = :LapackDense, maxsetupcalls=1))),
78-
(; pkg = :wrapper, type = :general, name = "NR LineSearch [Sundials]", solver = Dict(:alg => KINSOL(; globalization_strategy=:LineSearch, maxsetupcalls=1)))
74+
(; pkg = :wrapper, type = :LM, name = "LM [MINPACK]", solver = Dict(:alg => CMINPACK(; method=:lm))),
75+
(; pkg = :wrapper, type = :NR, name = "NR [NLsolve.jl]", solver = Dict(:alg => NLsolveJL(; method=:newton))),
76+
(; pkg = :wrapper, type = :TR, name = "TR [NLsolve.jl]", solver = Dict(:alg => NLsolveJL())),
77+
(; pkg = :wrapper, type = :NR, name = "NR [Sundials]", solver = Dict(:alg => KINSOL(; linear_solver = :LapackDense, maxsetupcalls=1))),
78+
(; pkg = :wrapper, type = :NR, name = "NR LineSearch [Sundials]", solver = Dict(:alg => KINSOL(; linear_solver = :LapackDense, globalization_strategy=:LineSearch, maxsetupcalls=1)))
7979
];
8080

8181
solver_tracker = [];
@@ -95,22 +95,26 @@ Prepares various helper functions for benchmarking a specific problem.
9595
function set_ad_chunksize(solvers, u0)
9696
ck = NonlinearSolve.pickchunksize(u0)
9797
for i in eachindex(solvers)
98-
@set! solvers[i].solver[:alg] = __set_ad_chunksize(solvers[i].solver[:alg], ck)
98+
@set! solvers[i].solver[:alg] = __set_ad_chunksize(solvers[i].solver[:alg], ck, length(u0))
9999
end
100100
return solvers
101101
end
102102

103-
function __set_ad_chunksize(solver::GeneralizedFirstOrderAlgorithm, ck)
104-
ad = AutoPolyesterForwardDiff(; chunksize = ck)
103+
function __set_ad_chunksize(solver::GeneralizedFirstOrderAlgorithm, ck, N)
104+
if N > ck
105+
ad = AutoPolyesterForwardDiff(; chunksize = ck)
106+
else
107+
ad = AutoForwardDiff(; chunksize = ck)
108+
end
105109
return GeneralizedFirstOrderAlgorithm(; solver.descent, solver.linesearch,
106110
solver.trustregion, jvp_autodiff = ad, solver.max_shrink_times, solver.vjp_autodiff,
107111
concrete_jac = solver.concrete_jac, name = solver.name)
108112
end
109-
function __set_ad_chunksize(solver::NonlinearSolvePolyAlgorithm, ck)
110-
algs = [__set_ad_chunksize(alg, ck) for alg in solver.algs]
113+
function __set_ad_chunksize(solver::NonlinearSolvePolyAlgorithm, ck, N)
114+
algs = [__set_ad_chunksize(alg, ck, N) for alg in solver.algs]
111115
return NonlinearSolvePolyAlgorithm(algs; solver.start_index)
112116
end
113-
__set_ad_chunksize(solver, ck) = solver
117+
__set_ad_chunksize(solver, ck, N) = solver
114118

115119
# Benchmarks a specific problem, checks which solvers can solve it and their performance
116120
function benchmark_problem!(prob_name; solver_tracker=solver_tracker)
@@ -147,7 +151,8 @@ function benchmark_problem!(prob_name; solver_tracker=solver_tracker)
147151

148152
wp_general = WorkPrecisionSet(prob.prob, abstols, reltols,
149153
getfield.(solvers_general, :solver); names=getfield.(solvers_general, :name),
150-
numruns=100, error_estimate=:l∞, maxiters=1000)
154+
numruns=100, error_estimate=:l∞, maxiters=1000,
155+
termination_condition = NonlinearSolve.AbsNormTerminationMode(Base.Fix1(maximum, abs)))
151156

152157
push!(wp_general_tracker, prob_name => wp_general)
153158

@@ -321,7 +326,7 @@ end
321326

322327
# Benchmarks
323328

324-
We here run benchmarks for each of the 23 models.
329+
We here run benchmarks for each of the 23 models.
325330

326331
### Problem 1 (Generalized Rosenbrock function)
327332

0 commit comments

Comments
 (0)