Skip to content

Commit 585301a

Browse files
Merge pull request #3404 from AayushSabharwal/as/run-trivial-init
feat: run trivial initialization in problem constructor
2 parents 4eb88e8 + c768c3e commit 585301a

File tree

5 files changed

+58
-17
lines changed

5 files changed

+58
-17
lines changed

src/systems/diffeqs/abstractodesystem.jl

+13-8
Original file line numberDiff line numberDiff line change
@@ -759,7 +759,8 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
759759
kwargs1 = merge(kwargs1, (; tstops))
760760
end
761761

762-
return ODEProblem{iip}(f, u0, tspan, p, pt; kwargs1..., kwargs...)
762+
# Call `remake` so it runs initialization if it is trivial
763+
return remake(ODEProblem{iip}(f, u0, tspan, p, pt; kwargs1..., kwargs...))
763764
end
764765
get_callback(prob::ODEProblem) = prob.kwargs[:callback]
765766

@@ -963,8 +964,10 @@ function DiffEqBase.DAEProblem{iip}(sys::AbstractODESystem, du0map, u0map, tspan
963964
kwargs1 = merge(kwargs1, (; tstops))
964965
end
965966

966-
DAEProblem{iip}(f, du0, u0, tspan, p; differential_vars = differential_vars,
967-
kwargs..., kwargs1...)
967+
# Call `remake` so it runs initialization if it is trivial
968+
return remake(DAEProblem{iip}(
969+
f, du0, u0, tspan, p; differential_vars = differential_vars,
970+
kwargs..., kwargs1...))
968971
end
969972

970973
function generate_history(sys::AbstractODESystem, u0; expression = Val{false}, kwargs...)
@@ -991,7 +994,7 @@ function DiffEqBase.DDEProblem{iip}(sys::AbstractODESystem, u0map = [],
991994
end
992995
f, u0, p = process_SciMLProblem(DDEFunction{iip}, sys, u0map, parammap;
993996
t = tspan !== nothing ? tspan[1] : tspan,
994-
symbolic_u0 = true,
997+
symbolic_u0 = true, u0_constructor,
995998
check_length, eval_expression, eval_module, kwargs...)
996999
h_gen = generate_history(sys, u0; expression = Val{true})
9971000
h_oop, h_iip = eval_or_rgf.(h_gen; eval_expression, eval_module)
@@ -1008,7 +1011,8 @@ function DiffEqBase.DDEProblem{iip}(sys::AbstractODESystem, u0map = [],
10081011
if cbs !== nothing
10091012
kwargs1 = merge(kwargs1, (callback = cbs,))
10101013
end
1011-
DDEProblem{iip}(f, u0, h, tspan, p; kwargs1..., kwargs...)
1014+
# Call `remake` so it runs initialization if it is trivial
1015+
return remake(DDEProblem{iip}(f, u0, h, tspan, p; kwargs1..., kwargs...))
10121016
end
10131017

10141018
function DiffEqBase.SDDEProblem(sys::AbstractODESystem, args...; kwargs...)
@@ -1029,7 +1033,7 @@ function DiffEqBase.SDDEProblem{iip}(sys::AbstractODESystem, u0map = [],
10291033
end
10301034
f, u0, p = process_SciMLProblem(SDDEFunction{iip}, sys, u0map, parammap;
10311035
t = tspan !== nothing ? tspan[1] : tspan,
1032-
symbolic_u0 = true, eval_expression, eval_module,
1036+
symbolic_u0 = true, eval_expression, eval_module, u0_constructor,
10331037
check_length, kwargs...)
10341038
h_gen = generate_history(sys, u0; expression = Val{true})
10351039
h_oop, h_iip = eval_or_rgf.(h_gen; eval_expression, eval_module)
@@ -1057,9 +1061,10 @@ function DiffEqBase.SDDEProblem{iip}(sys::AbstractODESystem, u0map = [],
10571061
else
10581062
noise_rate_prototype = zeros(eltype(u0), size(noiseeqs))
10591063
end
1060-
SDDEProblem{iip}(f, f.g, u0, h, tspan, p;
1064+
# Call `remake` so it runs initialization if it is trivial
1065+
return remake(SDDEProblem{iip}(f, f.g, u0, h, tspan, p;
10611066
noise_rate_prototype =
1062-
noise_rate_prototype, kwargs1..., kwargs...)
1067+
noise_rate_prototype, kwargs1..., kwargs...))
10631068
end
10641069

10651070
"""

src/systems/diffeqs/sdesystem.jl

+3-2
Original file line numberDiff line numberDiff line change
@@ -792,8 +792,9 @@ function DiffEqBase.SDEProblem{iip, specialize}(
792792

793793
kwargs = filter_kwargs(kwargs)
794794

795-
SDEProblem{iip}(f, u0, tspan, p; callback = cbs, noise,
796-
noise_rate_prototype = noise_rate_prototype, kwargs...)
795+
# Call `remake` so it runs initialization if it is trivial
796+
return remake(SDEProblem{iip}(f, u0, tspan, p; callback = cbs, noise,
797+
noise_rate_prototype = noise_rate_prototype, kwargs...))
797798
end
798799

799800
function DiffEqBase.SDEProblem(sys::ODESystem, args...; kwargs...)

src/systems/nonlinear/nonlinearsystem.jl

+4-2
Original file line numberDiff line numberDiff line change
@@ -519,7 +519,8 @@ function DiffEqBase.NonlinearProblem{iip}(sys::NonlinearSystem, u0map,
519519
f, u0, p = process_SciMLProblem(NonlinearFunction{iip}, sys, u0map, parammap;
520520
check_length, kwargs...)
521521
pt = something(get_metadata(sys), StandardNonlinearProblem())
522-
NonlinearProblem{iip}(f, u0, p, pt; filter_kwargs(kwargs)...)
522+
# Call `remake` so it runs initialization if it is trivial
523+
return remake(NonlinearProblem{iip}(f, u0, p, pt; filter_kwargs(kwargs)...))
523524
end
524525

525526
"""
@@ -548,7 +549,8 @@ function DiffEqBase.NonlinearLeastSquaresProblem{iip}(sys::NonlinearSystem, u0ma
548549
f, u0, p = process_SciMLProblem(NonlinearFunction{iip}, sys, u0map, parammap;
549550
check_length, kwargs...)
550551
pt = something(get_metadata(sys), StandardNonlinearProblem())
551-
NonlinearLeastSquaresProblem{iip}(f, u0, p; filter_kwargs(kwargs)...)
552+
# Call `remake` so it runs initialization if it is trivial
553+
return remake(NonlinearLeastSquaresProblem{iip}(f, u0, p; filter_kwargs(kwargs)...))
552554
end
553555

554556
const TypeT = Union{DataType, UnionAll}

src/systems/problem_utils.jl

+8-3
Original file line numberDiff line numberDiff line change
@@ -601,7 +601,7 @@ All other keyword arguments are forwarded to `InitializationProblem`.
601601
"""
602602
function maybe_build_initialization_problem(
603603
sys::AbstractSystem, op::AbstractDict, u0map, pmap, t, defs,
604-
guesses, missing_unknowns; implicit_dae = false, kwargs...)
604+
guesses, missing_unknowns; implicit_dae = false, u0_constructor = identity, kwargs...)
605605
guesses = merge(ModelingToolkit.guesses(sys), todict(guesses))
606606

607607
if t === nothing && is_time_dependent(sys)
@@ -615,7 +615,7 @@ function maybe_build_initialization_problem(
615615
if is_time_dependent(sys)
616616
all_init_syms = Set(all_symbols(initializeprob))
617617
solved_unknowns = filter(var -> var in all_init_syms, unknowns(sys))
618-
initializeprobmap = getu(initializeprob, solved_unknowns)
618+
initializeprobmap = u0_constructor getu(initializeprob, solved_unknowns)
619619
else
620620
initializeprobmap = nothing
621621
end
@@ -774,14 +774,19 @@ function process_SciMLProblem(
774774
op, missing_unknowns, missing_pars = build_operating_point!(sys,
775775
u0map, pmap, defs, cmap, dvs, ps)
776776

777+
if u0_constructor === identity && u0Type <: StaticArray
778+
u0_constructor = vals -> SymbolicUtils.Code.create_array(
779+
u0Type, eltype(vals), Val(1), Val(length(vals)), vals...)
780+
end
777781
if build_initializeprob
778782
kws = maybe_build_initialization_problem(
779783
sys, op, u0map, pmap, t, defs, guesses, missing_unknowns;
780784
implicit_dae, warn_initialize_determined, initialization_eqs,
781785
eval_expression, eval_module, fully_determined,
782786
warn_cyclic_dependency, check_units = check_initialization_units,
783787
circular_dependency_max_cycle_length, circular_dependency_max_cycles, use_scc,
784-
force_time_independent = force_initialization_time_independent, algebraic_only, allow_incomplete)
788+
force_time_independent = force_initialization_time_independent, algebraic_only, allow_incomplete,
789+
u0_constructor)
785790

786791
kwargs = merge(kwargs, kws)
787792
end

test/initializationsystem.jl

+30-2
Original file line numberDiff line numberDiff line change
@@ -720,7 +720,8 @@ end
720720
@parameters x0 y0
721721
@mtkbuild sys = ODESystem([x ~ x0, y ~ y0, s ~ x + y], t; guesses = [y0 => 0.0])
722722
prob = ODEProblem(sys, [s => 1.0], (0.0, 1.0), [x0 => 0.3, y0 => missing])
723-
@test prob.ps[y0] 0.0
723+
# trivial initialization run immediately
724+
@test prob.ps[y0] 0.7
724725
@test init(prob, Tsit5()).ps[y0] 0.7
725726
@test solve(prob, Tsit5()).ps[y0] 0.7
726727
end
@@ -745,7 +746,8 @@ end
745746
systems = [fixed, spring, mass, gravity, constant, damper],
746747
guesses = [spring.s_rel0 => 1.0])
747748
prob = ODEProblem(sys, [], (0.0, 1.0), [spring.s_rel0 => missing])
748-
@test prob.ps[spring.s_rel0] 0.0
749+
# trivial initialization run immediately
750+
@test prob.ps[spring.s_rel0] -3.905
749751
@test init(prob, Tsit5()).ps[spring.s_rel0] -3.905
750752
@test solve(prob, Tsit5()).ps[spring.s_rel0] -3.905
751753
end
@@ -1388,3 +1390,29 @@ end
13881390
integ1 = init(oprob1)
13891391
@test integ1[X1] 1.0
13901392
end
1393+
1394+
@testset "Trivial initialization is run on problem construction" begin
1395+
@variables _x(..) y(t)
1396+
@brownian a
1397+
@parameters tot
1398+
x = _x(t)
1399+
@testset "$Problem" for (Problem, lhs, rhs) in [
1400+
(ODEProblem, D, 0.0),
1401+
(SDEProblem, D, a),
1402+
(DDEProblem, D, _x(t - 0.1)),
1403+
(SDDEProblem, D, _x(t - 0.1) + a)
1404+
]
1405+
@mtkbuild sys = ModelingToolkit.System([lhs(x) ~ x + rhs, x + y ~ tot], t;
1406+
guesses = [tot => 1.0], defaults = [tot => missing])
1407+
prob = Problem(sys, [x => 1.0, y => 1.0], (0.0, 1.0))
1408+
@test prob.ps[tot] 2.0
1409+
end
1410+
@testset "$Problem" for Problem in [NonlinearProblem, NonlinearLeastSquaresProblem]
1411+
@parameters p1 p2
1412+
@mtkbuild sys = NonlinearSystem([x^2 + y^2 ~ p1, (x - 1)^2 + (y - 1)^2 ~ p2];
1413+
parameter_dependencies = [p2 ~ 2p1],
1414+
guesses = [p1 => 0.0], defaults = [p1 => missing])
1415+
prob = Problem(sys, [x => 1.0, y => 1.0], [p2 => 6.0])
1416+
@test prob.ps[p1] 3.0
1417+
end
1418+
end

0 commit comments

Comments
 (0)