Skip to content

Commit da48a1f

Browse files
Merge pull request #3390 from AayushSabharwal/as/initsys-fixes
fix: fix several bugs
2 parents baefe85 + 069b096 commit da48a1f

21 files changed

+169
-63
lines changed

.github/workflows/Downstream.yml

+1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ jobs:
3737
- {user: SciML, repo: MethodOfLines.jl, group: Interface}
3838
- {user: SciML, repo: MethodOfLines.jl, group: 2D_Diffusion}
3939
- {user: SciML, repo: MethodOfLines.jl, group: DAE}
40+
- {user: SciML, repo: ModelingToolkitNeuralNets.jl, group: All}
4041
steps:
4142
- uses: actions/checkout@v4
4243
- uses: julia-actions/setup-julia@v1

Project.toml

+7-5
Original file line numberDiff line numberDiff line change
@@ -104,11 +104,11 @@ DynamicQuantities = "^0.11.2, 0.12, 0.13, 1"
104104
EnumX = "1.0.4"
105105
ExprTools = "0.1.10"
106106
Expronicon = "0.8"
107+
FMI = "0.14"
107108
FindFirstFunctions = "1"
108109
ForwardDiff = "0.10.3"
109110
FunctionWrappers = "1.1"
110111
FunctionWrappersWrappers = "0.1"
111-
FMI = "0.14"
112112
Graphs = "1.5.2"
113113
HomotopyContinuation = "2.11"
114114
InfiniteOpt = "0.5"
@@ -119,6 +119,7 @@ LabelledArrays = "1.3"
119119
Latexify = "0.11, 0.12, 0.13, 0.14, 0.15, 0.16"
120120
Libdl = "1"
121121
LinearAlgebra = "1"
122+
Logging = "1"
122123
MLStyle = "0.4.17"
123124
ModelingToolkitStandardLibrary = "2.19"
124125
NaNMath = "0.3, 1"
@@ -128,7 +129,7 @@ OrderedCollections = "1"
128129
OrdinaryDiffEq = "6.82.0"
129130
OrdinaryDiffEqCore = "1.15.0"
130131
OrdinaryDiffEqDefault = "1.2"
131-
OrdinaryDiffEqNonlinearSolve = "1.3.0"
132+
OrdinaryDiffEqNonlinearSolve = "1.5.0"
132133
PrecompileTools = "1"
133134
REPL = "1"
134135
RecursiveArrayTools = "3.26"
@@ -143,11 +144,11 @@ SimpleNonlinearSolve = "0.1.0, 1, 2"
143144
SparseArrays = "1"
144145
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
145146
StaticArrays = "0.10, 0.11, 0.12, 1.0"
146-
StochasticDiffEq = "6.72.1"
147147
StochasticDelayDiffEq = "1.8.1"
148+
StochasticDiffEq = "6.72.1"
148149
SymbolicIndexingInterface = "0.3.37"
149150
SymbolicUtils = "3.14"
150-
Symbolics = "6.29"
151+
Symbolics = "6.29.1"
151152
URIs = "1"
152153
UnPack = "0.1, 1.0"
153154
Unitful = "1.1"
@@ -164,6 +165,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
164165
Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9"
165166
Ipopt_jll = "9cc047cb-c261-5740-88fc-0cf96f7bdcc7"
166167
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
168+
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
167169
ModelingToolkitStandardLibrary = "16a59e39-deab-5bd0-87e4-056b12336739"
168170
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
169171
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
@@ -187,4 +189,4 @@ Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
187189
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
188190

189191
[targets]
190-
test = ["AmplNLWriter", "BenchmarkTools", "ControlSystemsBase", "DataInterpolations", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "OrdinaryDiffEqCore", "OrdinaryDiffEqDefault", "REPL", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET", "OrdinaryDiffEqNonlinearSolve"]
192+
test = ["AmplNLWriter", "BenchmarkTools", "ControlSystemsBase", "DataInterpolations", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "OrdinaryDiffEqCore", "OrdinaryDiffEqDefault", "REPL", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET", "OrdinaryDiffEqNonlinearSolve", "Logging"]

ext/MTKBifurcationKitExt.jl

+5-1
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,10 @@ function BifurcationKit.BifurcationProblem(nsys::NonlinearSystem,
9797
@set! nsys.index_cache = nothing # force usage of a parameter vector instead of `MTKParameters`
9898
# Creates F and J functions.
9999
ofun = NonlinearFunction(nsys; jac = jac)
100-
F = ofun.f
100+
F = let f = ofun.f
101+
_f(resid, u, p) = (f(resid, u, p); resid)
102+
_f(u, p) = f(u, p)
103+
end
101104
J = jac ? ofun.jac : nothing
102105

103106
# Converts the input state guess.
@@ -136,6 +139,7 @@ function BifurcationKit.BifurcationProblem(nsys::NonlinearSystem,
136139
args...;
137140
record_from_solution = record_from_solution,
138141
J = J,
142+
inplace = true,
139143
kwargs...)
140144
end
141145

ext/MTKChainRulesCoreExt.jl

+6-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,12 @@ end
1616

1717
function subset_idxs(idxs, portion, template)
1818
ntuple(Val(length(template))) do subi
19-
[Base.tail(idx.idx) for idx in idxs if idx.portion == portion && idx.idx[1] == subi]
19+
result = [Base.tail(idx.idx)
20+
for idx in idxs if idx.portion == portion && idx.idx[1] == subi]
21+
if isempty(result)
22+
result = []
23+
end
24+
result
2025
end
2126
end
2227

src/linearization.jl

+9-5
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,8 @@ function (linfun::LinearizationFunction)(u, p, t)
170170
if u !== nothing # Handle systems without unknowns
171171
linfun.num_states == length(u) ||
172172
error("Number of unknown variables ($(linfun.num_states)) does not match the number of input unknowns ($(length(u)))")
173-
integ_cache = linfun.caches
174-
integ = MockIntegrator{true}(u, p, t, integ_cache)
173+
integ_cache = (linfun.caches,)
174+
integ = MockIntegrator{true}(u, p, t, integ_cache, nothing)
175175
u, p, success = SciMLBase.get_initial_values(
176176
linfun.prob, integ, fun, linfun.initializealg, Val(true);
177177
linfun.initialize_kwargs...)
@@ -218,7 +218,7 @@ Mock `DEIntegrator` to allow using `CheckInit` without having to create a new in
218218
219219
$(TYPEDFIELDS)
220220
"""
221-
struct MockIntegrator{iip, U, P, T, C} <: SciMLBase.DEIntegrator{Nothing, iip, U, T}
221+
struct MockIntegrator{iip, U, P, T, C, O} <: SciMLBase.DEIntegrator{Nothing, iip, U, T}
222222
"""
223223
The state vector.
224224
"""
@@ -235,10 +235,14 @@ struct MockIntegrator{iip, U, P, T, C} <: SciMLBase.DEIntegrator{Nothing, iip, U
235235
The integrator cache.
236236
"""
237237
cache::C
238+
"""
239+
Integrator "options" for `CheckInit`.
240+
"""
241+
opts::O
238242
end
239243

240-
function MockIntegrator{iip}(u::U, p::P, t::T, cache::C) where {iip, U, P, T, C}
241-
return MockIntegrator{iip, U, P, T, C}(u, p, t, cache)
244+
function MockIntegrator{iip}(u::U, p::P, t::T, cache::C, opts::O) where {iip, U, P, T, C, O}
245+
return MockIntegrator{iip, U, P, T, C, O}(u, p, t, cache, opts)
242246
end
243247

244248
SymbolicIndexingInterface.state_values(integ::MockIntegrator) = integ.u

src/structural_transformation/symbolics_tearing.jl

+2
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,7 @@ function cse_and_array_hacks(sys, obs, subeqs, unknowns, neweqs; cse = true, arr
647647
# HACK 1
648648
if cse && is_getindexed_array(rhs)
649649
rhs_arr = arguments(rhs)[1]
650+
iscall(rhs_arr) && operation(rhs_arr) isa Symbolics.Operator && continue
650651
if !haskey(rhs_to_tempvar, rhs_arr)
651652
tempvar = gensym(Symbol(lhs))
652653
N = length(rhs_arr)
@@ -719,6 +720,7 @@ function cse_and_array_hacks(sys, obs, subeqs, unknowns, neweqs; cse = true, arr
719720
Symbolics.shape(sym) != Symbolics.Unknown() || continue
720721
arg1 = arguments(sym)[1]
721722
cnt = get(arr_obs_occurrences, arg1, 0)
723+
cnt == 0 && continue
722724
arr_obs_occurrences[arg1] = cnt + 1
723725
end
724726
for eq in neweqs

src/systems/abstractsystem.jl

+5-1
Original file line numberDiff line numberDiff line change
@@ -669,7 +669,11 @@ end
669669

670670
# This is required so `fast_substitute` works
671671
function SymbolicUtils.maketerm(::Type{<:BasicSymbolic}, ::Initial, args, meta)
672-
return metadata(Initial()(args...), meta)
672+
val = Initial()(args...)
673+
if symbolic_type(val) == NotSymbolic()
674+
return val
675+
end
676+
return metadata(val, meta)
673677
end
674678

675679
function add_initialization_parameters(sys::AbstractSystem)

src/systems/codegen_utils.jl

+9-3
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,14 @@ end
1212
1313
Given the arguments to `build_function_wrapper`, return a list of assignments which
1414
reconstruct array variables if they are present scalarized in `args`.
15+
16+
# Keyword Arguments
17+
18+
- `argument_name` a function of the form `(::Int) -> Symbol` which takes the index of
19+
an argument to the generated function and returns the name of the argument in the
20+
generated function.
1521
"""
16-
function array_variable_assignments(args...)
22+
function array_variable_assignments(args...; argument_name = generated_argument_name)
1723
# map array symbolic to an identically sized array where each element is (buffer_idx, idx_in_buffer)
1824
var_to_arridxs = Dict{BasicSymbolic, Array{Tuple{Int, Int}}}()
1925
for (i, arg) in enumerate(args)
@@ -60,12 +66,12 @@ function array_variable_assignments(args...)
6066
end
6167
# view and reshape
6268

63-
expr = term(reshape, term(view, generated_argument_name(buffer_idx), idxs),
69+
expr = term(reshape, term(view, argument_name(buffer_idx), idxs),
6470
size(arrvar))
6571
else
6672
elems = map(idxs) do idx
6773
i, j = idx
68-
term(getindex, generated_argument_name(i), j)
74+
term(getindex, argument_name(i), j)
6975
end
7076
# use `MakeArray` syntax and generate a stack-allocated array
7177
expr = term(SymbolicUtils.Code.create_array, SArray, nothing,

src/systems/diffeqs/abstractodesystem.jl

+3-1
Original file line numberDiff line numberDiff line change
@@ -1291,7 +1291,9 @@ function InitializationProblem{iip, specialize}(sys::AbstractSystem,
12911291
guesses = Dict()
12921292
end
12931293

1294-
u0map = merge(ModelingToolkit.guesses(sys), todict(guesses), todict(u0map))
1294+
filter_missing_values!(u0map)
1295+
filter_missing_values!(parammap)
1296+
u0map = merge(ModelingToolkit.guesses(sys), todict(guesses), u0map)
12951297

12961298
fullmap = merge(u0map, parammap)
12971299
u0T = Union{}

src/systems/diffeqs/odesystem.jl

+9-2
Original file line numberDiff line numberDiff line change
@@ -476,11 +476,16 @@ function build_explicit_observed_function(sys, ts;
476476
end
477477
end
478478
allsyms = Set(all_symbols(sys))
479+
iv = has_iv(sys) ? get_iv(sys) : nothing
479480
for var in vs
480481
var = unwrap(var)
481482
newvar = get(ns_map, var, nothing)
482483
if newvar !== nothing
483484
namespace_subs[var] = newvar
485+
var = newvar
486+
end
487+
if throw && !var_in_varlist(var, allsyms, iv)
488+
Base.throw(ArgumentError("Symbol $var is not present in the system."))
484489
end
485490
end
486491
ts = fast_substitute(ts, namespace_subs)
@@ -522,12 +527,14 @@ function build_explicit_observed_function(sys, ts;
522527
if fns isa Tuple
523528
oop, iip = eval_or_rgf.(fns; eval_expression, eval_module)
524529
f = GeneratedFunctionWrapper{(
525-
p_start, length(args) - length(ps) + 1, is_split(sys))}(oop, iip)
530+
p_start + is_dde(sys), length(args) - length(ps) + 1 + is_dde(sys), is_split(sys))}(
531+
oop, iip)
526532
return return_inplace ? (f, f) : f
527533
else
528534
f = eval_or_rgf(fns; eval_expression, eval_module)
529535
f = GeneratedFunctionWrapper{(
530-
p_start, length(args) - length(ps) + 1, is_split(sys))}(f, nothing)
536+
p_start + is_dde(sys), length(args) - length(ps) + 1 + is_dde(sys), is_split(sys))}(
537+
f, nothing)
531538
return f
532539
end
533540
end

src/systems/nonlinear/initializesystem.jl

+6-4
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,8 @@ function (rip::ReconstructInitializeprob)(srcvalp, dstvalp)
322322
if state_values(dstvalp) === nothing
323323
return nothing, newp
324324
end
325-
T = eltype(state_values(srcvalp))
325+
srcu0 = state_values(srcvalp)
326+
T = srcu0 === nothing || isempty(srcu0) ? Union{} : eltype(srcu0)
326327
if parameter_values(dstvalp) isa MTKParameters
327328
if !isempty(newp.tunable)
328329
T = promote_type(eltype(newp.tunable), T)
@@ -332,7 +333,7 @@ function (rip::ReconstructInitializeprob)(srcvalp, dstvalp)
332333
end
333334
if T == eltype(state_values(dstvalp))
334335
u0 = state_values(dstvalp)
335-
else
336+
elseif T != Union{}
336337
u0 = T.(state_values(dstvalp))
337338
end
338339
buf, repack, alias = SciMLStructures.canonicalize(SciMLStructures.Tunable(), newp)
@@ -511,9 +512,10 @@ end
511512

512513
function SciMLBase.late_binding_update_u0_p(
513514
prob, sys::AbstractSystem, u0, p, t0, newu0, newp)
514-
u0 === missing && return newu0, newp
515-
eltype(u0) <: Pair || return newu0, newp
515+
u0 === missing && return newu0, (p === missing ? copy(newp) : newp)
516+
eltype(u0) <: Pair || return newu0, (p === missing ? copy(newp) : newp)
516517

518+
newp = p === missing ? copy(newp) : newp
517519
newu0 = DiffEqBase.promote_u0(newu0, newp, t0)
518520
tunables, repack, alias = SciMLStructures.canonicalize(SciMLStructures.Tunable(), newp)
519521
tunables = DiffEqBase.promote_u0(tunables, newu0, t0)

src/systems/nonlinear/nonlinearsystem.jl

+11-3
Original file line numberDiff line numberDiff line change
@@ -578,11 +578,19 @@ function CacheWriter(sys::AbstractSystem, buffer_types::Vector{TypeT},
578578
Symbol(:tmp, i) SetArray(true, :(out[$i]), get(exprs, T, []))
579579
end
580580

581+
function argument_name(i::Int)
582+
if i <= length(solsyms)
583+
return :($(generated_argument_name(1))[$i])
584+
end
585+
return generated_argument_name(i - length(solsyms))
586+
end
587+
array_assignments = array_variable_assignments(solsyms...; argument_name)
581588
fn = build_function_wrapper(
582-
sys, nothing, :out, DestructuredArgs(DestructuredArgs.(solsyms)),
583-
DestructuredArgs.(rps)...; p_start = 3, p_end = length(rps) + 2,
589+
sys, nothing, :out,
590+
DestructuredArgs(DestructuredArgs.(solsyms), generated_argument_name(1)),
591+
rps...; p_start = 3, p_end = length(rps) + 2,
584592
expression = Val{true}, add_observed = false,
585-
extra_assignments = [obs_assigns; body])
593+
extra_assignments = [array_assignments; obs_assigns; body])
586594
fn = eval_or_rgf(fn; eval_expression, eval_module)
587595
fn = GeneratedFunctionWrapper{(3, 3, is_split(sys))}(fn, nothing)
588596
return CacheWriter(fn)

src/systems/problem_utils.jl

+8-2
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ Variables as they are specified in `vars` will take priority over their `toterm`
117117
function add_fallbacks!(
118118
varmap::AnyDict, vars::Vector, fallbacks::Dict; toterm = default_toterm)
119119
missingvars = Set()
120+
arrvars = Set()
120121
for var in vars
121122
haskey(varmap, var) && continue
122123
ttvar = toterm(var)
@@ -157,6 +158,7 @@ function add_fallbacks!(
157158
fallbacks, arrvar, nothing) get(fallbacks, ttarrvar, nothing) Some(nothing)
158159
if val !== nothing
159160
val = val[idxs...]
161+
is_sized_array_symbolic(arrvar) && push!(arrvars, arrvar)
160162
end
161163
else
162164
val = nothing
@@ -170,6 +172,10 @@ function add_fallbacks!(
170172
end
171173
end
172174

175+
for arrvar in arrvars
176+
varmap[arrvar] = collect(arrvar)
177+
end
178+
173179
return missingvars
174180
end
175181

@@ -269,9 +275,9 @@ entry for `eq.lhs`, insert the reverse mapping if `eq.rhs` is not a number.
269275
"""
270276
function add_observed_equations!(varmap::AbstractDict, eqs)
271277
for eq in eqs
272-
if haskey(varmap, eq.lhs)
278+
if var_in_varlist(eq.lhs, keys(varmap), nothing)
273279
eq.rhs isa Number && continue
274-
haskey(varmap, eq.rhs) && continue
280+
var_in_varlist(eq.rhs, keys(varmap), nothing) && continue
275281
!iscall(eq.rhs) || issym(operation(eq.rhs)) || continue
276282
varmap[eq.rhs] = eq.lhs
277283
else

src/utils.jl

+19
Original file line numberDiff line numberDiff line change
@@ -1267,3 +1267,22 @@ function symbol_to_symbolic(sys::AbstractSystem, sym; allsyms = all_symbols(sys)
12671267
end
12681268
return sym
12691269
end
1270+
1271+
"""
1272+
$(TYPEDSIGNATURES)
1273+
1274+
Check if `var` is present in `varlist`. `iv` is the independent variable of the system,
1275+
and should be `nothing` if not applicable.
1276+
"""
1277+
function var_in_varlist(var, varlist::AbstractSet, iv)
1278+
var = unwrap(var)
1279+
# simple case
1280+
return var in varlist ||
1281+
# indexed array symbolic, unscalarized array present
1282+
(iscall(var) && operation(var) === getindex && arguments(var)[1] in varlist) ||
1283+
# unscalarized sized array symbolic, all scalarized elements present
1284+
(symbolic_type(var) == ArraySymbolic() && is_sized_array_symbolic(var) &&
1285+
all(x -> x in varlist, collect(var))) ||
1286+
# delayed variables
1287+
(isdelay(var, iv) && var_in_varlist(operation(var)(iv), varlist, iv))
1288+
end

test/debugging.jl

+7-6
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using ModelingToolkit, OrdinaryDiffEq, StochasticDiffEq, SymbolicIndexingInterface
2+
import Logging
23
using ModelingToolkit: t_nounits as t, D_nounits as D, ASSERTION_LOG_VARIABLE
34

45
@variables x(t)
@@ -25,11 +26,11 @@ end
2526
dsys = debug_system(sys; functions = [])
2627
@test is_parameter(dsys, ASSERTION_LOG_VARIABLE)
2728
prob = Problem(dsys, [x => 0.1], (0.0, 5.0))
28-
sol = solve(prob, alg)
29-
@test !SciMLBase.successful_retcode(sol)
30-
prob.ps[ASSERTION_LOG_VARIABLE] = true
3129
sol = @test_logs (:error, r"ohno") match_mode=:any solve(prob, alg)
3230
@test !SciMLBase.successful_retcode(sol)
31+
prob.ps[ASSERTION_LOG_VARIABLE] = false
32+
sol = @test_logs min_level=Logging.Error solve(prob, alg)
33+
@test !SciMLBase.successful_retcode(sol)
3334
end
3435
end
3536

@@ -41,10 +42,10 @@ end
4142
dsys = debug_system(outer; functions = [])
4243
@test is_parameter(dsys, ASSERTION_LOG_VARIABLE)
4344
prob = Problem(dsys, [inner.x => 0.1], (0.0, 5.0))
44-
sol = solve(prob, alg)
45-
@test !SciMLBase.successful_retcode(sol)
46-
prob.ps[ASSERTION_LOG_VARIABLE] = true
4745
sol = @test_logs (:error, r"ohno") match_mode=:any solve(prob, alg)
4846
@test !SciMLBase.successful_retcode(sol)
47+
prob.ps[ASSERTION_LOG_VARIABLE] = false
48+
sol = @test_logs min_level=Logging.Error solve(prob, alg)
49+
@test !SciMLBase.successful_retcode(sol)
4950
end
5051
end

0 commit comments

Comments
 (0)