Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add assertions functionality #3364

Merged
merged 11 commits into from
Feb 10, 2025
33 changes: 33 additions & 0 deletions docs/src/basics/Debugging.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,39 @@
Now we see that it crashed because `u1` decreased so much that it became negative and outside the domain of the `√` function.
We could have figured that out ourselves, but it is not always so obvious for more complex models.

Suppose we also want to validate that `u1 + u2 >= 2.0`. We can do this via the assertions functionality.

```@example debug
@mtkbuild sys = ODESystem(eqs, t; defaults, assertions = [(u1 + u2 >= 2.0) => "Oh no!"])
```

The assertions must be an iterable of pairs, where the first element is the symbolic condition and
the second is a message to be logged when the condition fails. All assertions are added to the
generated code and will cause the solver to reject steps that fail the assertions. For systems such
as the above where the assertion is guaranteed to eventually fail, the solver will likely exit
with a `dtmin` failure..

```@example debug
prob = ODEProblem(sys, [], (0.0, 10.0))
sol = solve(prob, Tsit5())
```

We can use `debug_system` to log the failing assertions in each call to the RHS function.

```@repl debug
dsys = debug_system(sys; functions = []);
dprob = ODEProblem(dsys, [], (0.0, 10.0));
dsol = solve(dprob, Tsit5());
```

Note the logs containing the failed assertion and corresponding message. To temporarily disable
logging in a system returned from `debug_system`, use `ModelingToolkit.ASSERTION_LOG_VARIABLE`.

```@repl debug
dprob[ModelingToolkit.ASSERTION_LOG_VARIABLE] = false;
solve(drob, Tsit5());

Check warning on line 68 in docs/src/basics/Debugging.md

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"drob" should be "drop".
```

```@docs
debug_system
```
6 changes: 3 additions & 3 deletions docs/src/basics/Variable_metadata.md
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ A variable can be marked `irreducible` to prevent it from being moved to an
it can be accessed in [callbacks](@ref events)

```@example metadata
@variable important_value [irreducible = true]
@variables important_value [irreducible = true]
isirreducible(important_value)
```

Expand All @@ -192,7 +192,7 @@ isirreducible(important_value)
When a model is structurally simplified, the algorithm will try to ensure that the variables with higher state priority become states of the system. A variable's state priority is a number set using the `state_priority` metadata.

```@example metadata
@variable important_dof [state_priority = 10] unimportant_dof [state_priority = -2]
@variables important_dof [state_priority = 10] unimportant_dof [state_priority = -2]
state_priority(important_dof)
```

Expand All @@ -201,7 +201,7 @@ state_priority(important_dof)
Units for variables can be designated using symbolic metadata. For more information, please see the [model validation and units](@ref units) section of the docs. Note that `getunit` is not equivalent to `get_unit` - the former is a metadata getter for individual variables (and is provided so the same interface function for `unit` exists like other metadata), while the latter is used to handle more general symbolic expressions.

```@example metadata
@variable speed [unit = u"m/s"]
@variables speed [unit = u"m/s"]
hasunit(speed)
```

Expand Down
2 changes: 1 addition & 1 deletion src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ export isinput, isoutput, getbounds, hasbounds, getguess, hasguess, isdisturbanc
istunable, getdist, hasdist,
tunable_parameters, isirreducible, getdescription, hasdescription,
hasunit, getunit, hasconnect, getconnect,
hasmisc, getmisc
hasmisc, getmisc, state_priority
export ode_order_lowering, dae_order_lowering, liouville_transform
export PDESystem
export Differential, expand_derivatives, @derivatives
Expand Down
56 changes: 56 additions & 0 deletions src/debugging.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,59 @@ function debug_sub(ex, funcs; kw...)
f in funcs ? logged_fun(f, args...; kw...) :
maketerm(typeof(ex), f, args, metadata(ex))
end

"""
$(TYPEDSIGNATURES)

A function which returns `NaN` if `condition` fails, and `0.0` otherwise.
"""
function _nan_condition(condition::Bool)
condition ? 0.0 : NaN
end

@register_symbolic _nan_condition(condition::Bool)

"""
$(TYPEDSIGNATURES)

A function which takes a condition `expr` and returns `NaN` if it is false,
and zero if it is true. In case the condition is false and `log == true`,
`message` will be logged as an `@error`.
"""
function _debug_assertion(expr::Bool, message::String, log::Bool)
value = _nan_condition(expr)
isnan(value) || return value
log && @error message
return value
end

@register_symbolic _debug_assertion(expr::Bool, message::String, log::Bool)

"""
Boolean parameter added to models returned from `debug_system` to control logging of
assertions.
"""
const ASSERTION_LOG_VARIABLE = only(@parameters __log_assertions_ₘₜₖ::Bool = false)

"""
$(TYPEDSIGNATURES)

Get a symbolic expression for all the assertions in `sys`. The expression returns `NaN`
if any of the assertions fail, and `0.0` otherwise. If `ASSERTION_LOG_VARIABLE` is a
parameter in the system, it will control whether the message associated with each
assertion is logged when it fails.
"""
function get_assertions_expr(sys::AbstractSystem)
asserts = assertions(sys)
term = 0
if is_parameter(sys, ASSERTION_LOG_VARIABLE)
for (k, v) in asserts
term += _debug_assertion(k, "Assertion $k failed:\n$v", ASSERTION_LOG_VARIABLE)
end
else
for (k, v) in asserts
term += _nan_condition(k)
end
end
return term
end
39 changes: 38 additions & 1 deletion src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -983,6 +983,7 @@ for prop in [:eqs
:gui_metadata
:discrete_subsystems
:parameter_dependencies
:assertions
:solved_unknowns
:split_idxs
:parent
Expand Down Expand Up @@ -1468,6 +1469,24 @@ end
"""
$(TYPEDSIGNATURES)

Get the assertions for a system `sys` and its subsystems.
"""
function assertions(sys::AbstractSystem)
has_assertions(sys) || return Dict{BasicSymbolic, String}()

asserts = get_assertions(sys)
systems = get_systems(sys)
namespaced_asserts = mapreduce(
merge!, systems; init = Dict{BasicSymbolic, String}()) do subsys
Dict{BasicSymbolic, String}(namespace_expr(k, subsys) => v
for (k, v) in assertions(subsys))
end
return merge(asserts, namespaced_asserts)
end

"""
$(TYPEDSIGNATURES)

Get the guesses for variables in the initialization system of the system `sys` and its subsystems.

See also [`initialization_equations`](@ref) and [`ModelingToolkit.get_guesses`](@ref).
Expand Down Expand Up @@ -2283,6 +2302,13 @@ ERROR: Function /(1, sin(P(t))) output non-finite value Inf with input
1 => 1
sin(P(t)) => 0.0
```

Additionally, all assertions in the system are optionally logged when they fail.
A new parameter is also added to the system which controls whether the message associated
with each assertion will be logged when the assertion fails. This parameter defaults to
`true` and can be toggled by symbolic indexing with
`ModelingToolkit.ASSERTION_LOG_VARIABLE`. For example,
`prob.ps[ModelingToolkit.ASSERTION_LOG_VARIABLE] = false` will disable logging.
"""
function debug_system(
sys::AbstractSystem; functions = [log, sqrt, (^), /, inv, asin, acos], kw...)
Expand All @@ -2293,11 +2319,17 @@ function debug_system(
error("debug_system(sys) only works on systems with no sub-systems! Consider flattening it with flatten(sys) or structural_simplify(sys) first.")
end
if has_eqs(sys)
@set! sys.eqs = debug_sub.(equations(sys), Ref(functions); kw...)
eqs = debug_sub.(equations(sys), Ref(functions); kw...)
@set! sys.eqs = eqs
@set! sys.ps = unique!([get_ps(sys); ASSERTION_LOG_VARIABLE])
@set! sys.defaults = merge(get_defaults(sys), Dict(ASSERTION_LOG_VARIABLE => true))
end
if has_observed(sys)
@set! sys.observed = debug_sub.(observed(sys), Ref(functions); kw...)
end
if iscomplete(sys)
sys = complete(sys; split = is_split(sys))
end
return sys
end

Expand Down Expand Up @@ -3036,6 +3068,11 @@ function extend(sys::AbstractSystem, basesys::AbstractSystem;
kwargs = merge(kwargs, (initialization_eqs = ieqs, guesses = guesses))
end

if has_assertions(basesys)
kwargs = merge(
kwargs, (; assertions = merge(get_assertions(basesys), get_assertions(sys))))
end

return T(args...; kwargs...)
end

Expand Down
4 changes: 4 additions & 0 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,10 @@ function generate_function(sys::AbstractODESystem, dvs = unknowns(sys),
rhss = implicit_dae ? [_iszero(eq.lhs) ? eq.rhs : eq.rhs - eq.lhs for eq in eqs] :
[eq.rhs for eq in eqs]

if !isempty(assertions(sys))
rhss[end] += unwrap(get_assertions_expr(sys))
end

# TODO: add an optional check on the ordering of observed equations
u = dvs
p = reorder_parameters(sys, ps)
Expand Down
14 changes: 11 additions & 3 deletions src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,11 @@ struct ODESystem <: AbstractODESystem
"""
parameter_dependencies::Vector{Equation}
"""
Mapping of conditions which should be true throughout the solution process to corresponding error
messages. These will be added to the equations when calling `debug_system`.
"""
assertions::Dict{BasicSymbolic, String}
"""
Metadata for the system, to be used by downstream packages.
"""
metadata::Any
Expand Down Expand Up @@ -190,7 +195,7 @@ struct ODESystem <: AbstractODESystem
jac, ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses,
torn_matching, initializesystem, initialization_eqs, schedule,
connector_type, preface, cevents,
devents, parameter_dependencies,
devents, parameter_dependencies, assertions = Dict{BasicSymbolic, String}(),
metadata = nothing, gui_metadata = nothing, is_dde = false,
tstops = [], tearing_state = nothing,
substitutions = nothing, complete = false, index_cache = nothing,
Expand All @@ -210,7 +215,7 @@ struct ODESystem <: AbstractODESystem
new(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac,
ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses, torn_matching,
initializesystem, initialization_eqs, schedule, connector_type, preface,
cevents, devents, parameter_dependencies, metadata,
cevents, devents, parameter_dependencies, assertions, metadata,
gui_metadata, is_dde, tstops, tearing_state, substitutions, complete, index_cache,
discrete_subsystems, solved_unknowns, split_idxs, parent)
end
Expand All @@ -235,6 +240,7 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
continuous_events = nothing,
discrete_events = nothing,
parameter_dependencies = Equation[],
assertions = Dict(),
checks = true,
metadata = nothing,
gui_metadata = nothing,
Expand Down Expand Up @@ -286,12 +292,13 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
if is_dde === nothing
is_dde = _check_if_dde(deqs, iv′, systems)
end
assertions = Dict{BasicSymbolic, Any}(unwrap(k) => v for (k, v) in assertions)
ODESystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
deqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, tgrad, jac,
ctrl_jac, Wfact, Wfact_t, name, description, systems,
defaults, guesses, nothing, initializesystem,
initialization_eqs, schedule, connector_type, preface, cont_callbacks,
disc_callbacks, parameter_dependencies,
disc_callbacks, parameter_dependencies, assertions,
metadata, gui_metadata, is_dde, tstops, checks = checks)
end

Expand Down Expand Up @@ -364,6 +371,7 @@ function flatten(sys::ODESystem, noeqs = false)
name = nameof(sys),
description = description(sys),
initialization_eqs = initialization_equations(sys),
assertions = assertions(sys),
is_dde = is_dde(sys),
tstops = symbolic_tstops(sys),
metadata = get_metadata(sys),
Expand Down
19 changes: 14 additions & 5 deletions src/systems/diffeqs/sdesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,11 @@ struct SDESystem <: AbstractODESystem
"""
parameter_dependencies::Vector{Equation}
"""
Mapping of conditions which should be true throughout the solution process to corresponding error
messages. These will be added to the equations when calling `debug_system`.
"""
assertions::Dict{BasicSymbolic, String}
"""
Metadata for the system, to be used by downstream packages.
"""
metadata::Any
Expand Down Expand Up @@ -159,7 +164,9 @@ struct SDESystem <: AbstractODESystem
function SDESystem(tag, deqs, neqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed,
tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults,
guesses, initializesystem, initialization_eqs, connector_type,
cevents, devents, parameter_dependencies, metadata = nothing, gui_metadata = nothing,
cevents, devents, parameter_dependencies, assertions = Dict{
BasicSymbolic, Nothing},
metadata = nothing, gui_metadata = nothing,
complete = false, index_cache = nothing, parent = nothing, is_scalar_noise = false,
is_dde = false,
isscheduled = false;
Expand All @@ -185,9 +192,8 @@ struct SDESystem <: AbstractODESystem
new(tag, deqs, neqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac,
ctrl_jac, Wfact, Wfact_t, name, description, systems,
defaults, guesses, initializesystem, initialization_eqs, connector_type, cevents,
devents,
parameter_dependencies, metadata, gui_metadata, complete, index_cache, parent, is_scalar_noise,
is_dde, isscheduled)
devents, parameter_dependencies, assertions, metadata, gui_metadata, complete,
index_cache, parent, is_scalar_noise, is_dde, isscheduled)
end
end

Expand All @@ -209,6 +215,7 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv
continuous_events = nothing,
discrete_events = nothing,
parameter_dependencies = Equation[],
assertions = Dict{BasicSymbolic, String}(),
metadata = nothing,
gui_metadata = nothing,
complete = false,
Expand Down Expand Up @@ -261,11 +268,12 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv
if is_dde === nothing
is_dde = _check_if_dde(deqs, iv′, systems)
end
assertions = Dict{BasicSymbolic, Any}(unwrap(k) => v for (k, v) in assertions)
SDESystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
deqs, neqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, tgrad, jac,
ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses,
initializesystem, initialization_eqs, connector_type,
cont_callbacks, disc_callbacks, parameter_dependencies, metadata, gui_metadata,
cont_callbacks, disc_callbacks, parameter_dependencies, assertions, metadata, gui_metadata,
complete, index_cache, parent, is_scalar_noise, is_dde; checks = checks)
end

Expand Down Expand Up @@ -378,6 +386,7 @@ function ODESystem(sys::SDESystem)
newsys = ODESystem(neweqs, get_iv(sys), unknowns(sys), parameters(sys);
parameter_dependencies = parameter_dependencies(sys), defaults = defaults(sys),
continuous_events = continuous_events(sys), discrete_events = discrete_events(sys),
assertions = assertions(sys),
name = nameof(sys), description = description(sys), metadata = get_metadata(sys))
@set newsys.parent = sys
end
Expand Down
2 changes: 1 addition & 1 deletion src/systems/systems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ function __structural_simplify(sys::AbstractSystem, io = nothing; simplify = fal
return SDESystem(Vector{Equation}(full_equations(ode_sys)), noise_eqs,
get_iv(ode_sys), unknowns(ode_sys), parameters(ode_sys);
name = nameof(ode_sys), is_scalar_noise, observed = observed(ode_sys), defaults = defaults(sys),
parameter_dependencies = parameter_dependencies(sys),
parameter_dependencies = parameter_dependencies(sys), assertions = assertions(sys),
guesses = guesses(sys), initialization_eqs = initialization_equations(sys))
end
end
Loading
Loading