diff --git a/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl b/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl index 2a87eb54e..495cadef2 100644 --- a/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl +++ b/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl @@ -39,6 +39,7 @@ const DualNonlinearLeastSquaresProblem = NonlinearLeastSquaresProblem{ } where {iip, T, V, P} abstract type AbstractSimpleNonlinearSolveAlgorithm <: AbstractNonlinearSolveAlgorithm end +configure_autodiff(prob, alg::AbstractSimpleNonlinearSolveAlgorithm) = alg const NLBUtils = NonlinearSolveBase.Utils @@ -59,12 +60,6 @@ function CommonSolve.solve( prob::NonlinearProblem, alg::AbstractSimpleNonlinearSolveAlgorithm, args...; kwargs... ) - cache = SciMLBase.__init(prob, alg, args...; kwargs...) - prob = cache.prob - if cache.retcode == ReturnCode.InitialFailure - return SciMLBase.build_solution(prob, alg, prob.u0, - NonlinearSolveBase.Utils.evaluate_f(prob, prob.u0); cache.retcode) - end prob = convert(ImmutableNonlinearProblem, prob) return solve(prob, alg, args...; kwargs...) end @@ -73,9 +68,7 @@ function CommonSolve.solve( prob::DualNonlinearProblem, alg::AbstractSimpleNonlinearSolveAlgorithm, args...; kwargs... ) - if hasfield(typeof(alg), :autodiff) && alg.autodiff === nothing - @set! alg.autodiff = AutoForwardDiff() - end + alg = configure_autodiff(prob, alg) prob = convert(ImmutableNonlinearProblem, prob) sol, partials = nonlinearsolve_forwarddiff_solve(prob, alg, args...; kwargs...) dual_soln = nonlinearsolve_dual_solution(sol.u, partials, prob.p) @@ -88,9 +81,7 @@ function CommonSolve.solve( prob::DualNonlinearLeastSquaresProblem, alg::AbstractSimpleNonlinearSolveAlgorithm, args...; kwargs... ) - if hasfield(typeof(alg), :autodiff) && alg.autodiff === nothing - @set! alg.autodiff = AutoForwardDiff() - end + alg = configure_autodiff(prob, alg) sol, partials = nonlinearsolve_forwarddiff_solve(prob, alg, args...; kwargs...) dual_soln = nonlinearsolve_dual_solution(sol.u, partials, prob.p) return SciMLBase.build_solution( @@ -103,6 +94,7 @@ function CommonSolve.solve( alg::AbstractSimpleNonlinearSolveAlgorithm, args...; sensealg = nothing, u0 = nothing, p = nothing, kwargs... ) + alg = configure_autodiff(prob, alg) cache = SciMLBase.__init(prob, alg, args...; kwargs...) prob = cache.prob if cache.retcode == ReturnCode.InitialFailure diff --git a/lib/SimpleNonlinearSolve/src/halley.jl b/lib/SimpleNonlinearSolve/src/halley.jl index 2d8446d90..773f4b569 100644 --- a/lib/SimpleNonlinearSolve/src/halley.jl +++ b/lib/SimpleNonlinearSolve/src/halley.jl @@ -20,11 +20,20 @@ A low-overhead implementation of Halley's Method. autodiff = nothing end +function configure_autodiff(prob, alg::SimpleHalley) + autodiff = something(alg.autodiff, AutoForwardDiff()) + autodiff = SciMLBase.has_jac(prob.f) ? autodiff : + NonlinearSolveBase.select_jacobian_autodiff(prob, autodiff) + @set! alg.autodiff = autodiff + alg +end + function SciMLBase.__solve( prob::ImmutableNonlinearProblem, alg::SimpleHalley, args...; abstol = nothing, reltol = nothing, maxiters = 1000, alias_u0 = false, termination_condition = nothing, kwargs... ) + autodiff = alg.autodiff x = NLBUtils.maybe_unaliased(prob.u0, alias_u0) fx = NLBUtils.evaluate_f(prob, x) T = promote_type(eltype(fx), eltype(x)) @@ -36,23 +45,21 @@ function SciMLBase.__solve( prob, abstol, reltol, fx, x, termination_condition, Val(:simple) ) - # The way we write the 2nd order derivatives, we know Enzyme won't work there - autodiff = alg.autodiff === nothing ? AutoForwardDiff() : alg.autodiff - @set! alg.autodiff = autodiff - @bb xo = copy(x) + fx_cache = (SciMLBase.isinplace(prob) && !SciMLBase.has_jac(prob.f)) ? + NLBUtils.safe_similar(fx) : fx + jac_cache = Utils.prepare_jacobian(prob, autodiff, fx_cache, x) + if NLBUtils.can_setindex(x) - A = NLBUtils.safe_similar(x, length(x), length(x)) Aaᵢ = NLBUtils.safe_similar(x, length(x)) cᵢ = NLBUtils.safe_similar(x) else - A, Aaᵢ, cᵢ = x, x, x + Aaᵢ, cᵢ = x, x, x end + J = Utils.compute_jacobian!!(nothing, prob, autodiff, fx_cache, x, jac_cache) for _ in 1:maxiters - fx, J, H = Utils.compute_jacobian_and_hessian(autodiff, prob, fx, x) - NLBUtils.can_setindex(x) || (A = J) # Factorize Once and Reuse @@ -67,13 +74,8 @@ function SciMLBase.__solve( end aᵢ = J_fact \ NLBUtils.safe_vec(fx) - A_ = NLBUtils.safe_vec(A) - @bb A_ = H × aᵢ - A = NLBUtils.restructure(A, A_) - - @bb Aaᵢ = A × aᵢ - @bb A .*= -1 - bᵢ = J_fact \ NLBUtils.safe_vec(Aaᵢ) + hvvp = Utils.compute_hvvp(prob, autodiff, fx_cache, x, aᵢ) + bᵢ = J_fact \ NLBUtils.safe_vec(hvvp) cᵢ_ = NLBUtils.safe_vec(cᵢ) @bb @. cᵢ_ = (aᵢ * aᵢ) / (-aᵢ + (T(0.5) * bᵢ)) @@ -84,6 +86,9 @@ function SciMLBase.__solve( @bb @. x += cᵢ @bb copyto!(xo, x) + + fx = NLBUtils.evaluate_f!!(prob, fx, x) + J = Utils.compute_jacobian!!(J, prob, autodiff, fx_cache, x, jac_cache) end return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.MaxIters) diff --git a/lib/SimpleNonlinearSolve/src/raphson.jl b/lib/SimpleNonlinearSolve/src/raphson.jl index 34efcbb90..510de2901 100644 --- a/lib/SimpleNonlinearSolve/src/raphson.jl +++ b/lib/SimpleNonlinearSolve/src/raphson.jl @@ -23,12 +23,21 @@ end const SimpleGaussNewton = SimpleNewtonRaphson +function configure_autodiff(prob, alg::SimpleNewtonRaphson) + autodiff = something(alg.autodiff, AutoForwardDiff()) + autodiff = SciMLBase.has_jac(prob.f) ? autodiff : + NonlinearSolveBase.select_jacobian_autodiff(prob, autodiff) + @set! alg.autodiff = autodiff + alg +end + function SciMLBase.__solve( prob::Union{ImmutableNonlinearProblem, NonlinearLeastSquaresProblem}, alg::SimpleNewtonRaphson, args...; abstol = nothing, reltol = nothing, maxiters = 1000, alias_u0 = false, termination_condition = nothing, kwargs... ) + autodiff = alg.autodiff x = NLBUtils.maybe_unaliased(prob.u0, alias_u0) fx = NLBUtils.evaluate_f(prob, x) @@ -39,10 +48,6 @@ function SciMLBase.__solve( prob, abstol, reltol, fx, x, termination_condition, Val(:simple) ) - autodiff = SciMLBase.has_jac(prob.f) ? alg.autodiff : - NonlinearSolveBase.select_jacobian_autodiff(prob, alg.autodiff) - @set! alg.autodiff = autodiff - @bb xo = similar(x) fx_cache = (SciMLBase.isinplace(prob) && !SciMLBase.has_jac(prob.f)) ? NLBUtils.safe_similar(fx) : fx diff --git a/lib/SimpleNonlinearSolve/src/utils.jl b/lib/SimpleNonlinearSolve/src/utils.jl index 8c35a324f..c19e7538c 100644 --- a/lib/SimpleNonlinearSolve/src/utils.jl +++ b/lib/SimpleNonlinearSolve/src/utils.jl @@ -158,26 +158,20 @@ function compute_jacobian!!(J, prob, autodiff, fx, x, ::DINoPreparation) return J end -function compute_jacobian_and_hessian(autodiff, prob, _, x::Number) +function compute_hvvp(prob, autodiff, _, x::Number, dir::Number) H = DI.second_derivative(prob.f, autodiff, x, Constant(prob.p)) - fx, J = DI.value_and_derivative(prob.f, autodiff, x, Constant(prob.p)) - return fx, J, H + return H*dir end -function compute_jacobian_and_hessian(autodiff, prob, fx, x) - if SciMLBase.isinplace(prob) - jac_fn = @closure (u, p) -> begin +function compute_hvvp(prob, autodiff, fx, x, dir) + jvp_fn = if SciMLBase.isinplace(prob) + @closure (u, p) -> begin du = NLBUtils.safe_similar(fx, promote_type(eltype(fx), eltype(u))) - return DI.jacobian(prob.f, du, autodiff, u, Constant(p)) + return only(DI.pushforward(prob.f, du, autodiff, u, (dir,), Constant(p))) end - J, H = DI.value_and_jacobian(jac_fn, autodiff, x, Constant(prob.p)) - fx = NLBUtils.evaluate_f!!(prob, fx, x) - return fx, J, H else - jac_fn = @closure (u, p) -> DI.jacobian(prob.f, autodiff, u, Constant(p)) - J, H = DI.value_and_jacobian(jac_fn, autodiff, x, Constant(prob.p)) - fx = NLBUtils.evaluate_f!!(prob, fx, x) - return fx, J, H + @closure (u, p) -> only(DI.pushforward(prob.f, autodiff, u, (dir,), Constant(p))) end + only(DI.pushforward(jvp_fn, autodiff, x, (dir,), Constant(prob.p))) end end diff --git a/lib/SimpleNonlinearSolve/test/core/rootfind_tests.jl b/lib/SimpleNonlinearSolve/test/core/rootfind_tests.jl index b6a39a9e9..4481645c7 100644 --- a/lib/SimpleNonlinearSolve/test/core/rootfind_tests.jl +++ b/lib/SimpleNonlinearSolve/test/core/rootfind_tests.jl @@ -28,10 +28,10 @@ ] function run_nlsolve_oop(f::F, u0, p = 2.0; solver) where {F} - return solve(NonlinearProblem{false}(f, u0, p), solver; abstol = 1e-9) + return @inferred solve(NonlinearProblem{false}(f, u0, p), solver; abstol = 1e-9) end function run_nlsolve_iip(f!::F, u0, p = 2.0; solver) where {F} - return solve(NonlinearProblem{true}(f!, u0, p), solver; abstol = 1e-9) + return @inferred solve(NonlinearProblem{true}(f!, u0, p), solver; abstol = 1e-9) end end