Skip to content

Commit 55e8ad6

Browse files
Merge pull request SciML#2911 from BenChung/improve-events
Support more of the SciMLBase events API
2 parents 3cef655 + 7053b34 commit 55e8ad6

File tree

2 files changed

+433
-41
lines changed

2 files changed

+433
-41
lines changed

Diff for: src/systems/callbacks.jl

+174-41
Original file line numberDiff line numberDiff line change
@@ -76,24 +76,62 @@ end
7676
#################################### continuous events #####################################
7777

7878
const NULL_AFFECT = Equation[]
79+
"""
80+
SymbolicContinuousCallback(eqs::Vector{Equation}, affect, affect_neg, rootfind)
81+
82+
A [`ContinuousCallback`](@ref SciMLBase.ContinuousCallback) specified symbolically. Takes a vector of equations `eq`
83+
as well as the positive-edge `affect` and negative-edge `affect_neg` that apply when *any* of `eq` are satisfied.
84+
By default `affect_neg = affect`; to only get rising edges specify `affect_neg = nothing`.
85+
86+
Assume without loss of generality that the equation is of the form `c(u,p,t) ~ 0`; we denote the integrator state as `i.u`.
87+
For compactness, we define `prev_sign = sign(c(u[t-1], p[t-1], t-1))` and `cur_sign = sign(c(u[t], p[t], t))`.
88+
A condition edge will be detected and the callback will be invoked iff `prev_sign * cur_sign <= 0`.
89+
Inter-sample condition activation is not guaranteed; for example if we use the dirac delta function as `c` to insert a
90+
sharp discontinuity between integrator steps (which in this example would not normally be identified by adaptivity) then the condition is not
91+
guaranteed to be triggered.
92+
93+
Once detected the integrator will "wind back" through a root-finding process to identify the point when the condition became active; the method used
94+
is specified by `rootfind` from [`SciMLBase.RootfindOpt`](@ref). Multiple callbacks in the same system with different `rootfind` operations will be resolved
95+
into separate VectorContinuousCallbacks in the enumeration order of `SciMLBase.RootfindOpt`, which may cause some callbacks to not fire if several become
96+
active at the same instant. See the `SciMLBase` documentation for more information on the semantic rules.
97+
98+
The positive edge `affect` will be triggered iff an edge is detected and if `prev_sign < 0`; similarly, `affect_neg` will be
99+
triggered iff an edge is detected `prev_sign > 0`.
100+
101+
Affects (i.e. `affect` and `affect_neg`) can be specified as either:
102+
* A list of equations that should be applied when the callback is triggered (e.g. `x ~ 3, y ~ 7`) which must be of the form `unknown ~ observed value` where each `unknown` appears only once. Equations will be applied in the order that they appear in the vector; parameters and state updates will become immediately visible to following equations.
103+
* A tuple `(f!, unknowns, read_parameters, modified_parameters, ctx)`, where:
104+
+ `f!` is a function with signature `(integ, u, p, ctx)` that is called with the integrator, a state *index* vector `u` derived from `unknowns`, a parameter *index* vector `p` derived from `read_parameters`, and the `ctx` that was given at construction time. Note that `ctx` is aliased between instances.
105+
+ `unknowns` is a vector of symbolic unknown variables and optionally their aliases (e.g. if the model was defined with `@variables x(t)` then a valid value for `unknowns` would be `[x]`). A variable can be aliased with a pair `x => :y`. The indices of these `unknowns` will be passed to `f!` in `u` in a named tuple; in the earlier example, if we pass `[x]` as `unknowns` then `f!` can access `x` as `integ.u[u.x]`. If no alias is specified the name of the index will be the symbol version of the variable name.
106+
+ `read_parameters` is a vector of the parameters that are *used* by `f!`. Their indices are passed to `f` in `p` similarly to the indices of `unknowns` passed in `u`.
107+
+ `modified_parameters` is a vector of the parameters that are *modified* by `f!`. Note that a parameter will not appear in `p` if it only appears in `modified_parameters`; it must appear in both `parameters` and `modified_parameters` if it is used in the affect definition.
108+
+ `ctx` is a user-defined context object passed to `f!` when invoked. This value is aliased for each problem.
109+
"""
79110
struct SymbolicContinuousCallback
80111
eqs::Vector{Equation}
81112
affect::Union{Vector{Equation}, FunctionalAffect}
82-
function SymbolicContinuousCallback(eqs::Vector{Equation}, affect = NULL_AFFECT)
83-
new(eqs, make_affect(affect))
113+
affect_neg::Union{Vector{Equation}, FunctionalAffect, Nothing}
114+
rootfind::SciMLBase.RootfindOpt
115+
function SymbolicContinuousCallback(; eqs::Vector{Equation}, affect = NULL_AFFECT,
116+
affect_neg = affect, rootfind = SciMLBase.LeftRootFind)
117+
new(eqs, make_affect(affect), make_affect(affect_neg), rootfind)
84118
end # Default affect to nothing
85119
end
86120
make_affect(affect) = affect
87121
make_affect(affect::Tuple) = FunctionalAffect(affect...)
88122
make_affect(affect::NamedTuple) = FunctionalAffect(; affect...)
89123

90124
function Base.:(==)(e1::SymbolicContinuousCallback, e2::SymbolicContinuousCallback)
91-
isequal(e1.eqs, e2.eqs) && isequal(e1.affect, e2.affect)
125+
isequal(e1.eqs, e2.eqs) && isequal(e1.affect, e2.affect) &&
126+
isequal(e1.affect_neg, e2.affect_neg) && isequal(e1.rootfind, e2.rootfind)
92127
end
93128
Base.isempty(cb::SymbolicContinuousCallback) = isempty(cb.eqs)
94129
function Base.hash(cb::SymbolicContinuousCallback, s::UInt)
95130
s = foldr(hash, cb.eqs, init = s)
96-
cb.affect isa AbstractVector ? foldr(hash, cb.affect, init = s) : hash(cb.affect, s)
131+
s = cb.affect isa AbstractVector ? foldr(hash, cb.affect, init = s) : hash(cb.affect, s)
132+
s = cb.affect_neg isa AbstractVector ? foldr(hash, cb.affect_neg, init = s) :
133+
hash(cb.affect_neg, s)
134+
hash(cb.rootfind, s)
97135
end
98136

99137
to_equation_vector(eq::Equation) = [eq]
@@ -108,6 +146,14 @@ function SymbolicContinuousCallback(args...)
108146
end # wrap eq in vector
109147
SymbolicContinuousCallback(p::Pair) = SymbolicContinuousCallback(p[1], p[2])
110148
SymbolicContinuousCallback(cb::SymbolicContinuousCallback) = cb # passthrough
149+
function SymbolicContinuousCallback(eqs::Equation, affect = NULL_AFFECT;
150+
affect_neg = affect, rootfind = SciMLBase.LeftRootFind)
151+
SymbolicContinuousCallback(eqs=[eqs], affect=affect, affect_neg=affect_neg, rootfind=rootfind)
152+
end
153+
function SymbolicContinuousCallback(eqs::Vector{Equation}, affect = NULL_AFFECT;
154+
affect_neg = affect, rootfind = SciMLBase.LeftRootFind)
155+
SymbolicContinuousCallback(eqs=eqs, affect=affect, affect_neg=affect_neg, rootfind=rootfind)
156+
end
111157

112158
SymbolicContinuousCallbacks(cb::SymbolicContinuousCallback) = [cb]
113159
SymbolicContinuousCallbacks(cbs::Vector{<:SymbolicContinuousCallback}) = cbs
@@ -130,12 +176,20 @@ function affects(cbs::Vector{SymbolicContinuousCallback})
130176
mapreduce(affects, vcat, cbs, init = Equation[])
131177
end
132178

179+
affect_negs(cb::SymbolicContinuousCallback) = cb.affect_neg
180+
function affect_negs(cbs::Vector{SymbolicContinuousCallback})
181+
mapreduce(affect_negs, vcat, cbs, init = Equation[])
182+
end
183+
133184
namespace_affects(af::Vector, s) = Equation[namespace_affect(a, s) for a in af]
134185
namespace_affects(af::FunctionalAffect, s) = namespace_affect(af, s)
186+
namespace_affects(::Nothing, s) = nothing
135187

136188
function namespace_callback(cb::SymbolicContinuousCallback, s)::SymbolicContinuousCallback
137-
SymbolicContinuousCallback(namespace_equation.(equations(cb), (s,)),
138-
namespace_affects(affects(cb), s))
189+
SymbolicContinuousCallback(
190+
namespace_equation.(equations(cb), (s,)),
191+
namespace_affects(affects(cb), s),
192+
namespace_affects(affect_negs(cb), s))
139193
end
140194

141195
"""
@@ -159,7 +213,7 @@ function continuous_events(sys::AbstractSystem)
159213
filter(!isempty, cbs)
160214
end
161215

162-
#################################### continuous events #####################################
216+
#################################### discrete events #####################################
163217

164218
struct SymbolicDiscreteCallback
165219
# condition can be one of:
@@ -462,12 +516,38 @@ function generate_rootfinding_callback(sys::AbstractODESystem, dvs = unknowns(sy
462516
isempty(cbs) && return nothing
463517
generate_rootfinding_callback(cbs, sys, dvs, ps; kwargs...)
464518
end
465-
466-
function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknowns(sys),
519+
"""
520+
Generate a single rootfinding callback; this happens if there is only one equation in `cbs` passed to
521+
generate_rootfinding_callback and thus we can produce a ContinuousCallback instead of a VectorContinuousCallback.
522+
"""
523+
function generate_single_rootfinding_callback(
524+
eq, cb, sys::AbstractODESystem, dvs = unknowns(sys),
467525
ps = full_parameters(sys); kwargs...)
526+
if !isequal(eq.lhs, 0)
527+
eq = 0 ~ eq.lhs - eq.rhs
528+
end
529+
530+
rf_oop, rf_ip = generate_custom_function(
531+
sys, [eq.rhs], dvs, ps; expression = Val{false}, kwargs...)
532+
affect_function = compile_affect_fn(cb, sys, dvs, ps, kwargs)
533+
cond = function (u, t, integ)
534+
if DiffEqBase.isinplace(integ.sol.prob)
535+
tmp, = DiffEqBase.get_tmp_cache(integ)
536+
rf_ip(tmp, u, parameter_values(integ), t)
537+
tmp[1]
538+
else
539+
rf_oop(u, parameter_values(integ), t)
540+
end
541+
end
542+
return ContinuousCallback(
543+
cond, affect_function.affect, affect_function.affect_neg, rootfind = cb.rootfind)
544+
end
545+
546+
function generate_vector_rootfinding_callback(
547+
cbs, sys::AbstractODESystem, dvs = unknowns(sys),
548+
ps = full_parameters(sys); rootfind = SciMLBase.RightRootFind, kwargs...)
468549
eqs = map(cb -> flatten_equations(cb.eqs), cbs)
469550
num_eqs = length.(eqs)
470-
(isempty(eqs) || sum(num_eqs) == 0) && return nothing
471551
# fuse equations to create VectorContinuousCallback
472552
eqs = reduce(vcat, eqs)
473553
# rewrite all equations as 0 ~ interesting stuff
@@ -477,45 +557,99 @@ function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknow
477557
end
478558

479559
rhss = map(x -> x.rhs, eqs)
480-
root_eq_vars = unique(collect(Iterators.flatten(map(ModelingToolkit.vars, rhss))))
560+
_, rf_ip = generate_custom_function(
561+
sys, rhss, dvs, ps; expression = Val{false}, kwargs...)
562+
563+
affect_functions = @NamedTuple{affect::Function, affect_neg::Union{Function, Nothing}}[compile_affect_fn(
564+
cb,
565+
sys,
566+
dvs,
567+
ps,
568+
kwargs)
569+
for cb in cbs]
570+
cond = function (out, u, t, integ)
571+
rf_ip(out, u, parameter_values(integ), t)
572+
end
481573

482-
rf_oop, rf_ip = generate_custom_function(sys, rhss, dvs, ps; expression = Val{false},
483-
kwargs...)
574+
# since there may be different number of conditions and affects,
575+
# we build a map that translates the condition eq. number to the affect number
576+
eq_ind2affect = reduce(vcat,
577+
[fill(i, num_eqs[i]) for i in eachindex(affect_functions)])
578+
@assert length(eq_ind2affect) == length(eqs)
579+
@assert maximum(eq_ind2affect) == length(affect_functions)
484580

485-
affect_functions = map(cbs) do cb # Keep affect function separate
486-
eq_aff = affects(cb)
487-
affect = compile_affect(eq_aff, sys, dvs, ps; expression = Val{false}, kwargs...)
581+
affect = let affect_functions = affect_functions, eq_ind2affect = eq_ind2affect
582+
function (integ, eq_ind) # eq_ind refers to the equation index that triggered the event, each event has num_eqs[i] equations
583+
affect_functions[eq_ind2affect[eq_ind]].affect(integ)
584+
end
488585
end
489-
490-
if length(eqs) == 1
491-
cond = function (u, t, integ)
492-
if DiffEqBase.isinplace(integ.sol.prob)
493-
tmp, = DiffEqBase.get_tmp_cache(integ)
494-
rf_ip(tmp, u, parameter_values(integ), t)
495-
tmp[1]
496-
else
497-
rf_oop(u, parameter_values(integ), t)
586+
affect_neg = let affect_functions = affect_functions, eq_ind2affect = eq_ind2affect
587+
function (integ, eq_ind) # eq_ind refers to the equation index that triggered the event, each event has num_eqs[i] equations
588+
affect_neg = affect_functions[eq_ind2affect[eq_ind]].affect_neg
589+
if isnothing(affect_neg)
590+
return # skip if the neg function doesn't exist - don't want to split this into a separate VCC because that'd break ordering
498591
end
592+
affect_neg(integ)
499593
end
500-
ContinuousCallback(cond, affect_functions[])
594+
end
595+
return VectorContinuousCallback(
596+
cond, affect, affect_neg, length(eqs), rootfind = rootfind)
597+
end
598+
599+
"""
600+
Compile a single continuous callback affect function(s).
601+
"""
602+
function compile_affect_fn(cb, sys::AbstractODESystem, dvs, ps, kwargs)
603+
eq_aff = affects(cb)
604+
eq_neg_aff = affect_negs(cb)
605+
affect = compile_affect(eq_aff, sys, dvs, ps; expression = Val{false}, kwargs...)
606+
if eq_neg_aff === eq_aff
607+
affect_neg = affect
608+
elseif isnothing(eq_neg_aff)
609+
affect_neg = nothing
501610
else
502-
cond = function (out, u, t, integ)
503-
rf_ip(out, u, parameter_values(integ), t)
611+
affect_neg = compile_affect(
612+
eq_neg_aff, sys, dvs, ps; expression = Val{false}, kwargs...)
613+
end
614+
(affect = affect, affect_neg = affect_neg)
615+
end
616+
617+
function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknowns(sys),
618+
ps = full_parameters(sys); kwargs...)
619+
eqs = map(cb -> flatten_equations(cb.eqs), cbs)
620+
num_eqs = length.(eqs)
621+
total_eqs = sum(num_eqs)
622+
(isempty(eqs) || total_eqs == 0) && return nothing
623+
if total_eqs == 1
624+
# find the callback with only one eq
625+
cb_ind = findfirst(>(0), num_eqs)
626+
if isnothing(cb_ind)
627+
error("Inconsistent state in affect compilation; one equation but no callback with equations?")
504628
end
629+
cb = cbs[cb_ind]
630+
return generate_single_rootfinding_callback(cb.eqs[], cb, sys, dvs, ps; kwargs...)
631+
end
505632

506-
# since there may be different number of conditions and affects,
507-
# we build a map that translates the condition eq. number to the affect number
508-
eq_ind2affect = reduce(vcat,
509-
[fill(i, num_eqs[i]) for i in eachindex(affect_functions)])
510-
@assert length(eq_ind2affect) == length(eqs)
511-
@assert maximum(eq_ind2affect) == length(affect_functions)
633+
# group the cbs by what rootfind op they use
634+
# groupby would be very useful here, but alas
635+
cb_classes = Dict{
636+
@NamedTuple{rootfind::SciMLBase.RootfindOpt}, Vector{SymbolicContinuousCallback}}()
637+
for cb in cbs
638+
push!(
639+
get!(() -> SymbolicContinuousCallback[], cb_classes, (rootfind = cb.rootfind,)),
640+
cb)
641+
end
512642

513-
affect = let affect_functions = affect_functions, eq_ind2affect = eq_ind2affect
514-
function (integ, eq_ind) # eq_ind refers to the equation index that triggered the event, each event has num_eqs[i] equations
515-
affect_functions[eq_ind2affect[eq_ind]](integ)
516-
end
517-
end
518-
VectorContinuousCallback(cond, affect, length(eqs))
643+
# generate the callbacks out; we sort by the equivalence class to ensure a deterministic preference order
644+
compiled_callbacks = map(collect(pairs(sort!(
645+
OrderedDict(cb_classes); by = p -> p.rootfind)))) do (equiv_class, cbs_in_class)
646+
return generate_vector_rootfinding_callback(
647+
cbs_in_class, sys, dvs, ps; rootfind = equiv_class.rootfind, kwargs...)
648+
end
649+
if length(compiled_callbacks) == 1
650+
return compiled_callbacks[]
651+
else
652+
return CallbackSet(compiled_callbacks...)
519653
end
520654
end
521655

@@ -529,7 +663,6 @@ function compile_user_affect(affect::FunctionalAffect, sys, dvs, ps; kwargs...)
529663
ps_ind = Dict(reverse(en) for en in enumerate(ps))
530664
p_inds = map(sym -> ps_ind[sym], parameters(affect))
531665
end
532-
533666
# HACK: filter out eliminated symbols. Not clear this is the right thing to do
534667
# (MTK should keep these symbols)
535668
u = filter(x -> !isnothing(x[2]), collect(zip(unknowns_syms(affect), v_inds))) |>

0 commit comments

Comments
 (0)