diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index 205d7e4601..01f847d5d6 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -759,7 +759,8 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map = kwargs1 = merge(kwargs1, (; tstops)) end - return ODEProblem{iip}(f, u0, tspan, p, pt; kwargs1..., kwargs...) + # Call `remake` so it runs initialization if it is trivial + return remake(ODEProblem{iip}(f, u0, tspan, p, pt; kwargs1..., kwargs...)) end get_callback(prob::ODEProblem) = prob.kwargs[:callback] @@ -963,8 +964,10 @@ function DiffEqBase.DAEProblem{iip}(sys::AbstractODESystem, du0map, u0map, tspan kwargs1 = merge(kwargs1, (; tstops)) end - DAEProblem{iip}(f, du0, u0, tspan, p; differential_vars = differential_vars, - kwargs..., kwargs1...) + # Call `remake` so it runs initialization if it is trivial + return remake(DAEProblem{iip}( + f, du0, u0, tspan, p; differential_vars = differential_vars, + kwargs..., kwargs1...)) end function generate_history(sys::AbstractODESystem, u0; expression = Val{false}, kwargs...) @@ -991,7 +994,7 @@ function DiffEqBase.DDEProblem{iip}(sys::AbstractODESystem, u0map = [], end f, u0, p = process_SciMLProblem(DDEFunction{iip}, sys, u0map, parammap; t = tspan !== nothing ? tspan[1] : tspan, - symbolic_u0 = true, + symbolic_u0 = true, u0_constructor, check_length, eval_expression, eval_module, kwargs...) h_gen = generate_history(sys, u0; expression = Val{true}) h_oop, h_iip = eval_or_rgf.(h_gen; eval_expression, eval_module) @@ -1008,7 +1011,8 @@ function DiffEqBase.DDEProblem{iip}(sys::AbstractODESystem, u0map = [], if cbs !== nothing kwargs1 = merge(kwargs1, (callback = cbs,)) end - DDEProblem{iip}(f, u0, h, tspan, p; kwargs1..., kwargs...) + # Call `remake` so it runs initialization if it is trivial + return remake(DDEProblem{iip}(f, u0, h, tspan, p; kwargs1..., kwargs...)) end function DiffEqBase.SDDEProblem(sys::AbstractODESystem, args...; kwargs...) @@ -1029,7 +1033,7 @@ function DiffEqBase.SDDEProblem{iip}(sys::AbstractODESystem, u0map = [], end f, u0, p = process_SciMLProblem(SDDEFunction{iip}, sys, u0map, parammap; t = tspan !== nothing ? tspan[1] : tspan, - symbolic_u0 = true, eval_expression, eval_module, + symbolic_u0 = true, eval_expression, eval_module, u0_constructor, check_length, kwargs...) h_gen = generate_history(sys, u0; expression = Val{true}) h_oop, h_iip = eval_or_rgf.(h_gen; eval_expression, eval_module) @@ -1057,9 +1061,10 @@ function DiffEqBase.SDDEProblem{iip}(sys::AbstractODESystem, u0map = [], else noise_rate_prototype = zeros(eltype(u0), size(noiseeqs)) end - SDDEProblem{iip}(f, f.g, u0, h, tspan, p; + # Call `remake` so it runs initialization if it is trivial + return remake(SDDEProblem{iip}(f, f.g, u0, h, tspan, p; noise_rate_prototype = - noise_rate_prototype, kwargs1..., kwargs...) + noise_rate_prototype, kwargs1..., kwargs...)) end """ diff --git a/src/systems/diffeqs/sdesystem.jl b/src/systems/diffeqs/sdesystem.jl index bc3fd6b73e..d8ac52dd1a 100644 --- a/src/systems/diffeqs/sdesystem.jl +++ b/src/systems/diffeqs/sdesystem.jl @@ -792,8 +792,9 @@ function DiffEqBase.SDEProblem{iip, specialize}( kwargs = filter_kwargs(kwargs) - SDEProblem{iip}(f, u0, tspan, p; callback = cbs, noise, - noise_rate_prototype = noise_rate_prototype, kwargs...) + # Call `remake` so it runs initialization if it is trivial + return remake(SDEProblem{iip}(f, u0, tspan, p; callback = cbs, noise, + noise_rate_prototype = noise_rate_prototype, kwargs...)) end function DiffEqBase.SDEProblem(sys::ODESystem, args...; kwargs...) diff --git a/src/systems/nonlinear/nonlinearsystem.jl b/src/systems/nonlinear/nonlinearsystem.jl index ad3fbfdeef..42ef380337 100644 --- a/src/systems/nonlinear/nonlinearsystem.jl +++ b/src/systems/nonlinear/nonlinearsystem.jl @@ -519,7 +519,8 @@ function DiffEqBase.NonlinearProblem{iip}(sys::NonlinearSystem, u0map, f, u0, p = process_SciMLProblem(NonlinearFunction{iip}, sys, u0map, parammap; check_length, kwargs...) pt = something(get_metadata(sys), StandardNonlinearProblem()) - NonlinearProblem{iip}(f, u0, p, pt; filter_kwargs(kwargs)...) + # Call `remake` so it runs initialization if it is trivial + return remake(NonlinearProblem{iip}(f, u0, p, pt; filter_kwargs(kwargs)...)) end """ @@ -548,7 +549,8 @@ function DiffEqBase.NonlinearLeastSquaresProblem{iip}(sys::NonlinearSystem, u0ma f, u0, p = process_SciMLProblem(NonlinearFunction{iip}, sys, u0map, parammap; check_length, kwargs...) pt = something(get_metadata(sys), StandardNonlinearProblem()) - NonlinearLeastSquaresProblem{iip}(f, u0, p; filter_kwargs(kwargs)...) + # Call `remake` so it runs initialization if it is trivial + return remake(NonlinearLeastSquaresProblem{iip}(f, u0, p; filter_kwargs(kwargs)...)) end const TypeT = Union{DataType, UnionAll} diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index 77f4229696..85dfded21c 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -601,7 +601,7 @@ All other keyword arguments are forwarded to `InitializationProblem`. """ function maybe_build_initialization_problem( sys::AbstractSystem, op::AbstractDict, u0map, pmap, t, defs, - guesses, missing_unknowns; implicit_dae = false, kwargs...) + guesses, missing_unknowns; implicit_dae = false, u0_constructor = identity, kwargs...) guesses = merge(ModelingToolkit.guesses(sys), todict(guesses)) if t === nothing && is_time_dependent(sys) @@ -615,7 +615,7 @@ function maybe_build_initialization_problem( if is_time_dependent(sys) all_init_syms = Set(all_symbols(initializeprob)) solved_unknowns = filter(var -> var in all_init_syms, unknowns(sys)) - initializeprobmap = getu(initializeprob, solved_unknowns) + initializeprobmap = u0_constructor ∘ getu(initializeprob, solved_unknowns) else initializeprobmap = nothing end @@ -774,6 +774,10 @@ function process_SciMLProblem( op, missing_unknowns, missing_pars = build_operating_point!(sys, u0map, pmap, defs, cmap, dvs, ps) + if u0_constructor === identity && u0Type <: StaticArray + u0_constructor = vals -> SymbolicUtils.Code.create_array( + u0Type, eltype(vals), Val(1), Val(length(vals)), vals...) + end if build_initializeprob kws = maybe_build_initialization_problem( sys, op, u0map, pmap, t, defs, guesses, missing_unknowns; @@ -781,7 +785,8 @@ function process_SciMLProblem( eval_expression, eval_module, fully_determined, warn_cyclic_dependency, check_units = check_initialization_units, circular_dependency_max_cycle_length, circular_dependency_max_cycles, use_scc, - force_time_independent = force_initialization_time_independent, algebraic_only, allow_incomplete) + force_time_independent = force_initialization_time_independent, algebraic_only, allow_incomplete, + u0_constructor) kwargs = merge(kwargs, kws) end diff --git a/test/initializationsystem.jl b/test/initializationsystem.jl index 45afa4163a..236f9e5ce3 100644 --- a/test/initializationsystem.jl +++ b/test/initializationsystem.jl @@ -720,7 +720,8 @@ end @parameters x0 y0 @mtkbuild sys = ODESystem([x ~ x0, y ~ y0, s ~ x + y], t; guesses = [y0 => 0.0]) prob = ODEProblem(sys, [s => 1.0], (0.0, 1.0), [x0 => 0.3, y0 => missing]) - @test prob.ps[y0] ≈ 0.0 + # trivial initialization run immediately + @test prob.ps[y0] ≈ 0.7 @test init(prob, Tsit5()).ps[y0] ≈ 0.7 @test solve(prob, Tsit5()).ps[y0] ≈ 0.7 end @@ -745,7 +746,8 @@ end systems = [fixed, spring, mass, gravity, constant, damper], guesses = [spring.s_rel0 => 1.0]) prob = ODEProblem(sys, [], (0.0, 1.0), [spring.s_rel0 => missing]) - @test prob.ps[spring.s_rel0] ≈ 0.0 + # trivial initialization run immediately + @test prob.ps[spring.s_rel0] ≈ -3.905 @test init(prob, Tsit5()).ps[spring.s_rel0] ≈ -3.905 @test solve(prob, Tsit5()).ps[spring.s_rel0] ≈ -3.905 end @@ -1388,3 +1390,29 @@ end integ1 = init(oprob1) @test integ1[X1] ≈ 1.0 end + +@testset "Trivial initialization is run on problem construction" begin + @variables _x(..) y(t) + @brownian a + @parameters tot + x = _x(t) + @testset "$Problem" for (Problem, lhs, rhs) in [ + (ODEProblem, D, 0.0), + (SDEProblem, D, a), + (DDEProblem, D, _x(t - 0.1)), + (SDDEProblem, D, _x(t - 0.1) + a) + ] + @mtkbuild sys = ModelingToolkit.System([lhs(x) ~ x + rhs, x + y ~ tot], t; + guesses = [tot => 1.0], defaults = [tot => missing]) + prob = Problem(sys, [x => 1.0, y => 1.0], (0.0, 1.0)) + @test prob.ps[tot] ≈ 2.0 + end + @testset "$Problem" for Problem in [NonlinearProblem, NonlinearLeastSquaresProblem] + @parameters p1 p2 + @mtkbuild sys = NonlinearSystem([x^2 + y^2 ~ p1, (x - 1)^2 + (y - 1)^2 ~ p2]; + parameter_dependencies = [p2 ~ 2p1], + guesses = [p1 => 0.0], defaults = [p1 => missing]) + prob = Problem(sys, [x => 1.0, y => 1.0], [p2 => 6.0]) + @test prob.ps[p1] ≈ 3.0 + end +end