Skip to content

Commit 82d815c

Browse files
refactor: major cleanup of *Problem construction
1 parent 28a5af3 commit 82d815c

File tree

9 files changed

+541
-297
lines changed

9 files changed

+541
-297
lines changed

src/ModelingToolkit.jl

+1
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ include("systems/abstractsystem.jl")
144144
include("systems/model_parsing.jl")
145145
include("systems/connectors.jl")
146146
include("systems/callbacks.jl")
147+
include("systems/problem_utils.jl")
147148

148149
include("systems/nonlinear/nonlinearsystem.jl")
149150
include("systems/diffeqs/odesystem.jl")

src/systems/diffeqs/abstractodesystem.jl

+9-213
Original file line numberDiff line numberDiff line change
@@ -793,211 +793,6 @@ function get_u0(
793793
return u0, defs
794794
end
795795

796-
struct GetUpdatedMTKParameters{G, S}
797-
# `getu` functor which gets parameters that are unknowns during initialization
798-
getpunknowns::G
799-
# `setu` functor which returns a modified MTKParameters using those parameters
800-
setpunknowns::S
801-
end
802-
803-
function (f::GetUpdatedMTKParameters)(prob, initializesol)
804-
mtkp = copy(parameter_values(prob))
805-
f.setpunknowns(mtkp, f.getpunknowns(initializesol))
806-
mtkp
807-
end
808-
809-
struct UpdateInitializeprob{G, S}
810-
# `getu` functor which gets all values from prob
811-
getvals::G
812-
# `setu` functor which updates initializeprob with values
813-
setvals::S
814-
end
815-
816-
function (f::UpdateInitializeprob)(initializeprob, prob)
817-
f.setvals(initializeprob, f.getvals(prob))
818-
end
819-
820-
function get_temporary_value(p)
821-
stype = symtype(unwrap(p))
822-
return if stype == Real
823-
zero(Float64)
824-
elseif stype <: AbstractArray{Real}
825-
zeros(Float64, size(p))
826-
elseif stype <: Real
827-
zero(stype)
828-
elseif stype <: AbstractArray
829-
zeros(eltype(stype), size(p))
830-
else
831-
error("Nonnumeric parameter $p with symtype $stype cannot be solved for during initialization")
832-
end
833-
end
834-
835-
function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
836-
implicit_dae = false, du0map = nothing,
837-
version = nothing, tgrad = false,
838-
jac = false,
839-
checkbounds = false, sparse = false,
840-
simplify = false,
841-
linenumbers = true, parallel = SerialForm(),
842-
eval_expression = false,
843-
eval_module = @__MODULE__,
844-
use_union = false,
845-
tofloat = true,
846-
symbolic_u0 = false,
847-
u0_constructor = identity,
848-
guesses = Dict(),
849-
t = nothing,
850-
warn_initialize_determined = true,
851-
build_initializeprob = true,
852-
initialization_eqs = [],
853-
fully_determined = false,
854-
check_units = true,
855-
kwargs...)
856-
eqs = equations(sys)
857-
dvs = unknowns(sys)
858-
ps = parameters(sys)
859-
iv = get_iv(sys)
860-
861-
check_array_equations_unknowns(eqs, dvs)
862-
# TODO: Pass already computed information to varmap_to_vars call
863-
# in process_u0? That would just be a small optimization
864-
varmap = u0map === nothing || isempty(u0map) || eltype(u0map) <: Number ?
865-
defaults(sys) :
866-
merge(defaults(sys), todict(u0map))
867-
varmap = canonicalize_varmap(varmap)
868-
varlist = collect(map(unwrap, dvs))
869-
missingvars = setdiff(varlist, collect(keys(varmap)))
870-
setobserved = filter(keys(varmap)) do var
871-
has_observed_with_lhs(sys, var) || has_observed_with_lhs(sys, default_toterm(var))
872-
end
873-
874-
if eltype(parammap) <: Pair
875-
parammap = Dict{Any, Any}(unwrap(k) => v for (k, v) in parammap)
876-
elseif parammap isa AbstractArray
877-
if isempty(parammap)
878-
parammap = SciMLBase.NullParameters()
879-
else
880-
parammap = Dict{Any, Any}(unwrap.(parameters(sys)) .=> parammap)
881-
end
882-
end
883-
defs = defaults(sys)
884-
if has_guesses(sys)
885-
guesses = merge(
886-
ModelingToolkit.guesses(sys), isempty(guesses) ? Dict() : todict(guesses))
887-
solvablepars = [p
888-
for p in parameters(sys)
889-
if is_parameter_solvable(p, parammap, defs, guesses)]
890-
891-
pvarmap = if parammap === nothing || parammap == SciMLBase.NullParameters() ||
892-
!(eltype(parammap) <: Pair) && isempty(parammap)
893-
defs
894-
else
895-
merge(defs, todict(parammap))
896-
end
897-
setparobserved = filter(keys(pvarmap)) do var
898-
has_parameter_dependency_with_lhs(sys, var)
899-
end
900-
else
901-
solvablepars = ()
902-
setparobserved = ()
903-
end
904-
# ModelingToolkit.get_tearing_state(sys) !== nothing => Requires structural_simplify first
905-
if sys isa ODESystem && build_initializeprob &&
906-
(((implicit_dae || !isempty(missingvars) || !isempty(solvablepars) ||
907-
!isempty(setobserved) || !isempty(setparobserved)) &&
908-
ModelingToolkit.get_tearing_state(sys) !== nothing) ||
909-
!isempty(initialization_equations(sys))) && t !== nothing
910-
if eltype(u0map) <: Number
911-
u0map = unknowns(sys) .=> vec(u0map)
912-
end
913-
if u0map === nothing || isempty(u0map)
914-
u0map = Dict()
915-
end
916-
917-
initializeprob = ModelingToolkit.InitializationProblem(
918-
sys, t, u0map, parammap; guesses, warn_initialize_determined,
919-
initialization_eqs, eval_expression, eval_module, fully_determined, check_units)
920-
initializeprobmap = getu(initializeprob, unknowns(sys))
921-
punknowns = [p
922-
for p in all_variable_symbols(initializeprob) if is_parameter(sys, p)]
923-
getpunknowns = getu(initializeprob, punknowns)
924-
setpunknowns = setp(sys, punknowns)
925-
initializeprobpmap = GetUpdatedMTKParameters(getpunknowns, setpunknowns)
926-
reqd_syms = parameter_symbols(initializeprob)
927-
update_initializeprob! = UpdateInitializeprob(
928-
getu(sys, reqd_syms), setu(initializeprob, reqd_syms))
929-
930-
zerovars = Dict(setdiff(unknowns(sys), keys(defaults(sys))) .=> 0.0)
931-
if parammap isa SciMLBase.NullParameters
932-
parammap = Dict()
933-
end
934-
for p in punknowns
935-
p = unwrap(p)
936-
stype = symtype(p)
937-
parammap[p] = get_temporary_value(p)
938-
end
939-
trueinit = collect(merge(zerovars, eltype(u0map) <: Pair ? todict(u0map) : u0map))
940-
u0map isa StaticArraysCore.StaticArray &&
941-
(trueinit = SVector{length(trueinit)}(trueinit))
942-
else
943-
initializeprob = nothing
944-
update_initializeprob! = nothing
945-
initializeprobmap = nothing
946-
initializeprobpmap = nothing
947-
trueinit = u0map
948-
end
949-
950-
if has_index_cache(sys) && get_index_cache(sys) !== nothing
951-
u0, defs = get_u0(sys, trueinit, parammap; symbolic_u0,
952-
t0 = constructor <: Union{DDEFunction, SDDEFunction} ? nothing : t, use_union)
953-
check_eqs_u0(eqs, dvs, u0; kwargs...)
954-
p = if parammap === nothing ||
955-
parammap == SciMLBase.NullParameters() && isempty(defs)
956-
nothing
957-
else
958-
MTKParameters(sys, parammap, trueinit; t0 = t)
959-
end
960-
else
961-
u0, p, defs = get_u0_p(sys,
962-
trueinit,
963-
parammap;
964-
tofloat,
965-
use_union,
966-
t0 = constructor <: Union{DDEFunction, SDDEFunction} ? nothing : t,
967-
symbolic_u0)
968-
p, split_idxs = split_parameters_by_type(p)
969-
if p isa Tuple
970-
ps = Base.Fix1(getindex, parameters(sys)).(split_idxs)
971-
ps = (ps...,) #if p is Tuple, ps should be Tuple
972-
end
973-
end
974-
if u0 !== nothing
975-
u0 = u0_constructor(u0)
976-
end
977-
978-
if implicit_dae && du0map !== nothing
979-
ddvs = map(Differential(iv), dvs)
980-
defs = mergedefaults(defs, du0map, ddvs)
981-
du0 = varmap_to_vars(du0map, ddvs; defaults = defs, toterm = identity,
982-
tofloat = true)
983-
else
984-
du0 = nothing
985-
ddvs = nothing
986-
end
987-
check_eqs_u0(eqs, dvs, u0; kwargs...)
988-
f = constructor(sys, dvs, ps, u0; ddvs = ddvs, tgrad = tgrad, jac = jac,
989-
checkbounds = checkbounds, p = p,
990-
linenumbers = linenumbers, parallel = parallel, simplify = simplify,
991-
sparse = sparse, eval_expression = eval_expression,
992-
eval_module = eval_module,
993-
initializeprob = initializeprob,
994-
update_initializeprob! = update_initializeprob!,
995-
initializeprobmap = initializeprobmap,
996-
initializeprobpmap = initializeprobpmap,
997-
kwargs...)
998-
implicit_dae ? (f, du0, u0, p) : (f, u0, p)
999-
end
1000-
1001796
function ODEFunctionExpr(sys::AbstractODESystem, args...; kwargs...)
1002797
ODEFunctionExpr{true}(sys, args...; kwargs...)
1003798
end
@@ -1104,7 +899,7 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
1104899
if !iscomplete(sys)
1105900
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `ODEProblem`")
1106901
end
1107-
f, u0, p = process_DEProblem(ODEFunction{iip, specialize}, sys, u0map, parammap;
902+
f, u0, p = process_SciMLProblem(ODEFunction{iip, specialize}, sys, u0map, parammap;
1108903
t = tspan !== nothing ? tspan[1] : tspan,
1109904
check_length, warn_initialize_determined, eval_expression, eval_module, kwargs...)
1110905
cbs = process_events(sys; callback, eval_expression, eval_module, kwargs...)
@@ -1147,7 +942,7 @@ function DiffEqBase.DAEProblem{iip}(sys::AbstractODESystem, du0map, u0map, tspan
1147942
if !iscomplete(sys)
1148943
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `DAEProblem`")
1149944
end
1150-
f, du0, u0, p = process_DEProblem(DAEFunction{iip}, sys, u0map, parammap;
945+
f, du0, u0, p = process_SciMLProblem(DAEFunction{iip}, sys, u0map, parammap;
1151946
implicit_dae = true, du0map = du0map, check_length,
1152947
t = tspan !== nothing ? tspan[1] : tspan,
1153948
warn_initialize_determined, kwargs...)
@@ -1179,7 +974,7 @@ function DiffEqBase.DDEProblem{iip}(sys::AbstractODESystem, u0map = [],
1179974
if !iscomplete(sys)
1180975
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `DDEProblem`")
1181976
end
1182-
f, u0, p = process_DEProblem(DDEFunction{iip}, sys, u0map, parammap;
977+
f, u0, p = process_SciMLProblem(DDEFunction{iip}, sys, u0map, parammap;
1183978
t = tspan !== nothing ? tspan[1] : tspan,
1184979
symbolic_u0 = true,
1185980
check_length, eval_expression, eval_module, kwargs...)
@@ -1214,7 +1009,7 @@ function DiffEqBase.SDDEProblem{iip}(sys::AbstractODESystem, u0map = [],
12141009
if !iscomplete(sys)
12151010
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `SDDEProblem`")
12161011
end
1217-
f, u0, p = process_DEProblem(SDDEFunction{iip}, sys, u0map, parammap;
1012+
f, u0, p = process_SciMLProblem(SDDEFunction{iip}, sys, u0map, parammap;
12181013
t = tspan !== nothing ? tspan[1] : tspan,
12191014
symbolic_u0 = true, eval_expression, eval_module,
12201015
check_length, kwargs...)
@@ -1274,7 +1069,8 @@ function ODEProblemExpr{iip}(sys::AbstractODESystem, u0map, tspan,
12741069
if !iscomplete(sys)
12751070
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `ODEProblemExpr`")
12761071
end
1277-
f, u0, p = process_DEProblem(ODEFunctionExpr{iip}, sys, u0map, parammap; check_length,
1072+
f, u0, p = process_SciMLProblem(
1073+
ODEFunctionExpr{iip}, sys, u0map, parammap; check_length,
12781074
t = tspan !== nothing ? tspan[1] : tspan,
12791075
kwargs...)
12801076
linenumbers = get(kwargs, :linenumbers, true)
@@ -1320,7 +1116,7 @@ function DAEProblemExpr{iip}(sys::AbstractODESystem, du0map, u0map, tspan,
13201116
if !iscomplete(sys)
13211117
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `DAEProblemExpr`")
13221118
end
1323-
f, du0, u0, p = process_DEProblem(DAEFunctionExpr{iip}, sys, u0map, parammap;
1119+
f, du0, u0, p = process_SciMLProblem(DAEFunctionExpr{iip}, sys, u0map, parammap;
13241120
t = tspan !== nothing ? tspan[1] : tspan,
13251121
implicit_dae = true, du0map = du0map, check_length,
13261122
kwargs...)
@@ -1372,7 +1168,7 @@ function DiffEqBase.SteadyStateProblem{iip}(sys::AbstractODESystem, u0map,
13721168
if !iscomplete(sys)
13731169
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `SteadyStateProblem`")
13741170
end
1375-
f, u0, p = process_DEProblem(ODEFunction{iip}, sys, u0map, parammap;
1171+
f, u0, p = process_SciMLProblem(ODEFunction{iip}, sys, u0map, parammap;
13761172
steady_state = true,
13771173
check_length, kwargs...)
13781174
kwargs = filter_kwargs(kwargs)
@@ -1404,7 +1200,7 @@ function SteadyStateProblemExpr{iip}(sys::AbstractODESystem, u0map,
14041200
if !iscomplete(sys)
14051201
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `SteadyStateProblemExpr`")
14061202
end
1407-
f, u0, p = process_DEProblem(ODEFunctionExpr{iip}, sys, u0map, parammap;
1203+
f, u0, p = process_SciMLProblem(ODEFunctionExpr{iip}, sys, u0map, parammap;
14081204
steady_state = true,
14091205
check_length, kwargs...)
14101206
linenumbers = get(kwargs, :linenumbers, true)

src/systems/diffeqs/sdesystem.jl

+3-2
Original file line numberDiff line numberDiff line change
@@ -659,7 +659,7 @@ function DiffEqBase.SDEProblem{iip, specialize}(
659659
if !iscomplete(sys)
660660
error("A completed `SDESystem` is required. Call `complete` or `structural_simplify` on the system before creating an `SDEProblem`")
661661
end
662-
f, u0, p = process_DEProblem(
662+
f, u0, p = process_SciMLProblem(
663663
SDEFunction{iip, specialize}, sys, u0map, parammap; check_length,
664664
kwargs...)
665665
cbs = process_events(sys; callback, kwargs...)
@@ -745,7 +745,8 @@ function SDEProblemExpr{iip}(sys::SDESystem, u0map, tspan,
745745
if !iscomplete(sys)
746746
error("A completed `SDESystem` is required. Call `complete` or `structural_simplify` on the system before creating an `SDEProblemExpr`")
747747
end
748-
f, u0, p = process_DEProblem(SDEFunctionExpr{iip}, sys, u0map, parammap; check_length,
748+
f, u0, p = process_SciMLProblem(
749+
SDEFunctionExpr{iip}, sys, u0map, parammap; check_length,
749750
kwargs...)
750751
linenumbers = get(kwargs, :linenumbers, true)
751752
sparsenoise === nothing && (sparsenoise = get(kwargs, :sparse, false))

src/systems/discrete_system/discrete_system.jl

+16-44
Original file line numberDiff line numberDiff line change
@@ -236,55 +236,25 @@ function generate_function(
236236
generate_custom_function(sys, exprs, dvs, ps; wrap_code, kwargs...)
237237
end
238238

239-
function process_DiscreteProblem(constructor, sys::DiscreteSystem, u0map, parammap;
240-
linenumbers = true, parallel = SerialForm(),
241-
use_union = false,
242-
tofloat = !use_union,
243-
eval_expression = false, eval_module = @__MODULE__,
244-
kwargs...)
239+
function shift_u0map_forward(sys::DiscreteSystem, u0map, defs)
245240
iv = get_iv(sys)
246-
eqs = equations(sys)
247-
dvs = unknowns(sys)
248-
ps = parameters(sys)
249-
250-
if eltype(u0map) <: Number
251-
u0map = unknowns(sys) .=> vec(u0map)
252-
end
253-
if u0map === nothing || isempty(u0map)
254-
u0map = Dict()
255-
end
256-
257-
trueu0map = Dict()
258-
for (k, v) in u0map
259-
k = unwrap(k)
241+
updated = AnyDict()
242+
for k in collect(keys(u0map))
243+
v = u0map[k]
260244
if !((op = operation(k)) isa Shift)
261245
error("Initial conditions must be for the past state of the unknowns. Instead of providing the condition for $k, provide the condition for $(Shift(iv, -1)(k)).")
262246
end
263-
trueu0map[Shift(iv, op.steps + 1)(arguments(k)[1])] = v
264-
end
265-
defs = ModelingToolkit.get_defaults(sys)
266-
for var in dvs
267-
if (op = operation(var)) isa Shift && !haskey(trueu0map, var)
268-
root = arguments(var)[1]
269-
haskey(defs, root) || error("Initial condition for $var not provided.")
270-
trueu0map[var] = defs[root]
271-
end
247+
updated[Shift(iv, op.steps + 1)(arguments(k)[1])] = v
272248
end
273-
if has_index_cache(sys) && get_index_cache(sys) !== nothing
274-
u0, defs = get_u0(sys, trueu0map, parammap)
275-
p = MTKParameters(sys, parammap, trueu0map)
276-
else
277-
u0, p, defs = get_u0_p(sys, trueu0map, parammap; tofloat, use_union)
249+
for var in unknowns(sys)
250+
op = operation(var)
251+
op isa Shift || continue
252+
haskey(updated, var) && continue
253+
root = first(arguments(var))
254+
haskey(defs, root) || error("Initial condition for $var not provided.")
255+
updated[var] = defs[root]
278256
end
279-
280-
check_eqs_u0(eqs, dvs, u0; kwargs...)
281-
282-
f = constructor(sys, dvs, ps, u0;
283-
linenumbers = linenumbers, parallel = parallel,
284-
syms = Symbol.(dvs), paramsyms = Symbol.(ps),
285-
eval_expression = eval_expression, eval_module = eval_module,
286-
kwargs...)
287-
return f, u0, p
257+
return updated
288258
end
289259

290260
"""
@@ -307,7 +277,9 @@ function SciMLBase.DiscreteProblem(
307277
eqs = equations(sys)
308278
iv = get_iv(sys)
309279

310-
f, u0, p = process_DiscreteProblem(
280+
u0map = to_varmap(u0map, dvs)
281+
u0map = shift_u0map_forward(sys, u0map, defaults(sys))
282+
f, u0, p = process_SciMLProblem(
311283
DiscreteFunction, sys, u0map, parammap; eval_expression, eval_module)
312284
u0 = f(u0, p, tspan[1])
313285
DiscreteProblem(f, u0, tspan, p; kwargs...)

0 commit comments

Comments
 (0)