Skip to content

Commit d5a48a4

Browse files
Merge pull request #3253 from AayushSabharwal/as/init-everywhere
feat: create initialization systems for all problem types
2 parents dad05e5 + 2e07200 commit d5a48a4

17 files changed

+750
-393
lines changed

Project.toml

+6-3
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ ConstructionBase = "1"
8989
DataInterpolations = "6.4"
9090
DataStructures = "0.17, 0.18"
9191
DeepDiffs = "1"
92+
DelayDiffEq = "5.50"
9293
DiffEqBase = "6.157"
9394
DiffEqCallbacks = "2.16, 3, 4"
9495
DiffEqNoiseProcess = "5"
@@ -117,7 +118,7 @@ Libdl = "1"
117118
LinearAlgebra = "1"
118119
MLStyle = "0.4.17"
119120
NaNMath = "0.3, 1"
120-
NonlinearSolve = "3.14, 4"
121+
NonlinearSolve = "4.3"
121122
OffsetArrays = "1"
122123
OrderedCollections = "1"
123124
OrdinaryDiffEq = "6.82.0"
@@ -129,15 +130,17 @@ RecursiveArrayTools = "3.26"
129130
Reexport = "0.2, 1"
130131
RuntimeGeneratedFunctions = "0.5.9"
131132
SCCNonlinearSolve = "1.0.0"
132-
SciMLBase = "2.66"
133+
SciMLBase = "2.68.1"
133134
SciMLStructures = "1.0"
134135
Serialization = "1"
135136
Setfield = "0.7, 0.8, 1"
136137
SimpleNonlinearSolve = "0.1.0, 1, 2"
137138
SparseArrays = "1"
138139
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
139140
StaticArrays = "0.10, 0.11, 0.12, 1.0"
140-
SymbolicIndexingInterface = "0.3.35"
141+
StochasticDiffEq = "6.72.1"
142+
StochasticDelayDiffEq = "1.8.1"
143+
SymbolicIndexingInterface = "0.3.36"
141144
SymbolicUtils = "3.7"
142145
Symbolics = "6.19"
143146
URIs = "1"

src/systems/diffeqs/abstractodesystem.jl

+28-33
Original file line numberDiff line numberDiff line change
@@ -359,10 +359,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
359359
sparsity = false,
360360
analytic = nothing,
361361
split_idxs = nothing,
362-
initializeprob = nothing,
363-
update_initializeprob! = nothing,
364-
initializeprobmap = nothing,
365-
initializeprobpmap = nothing,
362+
initialization_data = nothing,
366363
kwargs...) where {iip, specialize}
367364
if !iscomplete(sys)
368365
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `ODEFunction`")
@@ -463,10 +460,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
463460
observed = observedfun,
464461
sparsity = sparsity ? jacobian_sparsity(sys) : nothing,
465462
analytic = analytic,
466-
initializeprob = initializeprob,
467-
update_initializeprob! = update_initializeprob!,
468-
initializeprobmap = initializeprobmap,
469-
initializeprobpmap = initializeprobpmap)
463+
initialization_data)
470464
end
471465

472466
"""
@@ -496,10 +490,7 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
496490
sparse = false, simplify = false,
497491
eval_module = @__MODULE__,
498492
checkbounds = false,
499-
initializeprob = nothing,
500-
initializeprobmap = nothing,
501-
initializeprobpmap = nothing,
502-
update_initializeprob! = nothing,
493+
initialization_data = nothing,
503494
kwargs...) where {iip}
504495
if !iscomplete(sys)
505496
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `DAEFunction`")
@@ -547,15 +538,12 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
547538
nothing
548539
end
549540

550-
DAEFunction{iip}(f,
541+
DAEFunction{iip}(f;
551542
sys = sys,
552543
jac = _jac === nothing ? nothing : _jac,
553544
jac_prototype = jac_prototype,
554545
observed = observedfun,
555-
initializeprob = initializeprob,
556-
initializeprobmap = initializeprobmap,
557-
initializeprobpmap = initializeprobpmap,
558-
update_initializeprob! = update_initializeprob!)
546+
initialization_data)
559547
end
560548

561549
function DiffEqBase.DDEFunction(sys::AbstractODESystem, args...; kwargs...)
@@ -567,6 +555,7 @@ function DiffEqBase.DDEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
567555
eval_expression = false,
568556
eval_module = @__MODULE__,
569557
checkbounds = false,
558+
initialization_data = nothing,
570559
kwargs...) where {iip}
571560
if !iscomplete(sys)
572561
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `DDEFunction`")
@@ -579,7 +568,7 @@ function DiffEqBase.DDEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
579568
f(u, h, p, t) = f_oop(u, h, p, t)
580569
f(du, u, h, p, t) = f_iip(du, u, h, p, t)
581570

582-
DDEFunction{iip}(f, sys = sys)
571+
DDEFunction{iip}(f; sys = sys, initialization_data)
583572
end
584573

585574
function DiffEqBase.SDDEFunction(sys::AbstractODESystem, args...; kwargs...)
@@ -591,6 +580,7 @@ function DiffEqBase.SDDEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys
591580
eval_expression = false,
592581
eval_module = @__MODULE__,
593582
checkbounds = false,
583+
initialization_data = nothing,
594584
kwargs...) where {iip}
595585
if !iscomplete(sys)
596586
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `SDDEFunction`")
@@ -609,7 +599,7 @@ function DiffEqBase.SDDEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys
609599
g(u, h, p, t) = g_oop(u, h, p, t)
610600
g(du, u, h, p, t) = g_iip(du, u, h, p, t)
611601

612-
SDDEFunction{iip}(f, g, sys = sys)
602+
SDDEFunction{iip}(f, g; sys = sys, initialization_data)
613603
end
614604

615605
"""
@@ -933,7 +923,7 @@ function DiffEqBase.DDEProblem{iip}(sys::AbstractODESystem, u0map = [],
933923
h_oop, h_iip = eval_or_rgf.(h_gen; eval_expression, eval_module)
934924
h(p, t) = h_oop(p, t)
935925
h(p::MTKParameters, t) = h_oop(p..., t)
936-
u0 = h(p, tspan[1])
926+
u0 = float.(h(p, tspan[1]))
937927
if u0 !== nothing
938928
u0 = u0_constructor(u0)
939929
end
@@ -1257,23 +1247,23 @@ Generates a NonlinearProblem or NonlinearLeastSquaresProblem from an ODESystem
12571247
which represents the initialization, i.e. the calculation of the consistent
12581248
initial conditions for the given DAE.
12591249
"""
1260-
function InitializationProblem(sys::AbstractODESystem, args...; kwargs...)
1250+
function InitializationProblem(sys::AbstractSystem, args...; kwargs...)
12611251
InitializationProblem{true}(sys, args...; kwargs...)
12621252
end
12631253

1264-
function InitializationProblem(sys::AbstractODESystem, t,
1254+
function InitializationProblem(sys::AbstractSystem, t,
12651255
u0map::StaticArray,
12661256
args...;
12671257
kwargs...)
12681258
InitializationProblem{false, SciMLBase.FullSpecialize}(
12691259
sys, t, u0map, args...; kwargs...)
12701260
end
12711261

1272-
function InitializationProblem{true}(sys::AbstractODESystem, args...; kwargs...)
1262+
function InitializationProblem{true}(sys::AbstractSystem, args...; kwargs...)
12731263
InitializationProblem{true, SciMLBase.AutoSpecialize}(sys, args...; kwargs...)
12741264
end
12751265

1276-
function InitializationProblem{false}(sys::AbstractODESystem, args...; kwargs...)
1266+
function InitializationProblem{false}(sys::AbstractSystem, args...; kwargs...)
12771267
InitializationProblem{false, SciMLBase.FullSpecialize}(sys, args...; kwargs...)
12781268
end
12791269

@@ -1292,8 +1282,8 @@ function Base.showerror(io::IO, e::IncompleteInitializationError)
12921282
println(io, e.uninit)
12931283
end
12941284

1295-
function InitializationProblem{iip, specialize}(sys::AbstractODESystem,
1296-
t::Number, u0map = [],
1285+
function InitializationProblem{iip, specialize}(sys::AbstractSystem,
1286+
t, u0map = [],
12971287
parammap = DiffEqBase.NullParameters();
12981288
guesses = [],
12991289
check_length = true,
@@ -1320,6 +1310,11 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem,
13201310
pmap = parammap, guesses, extra_metadata = (; use_scc)); fully_determined)
13211311
end
13221312

1313+
meta = get_metadata(isys)
1314+
if meta isa InitializationSystemMetadata
1315+
@set! isys.metadata.oop_reconstruct_u0_p = ReconstructInitializeprob(sys, isys)
1316+
end
1317+
13231318
ts = get_tearing_state(isys)
13241319
unassigned_vars = StructuralTransformations.singular_check(ts)
13251320
if warn_initialize_determined && !isempty(unassigned_vars)
@@ -1357,13 +1352,13 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem,
13571352
@warn "Initialization system is underdetermined. $neqs equations for $nunknown unknowns. Initialization will default to using least squares. $(scc_message)To suppress this warning pass warn_initialize_determined = false. To make this warning into an error, pass fully_determined = true"
13581353
end
13591354

1360-
parammap = parammap isa DiffEqBase.NullParameters || isempty(parammap) ?
1361-
[get_iv(sys) => t] :
1362-
merge(todict(parammap), Dict(get_iv(sys) => t))
1363-
parammap = Dict(k => v for (k, v) in parammap if v !== missing)
1364-
if isempty(u0map)
1365-
u0map = Dict()
1355+
parammap = recursive_unwrap(anydict(parammap))
1356+
if t !== nothing
1357+
parammap[get_iv(sys)] = t
13661358
end
1359+
filter!(kvp -> kvp[2] !== missing, parammap)
1360+
1361+
u0map = to_varmap(u0map, unknowns(sys))
13671362
if isempty(guesses)
13681363
guesses = Dict()
13691364
end
@@ -1405,5 +1400,5 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem,
14051400
else
14061401
NonlinearLeastSquaresProblem
14071402
end
1408-
TProb(isys, u0map, parammap; kwargs...)
1403+
TProb(isys, u0map, parammap; kwargs..., build_initializeprob = false)
14091404
end

src/systems/diffeqs/odesystem.jl

+7-20
Original file line numberDiff line numberDiff line change
@@ -256,29 +256,16 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
256256
:ODESystem, force = true)
257257
end
258258
defaults = Dict{Any, Any}(todict(defaults))
259+
guesses = Dict{Any, Any}(todict(guesses))
259260
var_to_name = Dict()
260-
process_variables!(var_to_name, defaults, dvs′)
261-
process_variables!(var_to_name, defaults, ps′)
262-
process_variables!(var_to_name, defaults, [eq.lhs for eq in parameter_dependencies])
263-
process_variables!(var_to_name, defaults, [eq.rhs for eq in parameter_dependencies])
261+
process_variables!(var_to_name, defaults, guesses, dvs′)
262+
process_variables!(var_to_name, defaults, guesses, ps′)
263+
process_variables!(
264+
var_to_name, defaults, guesses, [eq.lhs for eq in parameter_dependencies])
265+
process_variables!(
266+
var_to_name, defaults, guesses, [eq.rhs for eq in parameter_dependencies])
264267
defaults = Dict{Any, Any}(value(k) => value(v)
265268
for (k, v) in pairs(defaults) if v !== nothing)
266-
267-
sysdvsguesses = [ModelingToolkit.getguess(st) for st in dvs′]
268-
hasaguess = findall(!isnothing, sysdvsguesses)
269-
var_guesses = dvs′[hasaguess] .=> sysdvsguesses[hasaguess]
270-
sysdvsguesses = isempty(var_guesses) ? Dict() : todict(var_guesses)
271-
syspsguesses = [ModelingToolkit.getguess(st) for st in ps′]
272-
hasaguess = findall(!isnothing, syspsguesses)
273-
ps_guesses = ps′[hasaguess] .=> syspsguesses[hasaguess]
274-
syspsguesses = isempty(ps_guesses) ? Dict() : todict(ps_guesses)
275-
syspdepguesses = [ModelingToolkit.getguess(eq.lhs) for eq in parameter_dependencies]
276-
hasaguess = findall(!isnothing, syspdepguesses)
277-
pdep_guesses = [eq.lhs for eq in parameter_dependencies][hasaguess] .=>
278-
syspdepguesses[hasaguess]
279-
syspdepguesses = isempty(pdep_guesses) ? Dict() : todict(pdep_guesses)
280-
281-
guesses = merge(sysdvsguesses, syspsguesses, syspdepguesses, todict(guesses))
282269
guesses = Dict{Any, Any}(value(k) => value(v)
283270
for (k, v) in pairs(guesses) if v !== nothing)
284271

src/systems/diffeqs/sdesystem.jl

+44-18
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,19 @@ struct SDESystem <: AbstractODESystem
9393
"""
9494
defaults::Dict
9595
"""
96+
The guesses to use as the initial conditions for the
97+
initialization system.
98+
"""
99+
guesses::Dict
100+
"""
101+
The system for performing the initialization.
102+
"""
103+
initializesystem::Union{Nothing, NonlinearSystem}
104+
"""
105+
Extra equations to be enforced during the initialization sequence.
106+
"""
107+
initialization_eqs::Vector{Equation}
108+
"""
96109
Type of the system.
97110
"""
98111
connector_type::Any
@@ -144,9 +157,8 @@ struct SDESystem <: AbstractODESystem
144157
isscheduled::Bool
145158

146159
function SDESystem(tag, deqs, neqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed,
147-
tgrad,
148-
jac,
149-
ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, connector_type,
160+
tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults,
161+
guesses, initializesystem, initialization_eqs, connector_type,
150162
cevents, devents, parameter_dependencies, metadata = nothing, gui_metadata = nothing,
151163
complete = false, index_cache = nothing, parent = nothing, is_scalar_noise = false,
152164
is_dde = false,
@@ -171,9 +183,9 @@ struct SDESystem <: AbstractODESystem
171183
check_units(u, deqs, neqs)
172184
end
173185
new(tag, deqs, neqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac,
174-
ctrl_jac,
175-
Wfact, Wfact_t, name, description, systems,
176-
defaults, connector_type, cevents, devents,
186+
ctrl_jac, Wfact, Wfact_t, name, description, systems,
187+
defaults, guesses, initializesystem, initialization_eqs, connector_type, cevents,
188+
devents,
177189
parameter_dependencies, metadata, gui_metadata, complete, index_cache, parent, is_scalar_noise,
178190
is_dde, isscheduled)
179191
end
@@ -187,6 +199,9 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv
187199
default_u0 = Dict(),
188200
default_p = Dict(),
189201
defaults = _merge(Dict(default_u0), Dict(default_p)),
202+
guesses = Dict(),
203+
initializesystem = nothing,
204+
initialization_eqs = Equation[],
190205
name = nothing,
191206
description = "",
192207
connector_type = nothing,
@@ -207,6 +222,8 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv
207222
dvs′ = value.(dvs)
208223
ps′ = value.(ps)
209224
ctrl′ = value.(controls)
225+
parameter_dependencies, ps′ = process_parameter_dependencies(
226+
parameter_dependencies, ps′)
210227

211228
sysnames = nameof.(systems)
212229
if length(unique(sysnames)) != length(sysnames)
@@ -217,13 +234,21 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv
217234
"`default_u0` and `default_p` are deprecated. Use `defaults` instead.",
218235
:SDESystem, force = true)
219236
end
220-
defaults = todict(defaults)
221-
defaults = Dict(value(k) => value(v)
222-
for (k, v) in pairs(defaults) if value(v) !== nothing)
223237

238+
defaults = Dict{Any, Any}(todict(defaults))
239+
guesses = Dict{Any, Any}(todict(guesses))
224240
var_to_name = Dict()
225-
process_variables!(var_to_name, defaults, dvs′)
226-
process_variables!(var_to_name, defaults, ps′)
241+
process_variables!(var_to_name, defaults, guesses, dvs′)
242+
process_variables!(var_to_name, defaults, guesses, ps′)
243+
process_variables!(
244+
var_to_name, defaults, guesses, [eq.lhs for eq in parameter_dependencies])
245+
process_variables!(
246+
var_to_name, defaults, guesses, [eq.rhs for eq in parameter_dependencies])
247+
defaults = Dict{Any, Any}(value(k) => value(v)
248+
for (k, v) in pairs(defaults) if v !== nothing)
249+
guesses = Dict{Any, Any}(value(k) => value(v)
250+
for (k, v) in pairs(guesses) if v !== nothing)
251+
227252
isempty(observed) || collect_var_to_name!(var_to_name, (eq.lhs for eq in observed))
228253

229254
tgrad = RefValue(EMPTY_TGRAD)
@@ -233,14 +258,13 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv
233258
Wfact_t = RefValue(EMPTY_JAC)
234259
cont_callbacks = SymbolicContinuousCallbacks(continuous_events)
235260
disc_callbacks = SymbolicDiscreteCallbacks(discrete_events)
236-
parameter_dependencies, ps′ = process_parameter_dependencies(
237-
parameter_dependencies, ps′)
238261
if is_dde === nothing
239262
is_dde = _check_if_dde(deqs, iv′, systems)
240263
end
241264
SDESystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
242265
deqs, neqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, tgrad, jac,
243-
ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, connector_type,
266+
ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses,
267+
initializesystem, initialization_eqs, connector_type,
244268
cont_callbacks, disc_callbacks, parameter_dependencies, metadata, gui_metadata,
245269
complete, index_cache, parent, is_scalar_noise, is_dde; checks = checks)
246270
end
@@ -520,7 +544,7 @@ function DiffEqBase.SDEFunction{iip, specialize}(sys::SDESystem, dvs = unknowns(
520544
version = nothing, tgrad = false, sparse = false,
521545
jac = false, Wfact = false, eval_expression = false,
522546
eval_module = @__MODULE__,
523-
checkbounds = false,
547+
checkbounds = false, initialization_data = nothing,
524548
kwargs...) where {iip, specialize}
525549
if !iscomplete(sys)
526550
error("A completed `SDESystem` is required. Call `complete` or `structural_simplify` on the system before creating an `SDEFunction`")
@@ -591,13 +615,13 @@ function DiffEqBase.SDEFunction{iip, specialize}(sys::SDESystem, dvs = unknowns(
591615

592616
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
593617

594-
SDEFunction{iip, specialize}(f, g,
618+
SDEFunction{iip, specialize}(f, g;
595619
sys = sys,
596620
jac = _jac === nothing ? nothing : _jac,
597621
tgrad = _tgrad === nothing ? nothing : _tgrad,
598622
Wfact = _Wfact === nothing ? nothing : _Wfact,
599623
Wfact_t = _Wfact_t === nothing ? nothing : _Wfact_t,
600-
mass_matrix = _M,
624+
mass_matrix = _M, initialization_data,
601625
observed = observedfun)
602626
end
603627

@@ -714,7 +738,7 @@ function DiffEqBase.SDEProblem{iip, specialize}(
714738
end
715739
f, u0, p = process_SciMLProblem(
716740
SDEFunction{iip, specialize}, sys, u0map, parammap; check_length,
717-
kwargs...)
741+
t = tspan === nothing ? nothing : tspan[1], kwargs...)
718742
cbs = process_events(sys; callback, kwargs...)
719743
sparsenoise === nothing && (sparsenoise = get(kwargs, :sparse, false))
720744

@@ -736,6 +760,8 @@ function DiffEqBase.SDEProblem{iip, specialize}(
736760
noise = nothing
737761
end
738762

763+
kwargs = filter_kwargs(kwargs)
764+
739765
SDEProblem{iip}(f, u0, tspan, p; callback = cbs, noise,
740766
noise_rate_prototype = noise_rate_prototype, kwargs...)
741767
end

0 commit comments

Comments
 (0)