Skip to content

Commit 8ec5bc5

Browse files
fix: make most tests pass
1 parent c4ce1b7 commit 8ec5bc5

30 files changed

+659
-341
lines changed

src/ModelingToolkit.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ using PrecompileTools, Reexport
6363
ParallelForm, SerialForm, MultithreadedForm, build_function,
6464
rhss, lhss, prettify_expr, gradient,
6565
jacobian, hessian, derivative, sparsejacobian, sparsehessian,
66-
substituter, scalarize, getparent
66+
substituter, scalarize, getparent, hasderiv, hasdiff
6767

6868
import DiffEqBase: @add_kwonly
6969
import OrdinaryDiffEq

src/systems/abstractsystem.jl

+68-10
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,14 @@ function SymbolicIndexingInterface.is_variable(sys::AbstractSystem, sym)
187187
if unwrap(sym) isa Int # [x, 1] coerces 1 to a Num
188188
return unwrap(sym) in 1:length(variable_symbols(sys))
189189
end
190-
return any(isequal(sym), variable_symbols(sys)) ||
190+
if has_index_cache(sys) && get_index_cache(sys) !== nothing
191+
ic = get_index_cache(sys)
192+
h = getsymbolhash(sym)
193+
return haskey(ic.unknown_idx, h) || haskey(ic.unknown_idx, getsymbolhash(default_toterm(sym))) || hasname(sym) && is_variable(sys, getname(sym))
194+
else
195+
return any(isequal(sym), variable_symbols(sys)) ||
191196
hasname(sym) && is_variable(sys, getname(sym))
197+
end
192198
end
193199

194200
function SymbolicIndexingInterface.is_variable(sys::AbstractSystem, sym::Symbol)
@@ -202,6 +208,22 @@ function SymbolicIndexingInterface.variable_index(sys::AbstractSystem, sym)
202208
if unwrap(sym) isa Int
203209
return unwrap(sym)
204210
end
211+
if has_index_cache(sys) && get_index_cache(sys) !== nothing
212+
ic = get_index_cache(sys)
213+
h = getsymbolhash(sym)
214+
return if haskey(ic.unknown_idx, h)
215+
ic.unknown_idx[h]
216+
else
217+
h = getsymbolhash(default_toterm(sym))
218+
if haskey(ic.unknown_idx, h)
219+
ic.unknown_idx[h]
220+
elseif hasname(sym)
221+
variable_index(sys, getname(sym))
222+
else
223+
nothing
224+
end
225+
end
226+
end
205227
idx = findfirst(isequal(sym), variable_symbols(sys))
206228
if idx === nothing && hasname(sym)
207229
idx = variable_index(sys, getname(sym))
@@ -230,7 +252,16 @@ function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym)
230252
if unwrap(sym) isa Int
231253
return unwrap(sym) in 1:length(parameter_symbols(sys))
232254
end
233-
255+
if has_index_cache(sys) && get_index_cache(sys) !== nothing
256+
ic = get_index_cache(sys)
257+
h = getsymbolhash(sym)
258+
return if haskey(ic.param_idx, h) || haskey(ic.discrete_idx, h)
259+
true
260+
else
261+
h = getsymbolhash(default_toterm(sym))
262+
haskey(ic.param_idx, h) || haskey(ic.discrete_idx, h) || hasname(sym) && is_parameter(sys, getname(sym))
263+
end
264+
end
234265
return any(isequal(sym), parameter_symbols(sys)) ||
235266
hasname(sym) && is_parameter(sys, getname(sym))
236267
end
@@ -246,6 +277,25 @@ function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym)
246277
if unwrap(sym) isa Int
247278
return unwrap(sym)
248279
end
280+
if has_index_cache(sys) && get_index_cache(sys) !== nothing
281+
ic = get_index_cache(sys)
282+
h = getsymbolhash(sym)
283+
return if haskey(ic.param_idx, h)
284+
ParameterIndex(SciMLStructures.Tunable(), ic.param_idx[h])
285+
elseif haskey(ic.discrete_idx, h)
286+
ParameterIndex(SciMLStructures.Discrete(), ic.discrete_idx[h])
287+
else
288+
h = getsymbolhash(default_toterm(sym))
289+
if haskey(ic.param_idx, h)
290+
ParameterIndex(SciMLStructures.Tunable(), ic.param_idx[h])
291+
elseif haskey(ic.discrete_idx, h)
292+
ParameterIndex(SciMLStructures.Discrete(), ic.discrete_idx[h])
293+
else
294+
nothing
295+
end
296+
end
297+
end
298+
249299
idx = findfirst(isequal(sym), parameter_symbols(sys))
250300
if idx === nothing && hasname(sym)
251301
idx = parameter_index(sys, getname(sym))
@@ -1441,12 +1491,18 @@ function linearization_function(sys::AbstractSystem, inputs,
14411491
end
14421492
sys = ssys
14431493
x0 = merge(defaults(sys), Dict(missing_variable_defaults(sys)), op)
1444-
u0, p, _ = get_u0_p(sys, x0, p; use_union = false, tofloat = true)
1445-
p, split_idxs = split_parameters_by_type(p)
1446-
ps = parameters(sys)
1447-
if p isa Tuple
1448-
ps = Base.Fix1(getindex, ps).(split_idxs)
1449-
ps = (ps...,) #if p is Tuple, ps should be Tuple
1494+
u0, _p, _ = get_u0_p(sys, x0, p; use_union = false, tofloat = true)
1495+
if has_index_cache(sys) && get_index_cache(sys) !== nothing
1496+
p = (MTKParameters(sys, p)...,)
1497+
ps = reorder_parameters(sys, parameters(sys))
1498+
else
1499+
p = _p
1500+
p, split_idxs = split_parameters_by_type(p)
1501+
ps = parameters(sys)
1502+
if p isa Tuple
1503+
ps = Base.Fix1(getindex, ps).(split_idxs)
1504+
ps = (ps...,) #if p is Tuple, ps should be Tuple
1505+
end
14501506
end
14511507

14521508
lin_fun = let diff_idxs = diff_idxs,
@@ -1472,7 +1528,7 @@ function linearization_function(sys::AbstractSystem, inputs,
14721528
uf = SciMLBase.UJacobianWrapper(fun, t, p)
14731529
fg_xz = ForwardDiff.jacobian(uf, u)
14741530
h_xz = ForwardDiff.jacobian(let p = p, t = t
1475-
xz -> h(xz, p, t)
1531+
xz -> p isa MTKParameters ? h(xz, p..., t) : h(xz, p, t)
14761532
end, u)
14771533
pf = SciMLBase.ParamJacobianWrapper(fun, t, u)
14781534
fg_u = jacobian_wrt_vars(pf, p, input_idxs, chunk)
@@ -1726,7 +1782,9 @@ function linearize(sys, lin_fun; t = 0.0, op = Dict(), allow_input_derivatives =
17261782
p = DiffEqBase.NullParameters())
17271783
x0 = merge(defaults(sys), op)
17281784
u0, p2, _ = get_u0_p(sys, x0, p; use_union = false, tofloat = true)
1729-
1785+
if has_index_cache(sys) && get_index_cache(sys) !== nothing
1786+
p2 = MTKParameters(sys, p)
1787+
end
17301788
linres = lin_fun(u0, p2, t)
17311789
f_x, f_z, g_x, g_z, f_u, g_u, h_x, h_z, h_u = linres
17321790

src/systems/callbacks.jl

+87-29
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#################################### system operations #####################################
2-
get_continuous_events(sys::AbstractSystem) = Equation[]
2+
get_continuous_events(sys::AbstractSystem) = SymbolicContinuousCallback[]
33
get_continuous_events(sys::AbstractODESystem) = getfield(sys, :continuous_events)
44
has_continuous_events(sys::AbstractSystem) = isdefined(sys, :continuous_events)
55

@@ -15,10 +15,11 @@ struct FunctionalAffect
1515
sts_syms::Vector{Symbol}
1616
pars::Vector
1717
pars_syms::Vector{Symbol}
18+
discretes::Vector
1819
ctx::Any
1920
end
2021

21-
function FunctionalAffect(f, sts, pars, ctx = nothing)
22+
function FunctionalAffect(f, sts, pars, discretes, ctx = nothing)
2223
# sts & pars contain either pairs: resistor.R => R, or Syms: R
2324
vs = [x isa Pair ? x.first : x for x in sts]
2425
vs_syms = Symbol[x isa Pair ? Symbol(x.second) : getname(x) for x in sts]
@@ -28,17 +29,18 @@ function FunctionalAffect(f, sts, pars, ctx = nothing)
2829
ps_syms = Symbol[x isa Pair ? Symbol(x.second) : getname(x) for x in pars]
2930
length(ps_syms) == length(unique(ps_syms)) || error("Parameters are not unique")
3031

31-
FunctionalAffect(f, vs, vs_syms, ps, ps_syms, ctx)
32+
FunctionalAffect(f, vs, vs_syms, ps, ps_syms, discretes, ctx)
3233
end
3334

34-
FunctionalAffect(; f, sts, pars, ctx = nothing) = FunctionalAffect(f, sts, pars, ctx)
35+
FunctionalAffect(; f, sts, pars, discretes, ctx = nothing) = FunctionalAffect(f, sts, pars, discretes, ctx)
3536

3637
func(f::FunctionalAffect) = f.f
3738
context(a::FunctionalAffect) = a.ctx
3839
parameters(a::FunctionalAffect) = a.pars
3940
parameters_syms(a::FunctionalAffect) = a.pars_syms
4041
unknowns(a::FunctionalAffect) = a.sts
4142
unknowns_syms(a::FunctionalAffect) = a.sts_syms
43+
discretes(a::FunctionalAffect) = a.discretes
4244

4345
function Base.:(==)(a1::FunctionalAffect, a2::FunctionalAffect)
4446
isequal(a1.f, a2.f) && isequal(a1.sts, a2.sts) && isequal(a1.pars, a2.pars) &&
@@ -52,6 +54,7 @@ function Base.hash(a::FunctionalAffect, s::UInt)
5254
s = hash(a.sts_syms, s)
5355
s = hash(a.pars, s)
5456
s = hash(a.pars_syms, s)
57+
s = hash(a.discretes, s)
5558
hash(a.ctx, s)
5659
end
5760

@@ -64,6 +67,7 @@ function namespace_affect(affect::FunctionalAffect, s)
6467
unknowns_syms(affect),
6568
renamespace.((s,), parameters(affect)),
6669
parameters_syms(affect),
70+
renamespace.((s,), discretes(affect)),
6771
context(affect))
6872
end
6973

@@ -121,7 +125,7 @@ end
121125

122126
affects(cb::SymbolicContinuousCallback) = cb.affect
123127
function affects(cbs::Vector{SymbolicContinuousCallback})
124-
reduce(vcat, [affects(cb) for cb in cbs])
128+
reduce(vcat, [affects(cb) for cb in cbs], init = [])
125129
end
126130

127131
namespace_affects(af::Vector, s) = Equation[namespace_affect(a, s) for a in af]
@@ -213,7 +217,7 @@ end
213217
affects(cb::SymbolicDiscreteCallback) = cb.affects
214218

215219
function affects(cbs::Vector{SymbolicDiscreteCallback})
216-
reduce(vcat, affects(cb) for cb in cbs)
220+
reduce(vcat, affects(cb) for cb in cbs; init = [])
217221
end
218222

219223
function namespace_callback(cb::SymbolicDiscreteCallback, s)::SymbolicDiscreteCallback
@@ -241,16 +245,54 @@ end
241245
################################# compilation functions ####################################
242246

243247
# handles ensuring that affect! functions work with integrator arguments
244-
function add_integrator_header(integrator = gensym(:MTKIntegrator), out = :u)
245-
expr -> Func([DestructuredArgs(expr.args, integrator, inds = [:u, :p, :t])], [],
246-
expr.body),
247-
expr -> Func([DestructuredArgs(expr.args, integrator, inds = [out, :u, :p, :t])], [],
248-
expr.body)
248+
function add_integrator_header(sys::AbstractSystem, integrator = gensym(:MTKIntegrator), out = :u)
249+
if has_index_cache(sys) && get_index_cache(sys) !== nothing
250+
function (expr)
251+
p = gensym(:p)
252+
Func([
253+
DestructuredArgs([expr.args[1], p, expr.args[end]],
254+
integrator, inds = [:u, :p, :t]),
255+
], [], Let([DestructuredArgs([arg.name for arg in expr.args[2:end-1]], p),
256+
expr.args[2:end-1]...], expr.body, false)
257+
)
258+
end,
259+
function (expr)
260+
p = gensym(:p)
261+
Func([
262+
DestructuredArgs([expr.args[1], expr.args[2], p, expr.args[end]],
263+
integrator, inds = [out, :u, :p, :t]),
264+
], [], Let([DestructuredArgs([arg.name for arg in expr.args[3:end-1]], p),
265+
expr.args[3:end-1]...], expr.body, false)
266+
)
267+
end
268+
else
269+
expr -> Func([DestructuredArgs(expr.args, integrator, inds = [:u, :p, :t])], [],
270+
expr.body),
271+
expr -> Func([DestructuredArgs(expr.args, integrator, inds = [out, :u, :p, :t])], [],
272+
expr.body)
273+
end
249274
end
250275

251-
function condition_header(integrator = gensym(:MTKIntegrator))
252-
expr -> Func([expr.args[1], expr.args[2],
253-
DestructuredArgs(expr.args[3:end], integrator, inds = [:p])], [], expr.body)
276+
function condition_header(sys::AbstractSystem, integrator = gensym(:MTKIntegrator))
277+
if has_index_cache(sys) && get_index_cache(sys) !== nothing
278+
function (expr)
279+
p = gensym(:p)
280+
res = Func(
281+
[expr.args[1], expr.args[2], DestructuredArgs([p], integrator, inds = [:p])],
282+
[],
283+
Let(
284+
[
285+
DestructuredArgs([arg.name for arg in expr.args[3:end]], p),
286+
expr.args[3:end]...
287+
], expr.body, false
288+
)
289+
)
290+
return res
291+
end
292+
else
293+
expr -> Func([expr.args[1], expr.args[2],
294+
DestructuredArgs(expr.args[3:end], integrator, inds = [:p])], [], expr.body)
295+
end
254296
end
255297

256298
"""
@@ -267,15 +309,15 @@ Notes
267309
function compile_condition(cb::SymbolicDiscreteCallback, sys, dvs, ps;
268310
expression = Val{true}, kwargs...)
269311
u = map(x -> time_varying_as_func(value(x), sys), dvs)
270-
p = map(x -> time_varying_as_func(value(x), sys), ps)
312+
p = map.(x -> time_varying_as_func(value(x), sys), reorder_parameters(sys, ps))
271313
t = get_iv(sys)
272314
condit = condition(cb)
273315
cs = collect_constants(condit)
274316
if !isempty(cs)
275317
cmap = map(x -> x => getdefault(x), cs)
276318
condit = substitute(condit, cmap)
277319
end
278-
build_function(condit, u, t, p; expression, wrap_code = condition_header(),
320+
build_function(condit, u, t, p...; expression, wrap_code = condition_header(sys),
279321
kwargs...)
280322
end
281323

@@ -325,8 +367,19 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
325367
unknownind = Dict(reverse(en) for en in enumerate(dvs))
326368
update_inds = map(sym -> unknownind[sym], update_vars)
327369
elseif isparameter(first(lhss)) && alleq
328-
psind = Dict(reverse(en) for en in enumerate(ps))
329-
update_inds = map(sym -> psind[sym], update_vars)
370+
if has_index_cache(sys) && get_index_cache(sys) !== nothing
371+
ic = get_index_cache(sys)
372+
update_inds = map(update_vars) do sym
373+
@unpack portion, idx = parameter_index(sys, sym)
374+
if portion == SciMLStructures.Discrete()
375+
idx += length(ic.param_idx)
376+
end
377+
idx
378+
end
379+
else
380+
psind = Dict(reverse(en) for en in enumerate(ps))
381+
update_inds = map(sym -> psind[sym], update_vars)
382+
end
330383
outvar = :p
331384
else
332385
error("Error, building an affect function for a callback that wants to modify both parameters and unknowns. This is not currently allowed in one individual callback.")
@@ -335,9 +388,10 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
335388
update_inds = outputidxs
336389
end
337390

391+
ps = reorder_parameters(sys, ps)
338392
if checkvars
339393
u = map(x -> time_varying_as_func(value(x), sys), dvs)
340-
p = map(x -> time_varying_as_func(value(x), sys), ps)
394+
p = map.(x -> time_varying_as_func(value(x), sys), ps)
341395
else
342396
u = dvs
343397
p = ps
@@ -346,8 +400,8 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
346400
integ = gensym(:MTKIntegrator)
347401
getexpr = (postprocess_affect_expr! === nothing) ? expression : Val{true}
348402
pre = get_preprocess_constants(rhss)
349-
rf_oop, rf_ip = build_function(rhss, u, p, t; expression = getexpr,
350-
wrap_code = add_integrator_header(integ, outvar),
403+
rf_oop, rf_ip = build_function(rhss, u, p..., t; expression = getexpr,
404+
wrap_code = add_integrator_header(sys, integ, outvar),
351405
outputidxs = update_inds,
352406
postprocess_fbody = pre,
353407
kwargs...)
@@ -385,10 +439,10 @@ function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknow
385439
root_eq_vars = unique(collect(Iterators.flatten(map(ModelingToolkit.vars, rhss))))
386440

387441
u = map(x -> time_varying_as_func(value(x), sys), dvs)
388-
p = map(x -> time_varying_as_func(value(x), sys), ps)
442+
p = map.(x -> time_varying_as_func(value(x), sys), reorder_parameters(sys, ps))
389443
t = get_iv(sys)
390444
pre = get_preprocess_constants(rhss)
391-
rf_oop, rf_ip = build_function(rhss, u, p, t; expression = Val{false},
445+
rf_oop, rf_ip = build_function(rhss, u, p..., t; expression = Val{false},
392446
postprocess_fbody = pre, kwargs...)
393447

394448
affect_functions = map(cbs) do cb # Keep affect function separate
@@ -400,16 +454,16 @@ function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknow
400454
cond = function (u, t, integ)
401455
if DiffEqBase.isinplace(integ.sol.prob)
402456
tmp, = DiffEqBase.get_tmp_cache(integ)
403-
rf_ip(tmp, u, integ.p, t)
457+
rf_ip(tmp, u, parameter_values(integ)..., t)
404458
tmp[1]
405459
else
406-
rf_oop(u, integ.p, t)
460+
rf_oop(u, parameter_values(integ)..., t)
407461
end
408462
end
409463
ContinuousCallback(cond, affect_functions[])
410464
else
411465
cond = function (out, u, t, integ)
412-
rf_ip(out, u, integ.p, t)
466+
rf_ip(out, u, parameter_values(integ)..., t)
413467
end
414468

415469
# since there may be different number of conditions and affects,
@@ -432,9 +486,13 @@ function compile_user_affect(affect::FunctionalAffect, sys, dvs, ps; kwargs...)
432486
dvs_ind = Dict(reverse(en) for en in enumerate(dvs))
433487
v_inds = map(sym -> dvs_ind[sym], unknowns(affect))
434488

435-
ps_ind = Dict(reverse(en) for en in enumerate(ps))
436-
p_inds = map(sym -> ps_ind[sym], parameters(affect))
437-
489+
if has_index_cache(sys) && get_index_cache(sys) !== nothing
490+
p_inds = [parameter_index(sys, sym) for sym in parameters(affect)]
491+
else
492+
ps_ind = Dict(reverse(en) for en in enumerate(ps))
493+
p_inds = map(sym -> ps_ind[sym], parameters(affect))
494+
end
495+
438496
# HACK: filter out eliminated symbols. Not clear this is the right thing to do
439497
# (MTK should keep these symbols)
440498
u = filter(x -> !isnothing(x[2]), collect(zip(unknowns_syms(affect), v_inds))) |>

0 commit comments

Comments
 (0)