1
1
# ################################### system operations #####################################
2
- get_continuous_events (sys:: AbstractSystem ) = Equation []
2
+ get_continuous_events (sys:: AbstractSystem ) = SymbolicContinuousCallback []
3
3
get_continuous_events (sys:: AbstractODESystem ) = getfield (sys, :continuous_events )
4
4
has_continuous_events (sys:: AbstractSystem ) = isdefined (sys, :continuous_events )
5
5
@@ -15,10 +15,11 @@ struct FunctionalAffect
15
15
sts_syms:: Vector{Symbol}
16
16
pars:: Vector
17
17
pars_syms:: Vector{Symbol}
18
+ discretes:: Vector
18
19
ctx:: Any
19
20
end
20
21
21
- function FunctionalAffect (f, sts, pars, ctx = nothing )
22
+ function FunctionalAffect (f, sts, pars, discretes, ctx = nothing )
22
23
# sts & pars contain either pairs: resistor.R => R, or Syms: R
23
24
vs = [x isa Pair ? x. first : x for x in sts]
24
25
vs_syms = Symbol[x isa Pair ? Symbol (x. second) : getname (x) for x in sts]
@@ -28,17 +29,18 @@ function FunctionalAffect(f, sts, pars, ctx = nothing)
28
29
ps_syms = Symbol[x isa Pair ? Symbol (x. second) : getname (x) for x in pars]
29
30
length (ps_syms) == length (unique (ps_syms)) || error (" Parameters are not unique" )
30
31
31
- FunctionalAffect (f, vs, vs_syms, ps, ps_syms, ctx)
32
+ FunctionalAffect (f, vs, vs_syms, ps, ps_syms, discretes, ctx)
32
33
end
33
34
34
- FunctionalAffect (; f, sts, pars, ctx = nothing ) = FunctionalAffect (f, sts, pars, ctx)
35
+ FunctionalAffect (; f, sts, pars, discretes, ctx = nothing ) = FunctionalAffect (f, sts, pars, discretes , ctx)
35
36
36
37
func (f:: FunctionalAffect ) = f. f
37
38
context (a:: FunctionalAffect ) = a. ctx
38
39
parameters (a:: FunctionalAffect ) = a. pars
39
40
parameters_syms (a:: FunctionalAffect ) = a. pars_syms
40
41
unknowns (a:: FunctionalAffect ) = a. sts
41
42
unknowns_syms (a:: FunctionalAffect ) = a. sts_syms
43
+ discretes (a:: FunctionalAffect ) = a. discretes
42
44
43
45
function Base.:(== )(a1:: FunctionalAffect , a2:: FunctionalAffect )
44
46
isequal (a1. f, a2. f) && isequal (a1. sts, a2. sts) && isequal (a1. pars, a2. pars) &&
@@ -52,6 +54,7 @@ function Base.hash(a::FunctionalAffect, s::UInt)
52
54
s = hash (a. sts_syms, s)
53
55
s = hash (a. pars, s)
54
56
s = hash (a. pars_syms, s)
57
+ s = hash (a. discretes, s)
55
58
hash (a. ctx, s)
56
59
end
57
60
@@ -64,6 +67,7 @@ function namespace_affect(affect::FunctionalAffect, s)
64
67
unknowns_syms (affect),
65
68
renamespace .((s,), parameters (affect)),
66
69
parameters_syms (affect),
70
+ renamespace .((s,), discretes (affect)),
67
71
context (affect))
68
72
end
69
73
121
125
122
126
affects (cb:: SymbolicContinuousCallback ) = cb. affect
123
127
function affects (cbs:: Vector{SymbolicContinuousCallback} )
124
- reduce (vcat, [affects (cb) for cb in cbs])
128
+ reduce (vcat, [affects (cb) for cb in cbs], init = [] )
125
129
end
126
130
127
131
namespace_affects (af:: Vector , s) = Equation[namespace_affect (a, s) for a in af]
213
217
affects (cb:: SymbolicDiscreteCallback ) = cb. affects
214
218
215
219
function affects (cbs:: Vector{SymbolicDiscreteCallback} )
216
- reduce (vcat, affects (cb) for cb in cbs)
220
+ reduce (vcat, affects (cb) for cb in cbs; init = [] )
217
221
end
218
222
219
223
function namespace_callback (cb:: SymbolicDiscreteCallback , s):: SymbolicDiscreteCallback
@@ -241,16 +245,54 @@ end
241
245
# ################################ compilation functions ####################################
242
246
243
247
# handles ensuring that affect! functions work with integrator arguments
244
- function add_integrator_header (integrator = gensym (:MTKIntegrator ), out = :u )
245
- expr -> Func ([DestructuredArgs (expr. args, integrator, inds = [:u , :p , :t ])], [],
246
- expr. body),
247
- expr -> Func ([DestructuredArgs (expr. args, integrator, inds = [out, :u , :p , :t ])], [],
248
- expr. body)
248
+ function add_integrator_header (sys:: AbstractSystem , integrator = gensym (:MTKIntegrator ), out = :u )
249
+ if has_index_cache (sys) && get_index_cache (sys) != = nothing
250
+ function (expr)
251
+ p = gensym (:p )
252
+ Func ([
253
+ DestructuredArgs ([expr. args[1 ], p, expr. args[end ]],
254
+ integrator, inds = [:u , :p , :t ]),
255
+ ], [], Let ([DestructuredArgs ([arg. name for arg in expr. args[2 : end - 1 ]], p),
256
+ expr. args[2 : end - 1 ]. .. ], expr. body, false )
257
+ )
258
+ end ,
259
+ function (expr)
260
+ p = gensym (:p )
261
+ Func ([
262
+ DestructuredArgs ([expr. args[1 ], expr. args[2 ], p, expr. args[end ]],
263
+ integrator, inds = [out, :u , :p , :t ]),
264
+ ], [], Let ([DestructuredArgs ([arg. name for arg in expr. args[3 : end - 1 ]], p),
265
+ expr. args[3 : end - 1 ]. .. ], expr. body, false )
266
+ )
267
+ end
268
+ else
269
+ expr -> Func ([DestructuredArgs (expr. args, integrator, inds = [:u , :p , :t ])], [],
270
+ expr. body),
271
+ expr -> Func ([DestructuredArgs (expr. args, integrator, inds = [out, :u , :p , :t ])], [],
272
+ expr. body)
273
+ end
249
274
end
250
275
251
- function condition_header (integrator = gensym (:MTKIntegrator ))
252
- expr -> Func ([expr. args[1 ], expr. args[2 ],
253
- DestructuredArgs (expr. args[3 : end ], integrator, inds = [:p ])], [], expr. body)
276
+ function condition_header (sys:: AbstractSystem , integrator = gensym (:MTKIntegrator ))
277
+ if has_index_cache (sys) && get_index_cache (sys) != = nothing
278
+ function (expr)
279
+ p = gensym (:p )
280
+ res = Func (
281
+ [expr. args[1 ], expr. args[2 ], DestructuredArgs ([p], integrator, inds = [:p ])],
282
+ [],
283
+ Let (
284
+ [
285
+ DestructuredArgs ([arg. name for arg in expr. args[3 : end ]], p),
286
+ expr. args[3 : end ]. ..
287
+ ], expr. body, false
288
+ )
289
+ )
290
+ return res
291
+ end
292
+ else
293
+ expr -> Func ([expr. args[1 ], expr. args[2 ],
294
+ DestructuredArgs (expr. args[3 : end ], integrator, inds = [:p ])], [], expr. body)
295
+ end
254
296
end
255
297
256
298
"""
@@ -267,15 +309,15 @@ Notes
267
309
function compile_condition (cb:: SymbolicDiscreteCallback , sys, dvs, ps;
268
310
expression = Val{true }, kwargs... )
269
311
u = map (x -> time_varying_as_func (value (x), sys), dvs)
270
- p = map (x -> time_varying_as_func (value (x), sys), ps )
312
+ p = map . (x -> time_varying_as_func (value (x), sys), reorder_parameters (sys, ps) )
271
313
t = get_iv (sys)
272
314
condit = condition (cb)
273
315
cs = collect_constants (condit)
274
316
if ! isempty (cs)
275
317
cmap = map (x -> x => getdefault (x), cs)
276
318
condit = substitute (condit, cmap)
277
319
end
278
- build_function (condit, u, t, p; expression, wrap_code = condition_header (),
320
+ build_function (condit, u, t, p... ; expression, wrap_code = condition_header (sys ),
279
321
kwargs... )
280
322
end
281
323
@@ -325,8 +367,19 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
325
367
unknownind = Dict (reverse (en) for en in enumerate (dvs))
326
368
update_inds = map (sym -> unknownind[sym], update_vars)
327
369
elseif isparameter (first (lhss)) && alleq
328
- psind = Dict (reverse (en) for en in enumerate (ps))
329
- update_inds = map (sym -> psind[sym], update_vars)
370
+ if has_index_cache (sys) && get_index_cache (sys) != = nothing
371
+ ic = get_index_cache (sys)
372
+ update_inds = map (update_vars) do sym
373
+ @unpack portion, idx = parameter_index (sys, sym)
374
+ if portion == SciMLStructures. Discrete ()
375
+ idx += length (ic. param_idx)
376
+ end
377
+ idx
378
+ end
379
+ else
380
+ psind = Dict (reverse (en) for en in enumerate (ps))
381
+ update_inds = map (sym -> psind[sym], update_vars)
382
+ end
330
383
outvar = :p
331
384
else
332
385
error (" Error, building an affect function for a callback that wants to modify both parameters and unknowns. This is not currently allowed in one individual callback." )
@@ -335,9 +388,10 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
335
388
update_inds = outputidxs
336
389
end
337
390
391
+ ps = reorder_parameters (sys, ps)
338
392
if checkvars
339
393
u = map (x -> time_varying_as_func (value (x), sys), dvs)
340
- p = map (x -> time_varying_as_func (value (x), sys), ps)
394
+ p = map . (x -> time_varying_as_func (value (x), sys), ps)
341
395
else
342
396
u = dvs
343
397
p = ps
@@ -346,8 +400,8 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
346
400
integ = gensym (:MTKIntegrator )
347
401
getexpr = (postprocess_affect_expr! === nothing ) ? expression : Val{true }
348
402
pre = get_preprocess_constants (rhss)
349
- rf_oop, rf_ip = build_function (rhss, u, p, t; expression = getexpr,
350
- wrap_code = add_integrator_header (integ, outvar),
403
+ rf_oop, rf_ip = build_function (rhss, u, p... , t; expression = getexpr,
404
+ wrap_code = add_integrator_header (sys, integ, outvar),
351
405
outputidxs = update_inds,
352
406
postprocess_fbody = pre,
353
407
kwargs... )
@@ -385,10 +439,10 @@ function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknow
385
439
root_eq_vars = unique (collect (Iterators. flatten (map (ModelingToolkit. vars, rhss))))
386
440
387
441
u = map (x -> time_varying_as_func (value (x), sys), dvs)
388
- p = map (x -> time_varying_as_func (value (x), sys), ps )
442
+ p = map . (x -> time_varying_as_func (value (x), sys), reorder_parameters (sys, ps) )
389
443
t = get_iv (sys)
390
444
pre = get_preprocess_constants (rhss)
391
- rf_oop, rf_ip = build_function (rhss, u, p, t; expression = Val{false },
445
+ rf_oop, rf_ip = build_function (rhss, u, p... , t; expression = Val{false },
392
446
postprocess_fbody = pre, kwargs... )
393
447
394
448
affect_functions = map (cbs) do cb # Keep affect function separate
@@ -400,16 +454,16 @@ function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknow
400
454
cond = function (u, t, integ)
401
455
if DiffEqBase. isinplace (integ. sol. prob)
402
456
tmp, = DiffEqBase. get_tmp_cache (integ)
403
- rf_ip (tmp, u, integ. p , t)
457
+ rf_ip (tmp, u, parameter_values ( integ) ... , t)
404
458
tmp[1 ]
405
459
else
406
- rf_oop (u, integ. p , t)
460
+ rf_oop (u, parameter_values ( integ) ... , t)
407
461
end
408
462
end
409
463
ContinuousCallback (cond, affect_functions[])
410
464
else
411
465
cond = function (out, u, t, integ)
412
- rf_ip (out, u, integ. p , t)
466
+ rf_ip (out, u, parameter_values ( integ) ... , t)
413
467
end
414
468
415
469
# since there may be different number of conditions and affects,
@@ -432,9 +486,13 @@ function compile_user_affect(affect::FunctionalAffect, sys, dvs, ps; kwargs...)
432
486
dvs_ind = Dict (reverse (en) for en in enumerate (dvs))
433
487
v_inds = map (sym -> dvs_ind[sym], unknowns (affect))
434
488
435
- ps_ind = Dict (reverse (en) for en in enumerate (ps))
436
- p_inds = map (sym -> ps_ind[sym], parameters (affect))
437
-
489
+ if has_index_cache (sys) && get_index_cache (sys) != = nothing
490
+ p_inds = [parameter_index (sys, sym) for sym in parameters (affect)]
491
+ else
492
+ ps_ind = Dict (reverse (en) for en in enumerate (ps))
493
+ p_inds = map (sym -> ps_ind[sym], parameters (affect))
494
+ end
495
+
438
496
# HACK: filter out eliminated symbols. Not clear this is the right thing to do
439
497
# (MTK should keep these symbols)
440
498
u = filter (x -> ! isnothing (x[2 ]), collect (zip (unknowns_syms (affect), v_inds))) |>
0 commit comments