Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: run trivial initialization in problem constructor #3404

Merged
merged 6 commits into from
Feb 27, 2025
21 changes: 13 additions & 8 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -759,7 +759,8 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
kwargs1 = merge(kwargs1, (; tstops))
end

return ODEProblem{iip}(f, u0, tspan, p, pt; kwargs1..., kwargs...)
# Call `remake` so it runs initialization if it is trivial
return remake(ODEProblem{iip}(f, u0, tspan, p, pt; kwargs1..., kwargs...))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should at least be commented why it's done. It's a bit of an odd way to get there, but understand why

end
get_callback(prob::ODEProblem) = prob.kwargs[:callback]

Expand Down Expand Up @@ -963,8 +964,10 @@ function DiffEqBase.DAEProblem{iip}(sys::AbstractODESystem, du0map, u0map, tspan
kwargs1 = merge(kwargs1, (; tstops))
end

DAEProblem{iip}(f, du0, u0, tspan, p; differential_vars = differential_vars,
kwargs..., kwargs1...)
# Call `remake` so it runs initialization if it is trivial
return remake(DAEProblem{iip}(
f, du0, u0, tspan, p; differential_vars = differential_vars,
kwargs..., kwargs1...))
end

function generate_history(sys::AbstractODESystem, u0; expression = Val{false}, kwargs...)
Expand All @@ -991,7 +994,7 @@ function DiffEqBase.DDEProblem{iip}(sys::AbstractODESystem, u0map = [],
end
f, u0, p = process_SciMLProblem(DDEFunction{iip}, sys, u0map, parammap;
t = tspan !== nothing ? tspan[1] : tspan,
symbolic_u0 = true,
symbolic_u0 = true, u0_constructor,
check_length, eval_expression, eval_module, kwargs...)
h_gen = generate_history(sys, u0; expression = Val{true})
h_oop, h_iip = eval_or_rgf.(h_gen; eval_expression, eval_module)
Expand All @@ -1008,7 +1011,8 @@ function DiffEqBase.DDEProblem{iip}(sys::AbstractODESystem, u0map = [],
if cbs !== nothing
kwargs1 = merge(kwargs1, (callback = cbs,))
end
DDEProblem{iip}(f, u0, h, tspan, p; kwargs1..., kwargs...)
# Call `remake` so it runs initialization if it is trivial
return remake(DDEProblem{iip}(f, u0, h, tspan, p; kwargs1..., kwargs...))
end

function DiffEqBase.SDDEProblem(sys::AbstractODESystem, args...; kwargs...)
Expand All @@ -1029,7 +1033,7 @@ function DiffEqBase.SDDEProblem{iip}(sys::AbstractODESystem, u0map = [],
end
f, u0, p = process_SciMLProblem(SDDEFunction{iip}, sys, u0map, parammap;
t = tspan !== nothing ? tspan[1] : tspan,
symbolic_u0 = true, eval_expression, eval_module,
symbolic_u0 = true, eval_expression, eval_module, u0_constructor,
check_length, kwargs...)
h_gen = generate_history(sys, u0; expression = Val{true})
h_oop, h_iip = eval_or_rgf.(h_gen; eval_expression, eval_module)
Expand Down Expand Up @@ -1057,9 +1061,10 @@ function DiffEqBase.SDDEProblem{iip}(sys::AbstractODESystem, u0map = [],
else
noise_rate_prototype = zeros(eltype(u0), size(noiseeqs))
end
SDDEProblem{iip}(f, f.g, u0, h, tspan, p;
# Call `remake` so it runs initialization if it is trivial
return remake(SDDEProblem{iip}(f, f.g, u0, h, tspan, p;
noise_rate_prototype =
noise_rate_prototype, kwargs1..., kwargs...)
noise_rate_prototype, kwargs1..., kwargs...))
end

"""
Expand Down
5 changes: 3 additions & 2 deletions src/systems/diffeqs/sdesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -792,8 +792,9 @@ function DiffEqBase.SDEProblem{iip, specialize}(

kwargs = filter_kwargs(kwargs)

SDEProblem{iip}(f, u0, tspan, p; callback = cbs, noise,
noise_rate_prototype = noise_rate_prototype, kwargs...)
# Call `remake` so it runs initialization if it is trivial
return remake(SDEProblem{iip}(f, u0, tspan, p; callback = cbs, noise,
noise_rate_prototype = noise_rate_prototype, kwargs...))
end

function DiffEqBase.SDEProblem(sys::ODESystem, args...; kwargs...)
Expand Down
6 changes: 4 additions & 2 deletions src/systems/nonlinear/nonlinearsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,8 @@ function DiffEqBase.NonlinearProblem{iip}(sys::NonlinearSystem, u0map,
f, u0, p = process_SciMLProblem(NonlinearFunction{iip}, sys, u0map, parammap;
check_length, kwargs...)
pt = something(get_metadata(sys), StandardNonlinearProblem())
NonlinearProblem{iip}(f, u0, p, pt; filter_kwargs(kwargs)...)
# Call `remake` so it runs initialization if it is trivial
return remake(NonlinearProblem{iip}(f, u0, p, pt; filter_kwargs(kwargs)...))
end

"""
Expand Down Expand Up @@ -548,7 +549,8 @@ function DiffEqBase.NonlinearLeastSquaresProblem{iip}(sys::NonlinearSystem, u0ma
f, u0, p = process_SciMLProblem(NonlinearFunction{iip}, sys, u0map, parammap;
check_length, kwargs...)
pt = something(get_metadata(sys), StandardNonlinearProblem())
NonlinearLeastSquaresProblem{iip}(f, u0, p; filter_kwargs(kwargs)...)
# Call `remake` so it runs initialization if it is trivial
return remake(NonlinearLeastSquaresProblem{iip}(f, u0, p; filter_kwargs(kwargs)...))
end

const TypeT = Union{DataType, UnionAll}
Expand Down
11 changes: 8 additions & 3 deletions src/systems/problem_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,7 @@ All other keyword arguments are forwarded to `InitializationProblem`.
"""
function maybe_build_initialization_problem(
sys::AbstractSystem, op::AbstractDict, u0map, pmap, t, defs,
guesses, missing_unknowns; implicit_dae = false, kwargs...)
guesses, missing_unknowns; implicit_dae = false, u0_constructor = identity, kwargs...)
guesses = merge(ModelingToolkit.guesses(sys), todict(guesses))

if t === nothing && is_time_dependent(sys)
Expand All @@ -615,7 +615,7 @@ function maybe_build_initialization_problem(
if is_time_dependent(sys)
all_init_syms = Set(all_symbols(initializeprob))
solved_unknowns = filter(var -> var in all_init_syms, unknowns(sys))
initializeprobmap = getu(initializeprob, solved_unknowns)
initializeprobmap = u0_constructor ∘ getu(initializeprob, solved_unknowns)
else
initializeprobmap = nothing
end
Expand Down Expand Up @@ -774,14 +774,19 @@ function process_SciMLProblem(
op, missing_unknowns, missing_pars = build_operating_point!(sys,
u0map, pmap, defs, cmap, dvs, ps)

if u0_constructor === identity && u0Type <: StaticArray
u0_constructor = vals -> SymbolicUtils.Code.create_array(
u0Type, eltype(vals), Val(1), Val(length(vals)), vals...)
end
if build_initializeprob
kws = maybe_build_initialization_problem(
sys, op, u0map, pmap, t, defs, guesses, missing_unknowns;
implicit_dae, warn_initialize_determined, initialization_eqs,
eval_expression, eval_module, fully_determined,
warn_cyclic_dependency, check_units = check_initialization_units,
circular_dependency_max_cycle_length, circular_dependency_max_cycles, use_scc,
force_time_independent = force_initialization_time_independent, algebraic_only, allow_incomplete)
force_time_independent = force_initialization_time_independent, algebraic_only, allow_incomplete,
u0_constructor)

kwargs = merge(kwargs, kws)
end
Expand Down
32 changes: 30 additions & 2 deletions test/initializationsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,8 @@ end
@parameters x0 y0
@mtkbuild sys = ODESystem([x ~ x0, y ~ y0, s ~ x + y], t; guesses = [y0 => 0.0])
prob = ODEProblem(sys, [s => 1.0], (0.0, 1.0), [x0 => 0.3, y0 => missing])
@test prob.ps[y0] ≈ 0.0
# trivial initialization run immediately
@test prob.ps[y0] ≈ 0.7
@test init(prob, Tsit5()).ps[y0] ≈ 0.7
@test solve(prob, Tsit5()).ps[y0] ≈ 0.7
end
Expand All @@ -745,7 +746,8 @@ end
systems = [fixed, spring, mass, gravity, constant, damper],
guesses = [spring.s_rel0 => 1.0])
prob = ODEProblem(sys, [], (0.0, 1.0), [spring.s_rel0 => missing])
@test prob.ps[spring.s_rel0] ≈ 0.0
# trivial initialization run immediately
@test prob.ps[spring.s_rel0] ≈ -3.905
@test init(prob, Tsit5()).ps[spring.s_rel0] ≈ -3.905
@test solve(prob, Tsit5()).ps[spring.s_rel0] ≈ -3.905
end
Expand Down Expand Up @@ -1388,3 +1390,29 @@ end
integ1 = init(oprob1)
@test integ1[X1] ≈ 1.0
end

@testset "Trivial initialization is run on problem construction" begin
@variables _x(..) y(t)
@brownian a
@parameters tot
x = _x(t)
@testset "$Problem" for (Problem, lhs, rhs) in [
(ODEProblem, D, 0.0),
(SDEProblem, D, a),
(DDEProblem, D, _x(t - 0.1)),
(SDDEProblem, D, _x(t - 0.1) + a)
]
@mtkbuild sys = ModelingToolkit.System([lhs(x) ~ x + rhs, x + y ~ tot], t;
guesses = [tot => 1.0], defaults = [tot => missing])
prob = Problem(sys, [x => 1.0, y => 1.0], (0.0, 1.0))
@test prob.ps[tot] ≈ 2.0
end
@testset "$Problem" for Problem in [NonlinearProblem, NonlinearLeastSquaresProblem]
@parameters p1 p2
@mtkbuild sys = NonlinearSystem([x^2 + y^2 ~ p1, (x - 1)^2 + (y - 1)^2 ~ p2];
parameter_dependencies = [p2 ~ 2p1],
guesses = [p1 => 0.0], defaults = [p1 => missing])
prob = Problem(sys, [x => 1.0, y => 1.0], [p2 => 6.0])
@test prob.ps[p1] ≈ 3.0
end
end
Loading