Skip to content

Commit 7904df7

Browse files
Merge pull request #3345 from AayushSabharwal/as/nlsys-init-hotfix
fix: only solve parameter initialization for `NonlinearSystem`
2 parents c4fe363 + ccfc13b commit 7904df7

File tree

5 files changed

+181
-206
lines changed

5 files changed

+181
-206
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ RecursiveArrayTools = "3.26"
132132
Reexport = "0.2, 1"
133133
RuntimeGeneratedFunctions = "0.5.9"
134134
SCCNonlinearSolve = "1.0.0"
135-
SciMLBase = "2.71"
135+
SciMLBase = "2.71.1"
136136
SciMLStructures = "1.0"
137137
Serialization = "1"
138138
Setfield = "0.7, 0.8, 1"

src/systems/diffeqs/abstractodesystem.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -1335,7 +1335,7 @@ function InitializationProblem{iip, specialize}(sys::AbstractSystem,
13351335

13361336
# TODO: throw on uninitialized arrays
13371337
filter!(x -> !(x isa Symbolics.Arr), uninit)
1338-
if !isempty(uninit)
1338+
if is_time_dependent(sys) && !isempty(uninit)
13391339
allow_incomplete || throw(IncompleteInitializationError(uninit))
13401340
# for incomplete initialization, we will add the missing variables as parameters.
13411341
# they will be updated by `update_initializeprob!` and `initializeprobmap` will

src/systems/nonlinear/initializesystem.jl

+58-48
Original file line numberDiff line numberDiff line change
@@ -42,51 +42,53 @@ function generate_initializesystem(sys::AbstractSystem;
4242
diffmap = Dict()
4343
end
4444

45-
if has_schedule(sys) && (schedule = get_schedule(sys); !isnothing(schedule))
46-
# 2) process dummy derivatives and u0map into initialization system
47-
# prepare map for dummy derivative substitution
48-
for x in filter(x -> !isnothing(x[1]), schedule.dummy_sub)
49-
# set dummy derivatives to default_dd_guess unless specified
50-
push!(defs, x[1] => get(guesses, x[1], default_dd_guess))
51-
end
52-
function process_u0map_with_dummysubs(y, x)
53-
y = get(schedule.dummy_sub, y, y)
54-
y = fixpoint_sub(y, diffmap)
55-
if y vars_set
56-
# variables specified in u0 overrides defaults
57-
push!(defs, y => x)
58-
elseif y isa Symbolics.Arr
59-
# TODO: don't scalarize arrays
60-
merge!(defs, Dict(scalarize(y .=> x)))
61-
elseif y isa Symbolics.BasicSymbolic
62-
# y is a derivative expression expanded; add it to the initialization equations
63-
push!(eqs_ics, y ~ x)
64-
else
65-
error("Initialization expression $y is currently not supported. If its a higher order derivative expression, then only the dummy derivative expressions are supported.")
45+
if is_time_dependent(sys)
46+
if has_schedule(sys) && (schedule = get_schedule(sys); !isnothing(schedule))
47+
# 2) process dummy derivatives and u0map into initialization system
48+
# prepare map for dummy derivative substitution
49+
for x in filter(x -> !isnothing(x[1]), schedule.dummy_sub)
50+
# set dummy derivatives to default_dd_guess unless specified
51+
push!(defs, x[1] => get(guesses, x[1], default_dd_guess))
6652
end
67-
end
68-
for (y, x) in u0map
69-
if Symbolics.isarraysymbolic(y)
70-
process_u0map_with_dummysubs.(collect(y), collect(x))
71-
else
72-
process_u0map_with_dummysubs(y, x)
53+
function process_u0map_with_dummysubs(y, x)
54+
y = get(schedule.dummy_sub, y, y)
55+
y = fixpoint_sub(y, diffmap)
56+
if y vars_set
57+
# variables specified in u0 overrides defaults
58+
push!(defs, y => x)
59+
elseif y isa Symbolics.Arr
60+
# TODO: don't scalarize arrays
61+
merge!(defs, Dict(scalarize(y .=> x)))
62+
elseif y isa Symbolics.BasicSymbolic
63+
# y is a derivative expression expanded; add it to the initialization equations
64+
push!(eqs_ics, y ~ x)
65+
else
66+
error("Initialization expression $y is currently not supported. If its a higher order derivative expression, then only the dummy derivative expressions are supported.")
67+
end
68+
end
69+
for (y, x) in u0map
70+
if Symbolics.isarraysymbolic(y)
71+
process_u0map_with_dummysubs.(collect(y), collect(x))
72+
else
73+
process_u0map_with_dummysubs(y, x)
74+
end
75+
end
76+
else
77+
# 2) System doesn't have a schedule, so dummy derivatives don't exist/aren't handled (SDESystem)
78+
for (k, v) in u0map
79+
defs[k] = v
7380
end
7481
end
75-
else
76-
# 2) System doesn't have a schedule, so dummy derivatives don't exist/aren't handled (SDESystem)
77-
for (k, v) in u0map
78-
defs[k] = v
79-
end
80-
end
8182

82-
# 3) process other variables
83-
for var in vars
84-
if var keys(defs)
85-
push!(eqs_ics, var ~ defs[var])
86-
elseif var keys(guesses)
87-
push!(defs, var => guesses[var])
88-
elseif check_defguess
89-
error("Invalid setup: variable $(var) has no default value or initial guess")
83+
# 3) process other variables
84+
for var in vars
85+
if var keys(defs)
86+
push!(eqs_ics, var ~ defs[var])
87+
elseif var keys(guesses)
88+
push!(defs, var => guesses[var])
89+
elseif check_defguess
90+
error("Invalid setup: variable $(var) has no default value or initial guess")
91+
end
9092
end
9193
end
9294

@@ -180,16 +182,24 @@ function generate_initializesystem(sys::AbstractSystem;
180182
pars = Vector{SymbolicParam}(filter(p -> !haskey(paramsubs, p), parameters(sys)))
181183
is_time_dependent(sys) && push!(pars, get_iv(sys))
182184

183-
# 8) use observed equations for guesses of observed variables if not provided
184-
for eq in trueobs
185-
haskey(defs, eq.lhs) && continue
186-
any(x -> isequal(default_toterm(x), eq.lhs), keys(defs)) && continue
185+
if is_time_dependent(sys)
186+
# 8) use observed equations for guesses of observed variables if not provided
187+
for eq in trueobs
188+
haskey(defs, eq.lhs) && continue
189+
any(x -> isequal(default_toterm(x), eq.lhs), keys(defs)) && continue
187190

188-
defs[eq.lhs] = eq.rhs
191+
defs[eq.lhs] = eq.rhs
192+
end
193+
append!(eqs_ics, trueobs)
194+
end
195+
196+
eqs_ics = Symbolics.substitute.(eqs_ics, (paramsubs,))
197+
if is_time_dependent(sys)
198+
vars = [vars; collect(values(paramsubs))]
199+
else
200+
vars = collect(values(paramsubs))
189201
end
190202

191-
eqs_ics = Symbolics.substitute.([eqs_ics; trueobs], (paramsubs,))
192-
vars = [vars; collect(values(paramsubs))]
193203
for k in keys(defs)
194204
defs[k] = substitute(defs[k], paramsubs)
195205
end

src/systems/problem_utils.jl

+27-11
Original file line numberDiff line numberDiff line change
@@ -546,30 +546,46 @@ function maybe_build_initialization_problem(
546546
initializeprob = ModelingToolkit.InitializationProblem(
547547
sys, t, u0map, pmap; guesses, kwargs...)
548548

549-
all_init_syms = Set(all_symbols(initializeprob))
550-
solved_unknowns = filter(var -> var in all_init_syms, unknowns(sys))
551-
initializeprobmap = getu(initializeprob, solved_unknowns)
549+
if is_time_dependent(sys)
550+
all_init_syms = Set(all_symbols(initializeprob))
551+
solved_unknowns = filter(var -> var in all_init_syms, unknowns(sys))
552+
initializeprobmap = getu(initializeprob, solved_unknowns)
553+
else
554+
initializeprobmap = nothing
555+
end
552556

553557
punknowns = [p
554558
for p in all_variable_symbols(initializeprob)
555559
if is_parameter(sys, p)]
556-
getpunknowns = getu(initializeprob, punknowns)
557-
setpunknowns = setp(sys, punknowns)
558-
initializeprobpmap = GetUpdatedMTKParameters(getpunknowns, setpunknowns)
560+
if isempty(punknowns)
561+
initializeprobpmap = nothing
562+
else
563+
getpunknowns = getu(initializeprob, punknowns)
564+
setpunknowns = setp(sys, punknowns)
565+
initializeprobpmap = GetUpdatedMTKParameters(getpunknowns, setpunknowns)
566+
end
559567

560568
reqd_syms = parameter_symbols(initializeprob)
561-
update_initializeprob! = UpdateInitializeprob(
562-
getu(sys, reqd_syms), setu(initializeprob, reqd_syms))
569+
# we still want the `initialization_data` because it helps with `remake`
570+
if initializeprobmap === nothing && initializeprobpmap === nothing
571+
update_initializeprob! = nothing
572+
else
573+
update_initializeprob! = UpdateInitializeprob(
574+
getu(sys, reqd_syms), setu(initializeprob, reqd_syms))
575+
end
576+
563577
for p in punknowns
564578
p = unwrap(p)
565579
stype = symtype(p)
566580
op[p] = get_temporary_value(p)
567581
end
568582

569-
for v in missing_unknowns
570-
op[v] = zero_var(v)
583+
if is_time_dependent(sys)
584+
for v in missing_unknowns
585+
op[v] = zero_var(v)
586+
end
587+
empty!(missing_unknowns)
571588
end
572-
empty!(missing_unknowns)
573589
return (;
574590
initialization_data = SciMLBase.OverrideInitData(
575591
initializeprob, update_initializeprob!, initializeprobmap,

0 commit comments

Comments
 (0)