Skip to content

Commit 4c86290

Browse files
Merge pull request #3296 from hersle/generic_logged_functions
Generic mechanism for debugging/logging functions
2 parents 8668cde + 3bc5e15 commit 4c86290

File tree

5 files changed

+93
-53
lines changed

5 files changed

+93
-53
lines changed

docs/pages.jl

+1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ pages = [
3232
"basics/InputOutput.md",
3333
"basics/MTKLanguage.md",
3434
"basics/Validation.md",
35+
"basics/Debugging.md",
3536
"basics/DependencyGraphs.md",
3637
"basics/Precompilation.md",
3738
"basics/FAQ.md"],

docs/src/basics/Debugging.md

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Debugging
2+
3+
Every (mortal) modeler writes models that contain mistakes or are susceptible to numerical errors in their hunt for the perfect model.
4+
Debugging such errors is part of the modeling process, and ModelingToolkit includes some functionality that helps with this.
5+
6+
For example, consider an ODE model with "dangerous" functions (here ``):
7+
8+
```@example debug
9+
using ModelingToolkit, OrdinaryDiffEq
10+
using ModelingToolkit: t_nounits as t, D_nounits as D
11+
12+
@variables u1(t) u2(t) u3(t)
13+
eqs = [D(u1) ~ -√(u1), D(u2) ~ -√(u2), D(u3) ~ -√(u3)]
14+
defaults = [u1 => 1.0, u2 => 2.0, u3 => 3.0]
15+
@named sys = ODESystem(eqs, t; defaults)
16+
sys = structural_simplify(sys)
17+
```
18+
19+
This problem causes the ODE solver to crash:
20+
21+
```@repl debug
22+
prob = ODEProblem(sys, [], (0.0, 10.0), []);
23+
sol = solve(prob, Tsit5());
24+
```
25+
26+
This suggests *that* something went wrong, but not exactly *what* went wrong and *where* it did.
27+
In such situations, the `debug_system` function is helpful:
28+
29+
```@repl debug
30+
dsys = debug_system(sys; functions = [sqrt]);
31+
dprob = ODEProblem(dsys, [], (0.0, 10.0), []);
32+
dsol = solve(dprob, Tsit5());
33+
```
34+
35+
Now we see that it crashed because `u1` decreased so much that it became negative and outside the domain of the `` function.
36+
We could have figured that out ourselves, but it is not always so obvious for more complex models.
37+
38+
```@docs
39+
debug_system
40+
```

src/debugging.jl

+27-19
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,44 @@
1-
const LOGGED_FUN = Set([log, sqrt, (^), /, inv])
2-
is_legal(::typeof(/), a, b) = is_legal(inv, b)
3-
is_legal(::typeof(inv), a) = !iszero(a)
4-
is_legal(::Union{typeof(log), typeof(sqrt)}, a) = a isa Complex || a >= zero(a)
5-
is_legal(::typeof(^), a, b) = a isa Complex || b isa Complex || isinteger(b) || a >= zero(a)
6-
1+
struct LoggedFunctionException <: Exception
2+
msg::String
3+
end
74
struct LoggedFun{F}
85
f::F
96
args::Any
7+
error_nonfinite::Bool
8+
end
9+
function LoggedFunctionException(lf::LoggedFun, args, msg)
10+
LoggedFunctionException(
11+
"Function $(lf.f)($(join(lf.args, ", "))) " * msg * " with input" *
12+
join("\n " .* string.(lf.args .=> args)) # one line for each "var => val" for readability
13+
)
1014
end
15+
Base.showerror(io::IO, err::LoggedFunctionException) = print(io, err.msg)
1116
Base.nameof(lf::LoggedFun) = nameof(lf.f)
1217
SymbolicUtils.promote_symtype(::LoggedFun, Ts...) = Real
1318
function (lf::LoggedFun)(args...)
14-
f = lf.f
15-
symbolic_args = lf.args
16-
if is_legal(f, args...)
17-
f(args...)
18-
else
19-
args_str = join(string.(symbolic_args .=> args), ", ", ", and ")
20-
throw(DomainError(args, "$(lf.f) errors with input(s): $args_str"))
19+
val = try
20+
lf.f(args...) # try to call with numerical input, as usual
21+
catch err
22+
throw(LoggedFunctionException(lf, args, "errors")) # Julia automatically attaches original error message
2123
end
24+
if lf.error_nonfinite && !isfinite(val)
25+
throw(LoggedFunctionException(lf, args, "output non-finite value $val"))
26+
end
27+
return val
2228
end
2329

24-
function logged_fun(f, args...)
30+
function logged_fun(f, args...; error_nonfinite = true) # remember to update error_nonfinite in debug_system() docstring
2531
# Currently we don't really support complex numbers
26-
term(LoggedFun(f, args), args..., type = Real)
32+
term(LoggedFun(f, args, error_nonfinite), args..., type = Real)
2733
end
2834

29-
debug_sub(eq::Equation) = debug_sub(eq.lhs) ~ debug_sub(eq.rhs)
30-
function debug_sub(ex)
35+
function debug_sub(eq::Equation, funcs; kw...)
36+
debug_sub(eq.lhs, funcs; kw...) ~ debug_sub(eq.rhs, funcs; kw...)
37+
end
38+
function debug_sub(ex, funcs; kw...)
3139
iscall(ex) || return ex
3240
f = operation(ex)
33-
args = map(debug_sub, arguments(ex))
34-
f in LOGGED_FUN ? logged_fun(f, args...) :
41+
args = map(ex -> debug_sub(ex, funcs; kw...), arguments(ex))
42+
f in funcs ? logged_fun(f, args...; kw...) :
3543
maketerm(typeof(ex), f, args, metadata(ex))
3644
end

src/systems/abstractsystem.jl

+19-19
Original file line numberDiff line numberDiff line change
@@ -2281,37 +2281,37 @@ macro mtkbuild(exprs...)
22812281
end
22822282

22832283
"""
2284-
$(SIGNATURES)
2284+
debug_system(sys::AbstractSystem; functions = [log, sqrt, (^), /, inv, asin, acos], error_nonfinite = true)
22852285
2286-
Replace functions with singularities with a function that errors with symbolic
2287-
information. E.g.
2286+
Wrap `functions` in `sys` so any error thrown in them shows helpful symbolic-numeric
2287+
information about its input. If `error_nonfinite`, functions that output nonfinite
2288+
values (like `Inf` or `NaN`) also display errors, even though the raw function itself
2289+
does not throw an exception (like `1/0`). For example:
22882290
22892291
```julia-repl
2290-
julia> sys = debug_system(sys);
2291-
2292-
julia> sys = complete(sys);
2292+
julia> sys = debug_system(complete(sys))
22932293
2294-
julia> prob = ODEProblem(sys, [], (0, 1.0));
2294+
julia> prob = ODEProblem(sys, [0.0, 2.0], (0.0, 1.0))
22952295
2296-
julia> du = zero(prob.u0);
2297-
2298-
julia> prob.f(du, prob.u0, prob.p, 0.0)
2299-
ERROR: DomainError with (-1.0,):
2300-
log errors with input(s): -cos(Q(t)) => -1.0
2301-
Stacktrace:
2302-
[1] (::ModelingToolkit.LoggedFun{typeof(log)})(args::Float64)
2303-
...
2296+
julia> prob.f(prob.u0, prob.p, 0.0)
2297+
ERROR: Function /(1, sin(P(t))) output non-finite value Inf with input
2298+
1 => 1
2299+
sin(P(t)) => 0.0
23042300
```
23052301
"""
2306-
function debug_system(sys::AbstractSystem)
2302+
function debug_system(
2303+
sys::AbstractSystem; functions = [log, sqrt, (^), /, inv, asin, acos], kw...)
2304+
if !(functions isa Set)
2305+
functions = Set(functions) # more efficient "in" lookup
2306+
end
23072307
if has_systems(sys) && !isempty(get_systems(sys))
2308-
error("debug_system only works on systems with no sub-systems!")
2308+
error("debug_system(sys) only works on systems with no sub-systems! Consider flattening it with flatten(sys) or structural_simplify(sys) first.")
23092309
end
23102310
if has_eqs(sys)
2311-
@set! sys.eqs = debug_sub.(equations(sys))
2311+
@set! sys.eqs = debug_sub.(equations(sys), Ref(functions); kw...)
23122312
end
23132313
if has_observed(sys)
2314-
@set! sys.observed = debug_sub.(observed(sys))
2314+
@set! sys.observed = debug_sub.(observed(sys), Ref(functions); kw...)
23152315
end
23162316
return sys
23172317
end

test/odesystem.jl

+6-15
Original file line numberDiff line numberDiff line change
@@ -931,22 +931,13 @@ testdict = Dict([:name => "test"])
931931
@named sys = ODESystem(eqs, t, metadata = testdict)
932932
@test get_metadata(sys) == testdict
933933

934-
@variables P(t)=0 Q(t)=2
935-
∂t = D
936-
937-
eqs = [∂t(Q) ~ 1 / sin(P)
938-
∂t(P) ~ log(-cos(Q))]
934+
@variables P(t)=NaN Q(t)=NaN
935+
eqs = [D(Q) ~ 1 / sin(P), D(P) ~ log(-cos(Q))]
939936
@named sys = ODESystem(eqs, t, [P, Q], [])
940-
sys = complete(debug_system(sys));
941-
prob = ODEProblem(sys, [], (0, 1.0));
942-
du = zero(prob.u0);
943-
if VERSION < v"1.8"
944-
@test_throws DomainError prob.f(du, [1, 0], prob.p, 0.0)
945-
@test_throws DomainError prob.f(du, [0, 2], prob.p, 0.0)
946-
else
947-
@test_throws "-cos(Q(t))" prob.f(du, [1, 0], prob.p, 0.0)
948-
@test_throws "sin(P(t))" prob.f(du, [0, 2], prob.p, 0.0)
949-
end
937+
sys = complete(debug_system(sys))
938+
prob = ODEProblem(sys, [], (0.0, 1.0))
939+
@test_throws "log(-cos(Q(t))) errors" prob.f([1, 0], prob.p, 0.0)
940+
@test_throws "/(1, sin(P(t))) output non-finite value" prob.f([0, 2], prob.p, 0.0)
950941

951942
let
952943
@variables x(t) = 1

0 commit comments

Comments
 (0)