Skip to content

Commit 0a5c1ce

Browse files
Merge pull request #3364 from AayushSabharwal/as/assertions
feat: add `assertions` functionality
2 parents 3d7a825 + 7cd399f commit 0a5c1ce

11 files changed

+212
-14
lines changed

docs/src/basics/Debugging.md

+33
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,39 @@ dsol = solve(dprob, Tsit5());
3535
Now we see that it crashed because `u1` decreased so much that it became negative and outside the domain of the `` function.
3636
We could have figured that out ourselves, but it is not always so obvious for more complex models.
3737

38+
Suppose we also want to validate that `u1 + u2 >= 2.0`. We can do this via the assertions functionality.
39+
40+
```@example debug
41+
@mtkbuild sys = ODESystem(eqs, t; defaults, assertions = [(u1 + u2 >= 2.0) => "Oh no!"])
42+
```
43+
44+
The assertions must be an iterable of pairs, where the first element is the symbolic condition and
45+
the second is a message to be logged when the condition fails. All assertions are added to the
46+
generated code and will cause the solver to reject steps that fail the assertions. For systems such
47+
as the above where the assertion is guaranteed to eventually fail, the solver will likely exit
48+
with a `dtmin` failure..
49+
50+
```@example debug
51+
prob = ODEProblem(sys, [], (0.0, 10.0))
52+
sol = solve(prob, Tsit5())
53+
```
54+
55+
We can use `debug_system` to log the failing assertions in each call to the RHS function.
56+
57+
```@repl debug
58+
dsys = debug_system(sys; functions = []);
59+
dprob = ODEProblem(dsys, [], (0.0, 10.0));
60+
dsol = solve(dprob, Tsit5());
61+
```
62+
63+
Note the logs containing the failed assertion and corresponding message. To temporarily disable
64+
logging in a system returned from `debug_system`, use `ModelingToolkit.ASSERTION_LOG_VARIABLE`.
65+
66+
```@repl debug
67+
dprob[ModelingToolkit.ASSERTION_LOG_VARIABLE] = false;
68+
solve(drob, Tsit5());
69+
```
70+
3871
```@docs
3972
debug_system
4073
```

docs/src/basics/Variable_metadata.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ A variable can be marked `irreducible` to prevent it from being moved to an
183183
it can be accessed in [callbacks](@ref events)
184184

185185
```@example metadata
186-
@variable important_value [irreducible = true]
186+
@variables important_value [irreducible = true]
187187
isirreducible(important_value)
188188
```
189189

@@ -192,7 +192,7 @@ isirreducible(important_value)
192192
When a model is structurally simplified, the algorithm will try to ensure that the variables with higher state priority become states of the system. A variable's state priority is a number set using the `state_priority` metadata.
193193

194194
```@example metadata
195-
@variable important_dof [state_priority = 10] unimportant_dof [state_priority = -2]
195+
@variables important_dof [state_priority = 10] unimportant_dof [state_priority = -2]
196196
state_priority(important_dof)
197197
```
198198

@@ -201,7 +201,7 @@ state_priority(important_dof)
201201
Units for variables can be designated using symbolic metadata. For more information, please see the [model validation and units](@ref units) section of the docs. Note that `getunit` is not equivalent to `get_unit` - the former is a metadata getter for individual variables (and is provided so the same interface function for `unit` exists like other metadata), while the latter is used to handle more general symbolic expressions.
202202

203203
```@example metadata
204-
@variable speed [unit = u"m/s"]
204+
@variables speed [unit = u"m/s"]
205205
hasunit(speed)
206206
```
207207

src/ModelingToolkit.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ export isinput, isoutput, getbounds, hasbounds, getguess, hasguess, isdisturbanc
249249
istunable, getdist, hasdist,
250250
tunable_parameters, isirreducible, getdescription, hasdescription,
251251
hasunit, getunit, hasconnect, getconnect,
252-
hasmisc, getmisc
252+
hasmisc, getmisc, state_priority
253253
export ode_order_lowering, dae_order_lowering, liouville_transform
254254
export PDESystem
255255
export Differential, expand_derivatives, @derivatives

src/debugging.jl

+56
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,59 @@ function debug_sub(ex, funcs; kw...)
4242
f in funcs ? logged_fun(f, args...; kw...) :
4343
maketerm(typeof(ex), f, args, metadata(ex))
4444
end
45+
46+
"""
47+
$(TYPEDSIGNATURES)
48+
49+
A function which returns `NaN` if `condition` fails, and `0.0` otherwise.
50+
"""
51+
function _nan_condition(condition::Bool)
52+
condition ? 0.0 : NaN
53+
end
54+
55+
@register_symbolic _nan_condition(condition::Bool)
56+
57+
"""
58+
$(TYPEDSIGNATURES)
59+
60+
A function which takes a condition `expr` and returns `NaN` if it is false,
61+
and zero if it is true. In case the condition is false and `log == true`,
62+
`message` will be logged as an `@error`.
63+
"""
64+
function _debug_assertion(expr::Bool, message::String, log::Bool)
65+
value = _nan_condition(expr)
66+
isnan(value) || return value
67+
log && @error message
68+
return value
69+
end
70+
71+
@register_symbolic _debug_assertion(expr::Bool, message::String, log::Bool)
72+
73+
"""
74+
Boolean parameter added to models returned from `debug_system` to control logging of
75+
assertions.
76+
"""
77+
const ASSERTION_LOG_VARIABLE = only(@parameters __log_assertions_ₘₜₖ::Bool = false)
78+
79+
"""
80+
$(TYPEDSIGNATURES)
81+
82+
Get a symbolic expression for all the assertions in `sys`. The expression returns `NaN`
83+
if any of the assertions fail, and `0.0` otherwise. If `ASSERTION_LOG_VARIABLE` is a
84+
parameter in the system, it will control whether the message associated with each
85+
assertion is logged when it fails.
86+
"""
87+
function get_assertions_expr(sys::AbstractSystem)
88+
asserts = assertions(sys)
89+
term = 0
90+
if is_parameter(sys, ASSERTION_LOG_VARIABLE)
91+
for (k, v) in asserts
92+
term += _debug_assertion(k, "Assertion $k failed:\n$v", ASSERTION_LOG_VARIABLE)
93+
end
94+
else
95+
for (k, v) in asserts
96+
term += _nan_condition(k)
97+
end
98+
end
99+
return term
100+
end

src/systems/abstractsystem.jl

+38-1
Original file line numberDiff line numberDiff line change
@@ -983,6 +983,7 @@ for prop in [:eqs
983983
:gui_metadata
984984
:discrete_subsystems
985985
:parameter_dependencies
986+
:assertions
986987
:solved_unknowns
987988
:split_idxs
988989
:parent
@@ -1468,6 +1469,24 @@ end
14681469
"""
14691470
$(TYPEDSIGNATURES)
14701471
1472+
Get the assertions for a system `sys` and its subsystems.
1473+
"""
1474+
function assertions(sys::AbstractSystem)
1475+
has_assertions(sys) || return Dict{BasicSymbolic, String}()
1476+
1477+
asserts = get_assertions(sys)
1478+
systems = get_systems(sys)
1479+
namespaced_asserts = mapreduce(
1480+
merge!, systems; init = Dict{BasicSymbolic, String}()) do subsys
1481+
Dict{BasicSymbolic, String}(namespace_expr(k, subsys) => v
1482+
for (k, v) in assertions(subsys))
1483+
end
1484+
return merge(asserts, namespaced_asserts)
1485+
end
1486+
1487+
"""
1488+
$(TYPEDSIGNATURES)
1489+
14711490
Get the guesses for variables in the initialization system of the system `sys` and its subsystems.
14721491
14731492
See also [`initialization_equations`](@ref) and [`ModelingToolkit.get_guesses`](@ref).
@@ -2283,6 +2302,13 @@ ERROR: Function /(1, sin(P(t))) output non-finite value Inf with input
22832302
1 => 1
22842303
sin(P(t)) => 0.0
22852304
```
2305+
2306+
Additionally, all assertions in the system are optionally logged when they fail.
2307+
A new parameter is also added to the system which controls whether the message associated
2308+
with each assertion will be logged when the assertion fails. This parameter defaults to
2309+
`true` and can be toggled by symbolic indexing with
2310+
`ModelingToolkit.ASSERTION_LOG_VARIABLE`. For example,
2311+
`prob.ps[ModelingToolkit.ASSERTION_LOG_VARIABLE] = false` will disable logging.
22862312
"""
22872313
function debug_system(
22882314
sys::AbstractSystem; functions = [log, sqrt, (^), /, inv, asin, acos], kw...)
@@ -2293,11 +2319,17 @@ function debug_system(
22932319
error("debug_system(sys) only works on systems with no sub-systems! Consider flattening it with flatten(sys) or structural_simplify(sys) first.")
22942320
end
22952321
if has_eqs(sys)
2296-
@set! sys.eqs = debug_sub.(equations(sys), Ref(functions); kw...)
2322+
eqs = debug_sub.(equations(sys), Ref(functions); kw...)
2323+
@set! sys.eqs = eqs
2324+
@set! sys.ps = unique!([get_ps(sys); ASSERTION_LOG_VARIABLE])
2325+
@set! sys.defaults = merge(get_defaults(sys), Dict(ASSERTION_LOG_VARIABLE => true))
22972326
end
22982327
if has_observed(sys)
22992328
@set! sys.observed = debug_sub.(observed(sys), Ref(functions); kw...)
23002329
end
2330+
if iscomplete(sys)
2331+
sys = complete(sys; split = is_split(sys))
2332+
end
23012333
return sys
23022334
end
23032335

@@ -3036,6 +3068,11 @@ function extend(sys::AbstractSystem, basesys::AbstractSystem;
30363068
kwargs = merge(kwargs, (initialization_eqs = ieqs, guesses = guesses))
30373069
end
30383070

3071+
if has_assertions(basesys)
3072+
kwargs = merge(
3073+
kwargs, (; assertions = merge(get_assertions(basesys), get_assertions(sys))))
3074+
end
3075+
30393076
return T(args...; kwargs...)
30403077
end
30413078

src/systems/diffeqs/abstractodesystem.jl

+4
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,10 @@ function generate_function(sys::AbstractODESystem, dvs = unknowns(sys),
168168
rhss = implicit_dae ? [_iszero(eq.lhs) ? eq.rhs : eq.rhs - eq.lhs for eq in eqs] :
169169
[eq.rhs for eq in eqs]
170170

171+
if !isempty(assertions(sys))
172+
rhss[end] += unwrap(get_assertions_expr(sys))
173+
end
174+
171175
# TODO: add an optional check on the ordering of observed equations
172176
u = dvs
173177
p = reorder_parameters(sys, ps)

src/systems/diffeqs/odesystem.jl

+11-3
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,11 @@ struct ODESystem <: AbstractODESystem
137137
"""
138138
parameter_dependencies::Vector{Equation}
139139
"""
140+
Mapping of conditions which should be true throughout the solution process to corresponding error
141+
messages. These will be added to the equations when calling `debug_system`.
142+
"""
143+
assertions::Dict{BasicSymbolic, String}
144+
"""
140145
Metadata for the system, to be used by downstream packages.
141146
"""
142147
metadata::Any
@@ -190,7 +195,7 @@ struct ODESystem <: AbstractODESystem
190195
jac, ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses,
191196
torn_matching, initializesystem, initialization_eqs, schedule,
192197
connector_type, preface, cevents,
193-
devents, parameter_dependencies,
198+
devents, parameter_dependencies, assertions = Dict{BasicSymbolic, String}(),
194199
metadata = nothing, gui_metadata = nothing, is_dde = false,
195200
tstops = [], tearing_state = nothing,
196201
substitutions = nothing, complete = false, index_cache = nothing,
@@ -210,7 +215,7 @@ struct ODESystem <: AbstractODESystem
210215
new(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac,
211216
ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses, torn_matching,
212217
initializesystem, initialization_eqs, schedule, connector_type, preface,
213-
cevents, devents, parameter_dependencies, metadata,
218+
cevents, devents, parameter_dependencies, assertions, metadata,
214219
gui_metadata, is_dde, tstops, tearing_state, substitutions, complete, index_cache,
215220
discrete_subsystems, solved_unknowns, split_idxs, parent)
216221
end
@@ -235,6 +240,7 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
235240
continuous_events = nothing,
236241
discrete_events = nothing,
237242
parameter_dependencies = Equation[],
243+
assertions = Dict(),
238244
checks = true,
239245
metadata = nothing,
240246
gui_metadata = nothing,
@@ -286,12 +292,13 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
286292
if is_dde === nothing
287293
is_dde = _check_if_dde(deqs, iv′, systems)
288294
end
295+
assertions = Dict{BasicSymbolic, Any}(unwrap(k) => v for (k, v) in assertions)
289296
ODESystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
290297
deqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, tgrad, jac,
291298
ctrl_jac, Wfact, Wfact_t, name, description, systems,
292299
defaults, guesses, nothing, initializesystem,
293300
initialization_eqs, schedule, connector_type, preface, cont_callbacks,
294-
disc_callbacks, parameter_dependencies,
301+
disc_callbacks, parameter_dependencies, assertions,
295302
metadata, gui_metadata, is_dde, tstops, checks = checks)
296303
end
297304

@@ -364,6 +371,7 @@ function flatten(sys::ODESystem, noeqs = false)
364371
name = nameof(sys),
365372
description = description(sys),
366373
initialization_eqs = initialization_equations(sys),
374+
assertions = assertions(sys),
367375
is_dde = is_dde(sys),
368376
tstops = symbolic_tstops(sys),
369377
metadata = get_metadata(sys),

src/systems/diffeqs/sdesystem.jl

+14-5
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,11 @@ struct SDESystem <: AbstractODESystem
126126
"""
127127
parameter_dependencies::Vector{Equation}
128128
"""
129+
Mapping of conditions which should be true throughout the solution process to corresponding error
130+
messages. These will be added to the equations when calling `debug_system`.
131+
"""
132+
assertions::Dict{BasicSymbolic, String}
133+
"""
129134
Metadata for the system, to be used by downstream packages.
130135
"""
131136
metadata::Any
@@ -159,7 +164,9 @@ struct SDESystem <: AbstractODESystem
159164
function SDESystem(tag, deqs, neqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed,
160165
tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults,
161166
guesses, initializesystem, initialization_eqs, connector_type,
162-
cevents, devents, parameter_dependencies, metadata = nothing, gui_metadata = nothing,
167+
cevents, devents, parameter_dependencies, assertions = Dict{
168+
BasicSymbolic, Nothing},
169+
metadata = nothing, gui_metadata = nothing,
163170
complete = false, index_cache = nothing, parent = nothing, is_scalar_noise = false,
164171
is_dde = false,
165172
isscheduled = false;
@@ -185,9 +192,8 @@ struct SDESystem <: AbstractODESystem
185192
new(tag, deqs, neqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac,
186193
ctrl_jac, Wfact, Wfact_t, name, description, systems,
187194
defaults, guesses, initializesystem, initialization_eqs, connector_type, cevents,
188-
devents,
189-
parameter_dependencies, metadata, gui_metadata, complete, index_cache, parent, is_scalar_noise,
190-
is_dde, isscheduled)
195+
devents, parameter_dependencies, assertions, metadata, gui_metadata, complete,
196+
index_cache, parent, is_scalar_noise, is_dde, isscheduled)
191197
end
192198
end
193199

@@ -209,6 +215,7 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv
209215
continuous_events = nothing,
210216
discrete_events = nothing,
211217
parameter_dependencies = Equation[],
218+
assertions = Dict{BasicSymbolic, String}(),
212219
metadata = nothing,
213220
gui_metadata = nothing,
214221
complete = false,
@@ -261,11 +268,12 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv
261268
if is_dde === nothing
262269
is_dde = _check_if_dde(deqs, iv′, systems)
263270
end
271+
assertions = Dict{BasicSymbolic, Any}(unwrap(k) => v for (k, v) in assertions)
264272
SDESystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
265273
deqs, neqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, tgrad, jac,
266274
ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses,
267275
initializesystem, initialization_eqs, connector_type,
268-
cont_callbacks, disc_callbacks, parameter_dependencies, metadata, gui_metadata,
276+
cont_callbacks, disc_callbacks, parameter_dependencies, assertions, metadata, gui_metadata,
269277
complete, index_cache, parent, is_scalar_noise, is_dde; checks = checks)
270278
end
271279

@@ -378,6 +386,7 @@ function ODESystem(sys::SDESystem)
378386
newsys = ODESystem(neweqs, get_iv(sys), unknowns(sys), parameters(sys);
379387
parameter_dependencies = parameter_dependencies(sys), defaults = defaults(sys),
380388
continuous_events = continuous_events(sys), discrete_events = discrete_events(sys),
389+
assertions = assertions(sys),
381390
name = nameof(sys), description = description(sys), metadata = get_metadata(sys))
382391
@set newsys.parent = sys
383392
end

src/systems/systems.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ function __structural_simplify(sys::AbstractSystem, io = nothing; simplify = fal
165165
return SDESystem(Vector{Equation}(full_equations(ode_sys)), noise_eqs,
166166
get_iv(ode_sys), unknowns(ode_sys), parameters(ode_sys);
167167
name = nameof(ode_sys), is_scalar_noise, observed = observed(ode_sys), defaults = defaults(sys),
168-
parameter_dependencies = parameter_dependencies(sys),
168+
parameter_dependencies = parameter_dependencies(sys), assertions = assertions(sys),
169169
guesses = guesses(sys), initialization_eqs = initialization_equations(sys))
170170
end
171171
end

0 commit comments

Comments
 (0)