Skip to content

Commit e524c83

Browse files
Merge pull request #3440 from AayushSabharwal/as/ode-to-nonlinear
feat: allow `NonlinearSystem(::ODESystem)` and `NonlinearProblem(::ODESystem)`
2 parents a569f50 + 82b9b40 commit e524c83

File tree

2 files changed

+80
-0
lines changed

2 files changed

+80
-0
lines changed

src/systems/nonlinear/nonlinearsystem.jl

+30
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,32 @@ function NonlinearSystem(eqs; kwargs...)
232232
return NonlinearSystem(eqs, collect(allunknowns), collect(new_ps); kwargs...)
233233
end
234234

235+
"""
236+
$(TYPEDSIGNATURES)
237+
238+
Convert an `ODESystem` to a `NonlinearSystem` solving for its steady state (where derivatives are zero).
239+
Any differential variable `D(x) ~ f(...)` will be turned into `0 ~ f(...)`. The returned system is not
240+
simplified. If the input system is `complete`d, then so will the returned system.
241+
"""
242+
function NonlinearSystem(sys::AbstractODESystem)
243+
eqs = equations(sys)
244+
obs = observed(sys)
245+
subrules = Dict(D(x) => 0.0 for x in unknowns(sys))
246+
eqs = map(eqs) do eq
247+
fast_substitute(eq, subrules)
248+
end
249+
250+
nsys = NonlinearSystem(eqs, unknowns(sys), [parameters(sys); get_iv(sys)];
251+
parameter_dependencies = parameter_dependencies(sys),
252+
defaults = merge(defaults(sys), Dict(get_iv(sys) => Inf)), guesses = guesses(sys),
253+
initialization_eqs = initialization_equations(sys), name = nameof(sys),
254+
observed = obs)
255+
if iscomplete(sys)
256+
nsys = complete(nsys; split = is_split(sys))
257+
end
258+
return nsys
259+
end
260+
235261
function calculate_jacobian(sys::NonlinearSystem; sparse = false, simplify = false)
236262
cache = get_jac(sys)[]
237263
if cache isa Tuple && cache[2] == (sparse, simplify)
@@ -529,6 +555,10 @@ function DiffEqBase.NonlinearProblem{iip}(sys::NonlinearSystem, u0map,
529555
return remake(NonlinearProblem{iip}(f, u0, p, pt; filter_kwargs(kwargs)...))
530556
end
531557

558+
function DiffEqBase.NonlinearProblem(sys::AbstractODESystem, args...; kwargs...)
559+
NonlinearProblem(NonlinearSystem(sys), args...; kwargs...)
560+
end
561+
532562
"""
533563
```julia
534564
DiffEqBase.NonlinearLeastSquaresProblem{iip}(sys::NonlinearSystem, u0map,

test/nonlinearsystem.jl

+50
Original file line numberDiff line numberDiff line change
@@ -390,3 +390,53 @@ end
390390
@test !any(isequal(p[1]), parameters(sys))
391391
@test is_parameter(sys, p)
392392
end
393+
394+
@testset "Can convert from `ODESystem`" begin
395+
@variables x(t) y(t)
396+
@parameters p q r
397+
@named sys = ODESystem([D(x) ~ p * x^3 + q, 0 ~ -y + q * x - r], t;
398+
defaults = [x => 1.0, p => missing], guesses = [p => 1.0],
399+
initialization_eqs = [p^3 + q^3 ~ 4r], parameter_dependencies = [r ~ 3p])
400+
nlsys = NonlinearSystem(sys)
401+
defs = defaults(nlsys)
402+
@test length(defs) == 3
403+
@test defs[x] == 1.0
404+
@test defs[p] === missing
405+
@test isinf(defs[t])
406+
@test length(guesses(nlsys)) == 1
407+
@test guesses(nlsys)[p] == 1.0
408+
@test length(initialization_equations(nlsys)) == 1
409+
@test length(parameter_dependencies(nlsys)) == 1
410+
@test length(equations(nlsys)) == 2
411+
@test all(iszero, [eq.lhs for eq in equations(nlsys)])
412+
@test nameof(nlsys) == nameof(sys)
413+
@test !ModelingToolkit.iscomplete(nlsys)
414+
415+
sys1 = complete(sys; split = false)
416+
nlsys = NonlinearSystem(sys1)
417+
@test ModelingToolkit.iscomplete(nlsys)
418+
@test !ModelingToolkit.is_split(nlsys)
419+
420+
sys2 = complete(sys)
421+
nlsys = NonlinearSystem(sys2)
422+
@test ModelingToolkit.iscomplete(nlsys)
423+
@test ModelingToolkit.is_split(nlsys)
424+
425+
sys3 = structural_simplify(sys)
426+
nlsys = NonlinearSystem(sys3)
427+
@test length(equations(nlsys)) == length(ModelingToolkit.observed(nlsys)) == 1
428+
429+
prob = NonlinearProblem(sys3, [q => 2.0])
430+
@test prob.f.initialization_data.initializeprobmap === nothing
431+
sol = solve(prob)
432+
@test SciMLBase.successful_retcode(sol)
433+
@test sol.ps[p^3 + q^3]sol.ps[4r] atol=1e-10
434+
435+
@testset "Differential inside expression also substituted" begin
436+
@named sys = ODESystem([0 ~ y * D(x) + x^2 - p, 0 ~ x * D(y) + y * p], t)
437+
nlsys = NonlinearSystem(sys)
438+
vs = ModelingToolkit.vars(equations(nlsys))
439+
@test !in(D(x), vs)
440+
@test !in(D(y), vs)
441+
end
442+
end

0 commit comments

Comments
 (0)