diff --git a/docs/pages.jl b/docs/pages.jl index 2af487adf8..5dd869625c 100644 --- a/docs/pages.jl +++ b/docs/pages.jl @@ -31,6 +31,7 @@ pages = [ "basics/InputOutput.md", "basics/MTKLanguage.md", "basics/Validation.md", + "basics/Debugging.md", "basics/DependencyGraphs.md", "basics/Precompilation.md", "basics/FAQ.md"], diff --git a/docs/src/basics/Debugging.md b/docs/src/basics/Debugging.md new file mode 100644 index 0000000000..d5c51ec0c1 --- /dev/null +++ b/docs/src/basics/Debugging.md @@ -0,0 +1,40 @@ +# Debugging + +Every (mortal) modeler writes models that contain mistakes or are susceptible to numerical errors in their hunt for the perfect model. +Debugging such errors is part of the modeling process, and ModelingToolkit includes some functionality that helps with this. + +For example, consider an ODE model with "dangerous" functions (here `√`): + +```@example debug +using ModelingToolkit, OrdinaryDiffEq +using ModelingToolkit: t_nounits as t, D_nounits as D + +@variables u1(t) u2(t) u3(t) +eqs = [D(u1) ~ -√(u1), D(u2) ~ -√(u2), D(u3) ~ -√(u3)] +defaults = [u1 => 1.0, u2 => 2.0, u3 => 3.0] +@named sys = ODESystem(eqs, t; defaults) +sys = structural_simplify(sys) +``` + +This problem causes the ODE solver to crash: + +```@repl debug +prob = ODEProblem(sys, [], (0.0, 10.0), []); +sol = solve(prob, Tsit5()); +``` + +This suggests *that* something went wrong, but not exactly *what* went wrong and *where* it did. +In such situations, the `debug_system` function is helpful: + +```@repl debug +dsys = debug_system(sys; functions = [sqrt]); +dprob = ODEProblem(dsys, [], (0.0, 10.0), []); +dsol = solve(dprob, Tsit5()); +``` + +Now we see that it crashed because `u1` decreased so much that it became negative and outside the domain of the `√` function. +We could have figured that out ourselves, but it is not always so obvious for more complex models. + +```@docs +debug_system +``` diff --git a/src/debugging.jl b/src/debugging.jl index 06e3edf0d8..a1a168d8dd 100644 --- a/src/debugging.jl +++ b/src/debugging.jl @@ -1,36 +1,44 @@ -const LOGGED_FUN = Set([log, sqrt, (^), /, inv]) -is_legal(::typeof(/), a, b) = is_legal(inv, b) -is_legal(::typeof(inv), a) = !iszero(a) -is_legal(::Union{typeof(log), typeof(sqrt)}, a) = a isa Complex || a >= zero(a) -is_legal(::typeof(^), a, b) = a isa Complex || b isa Complex || isinteger(b) || a >= zero(a) - +struct LoggedFunctionException <: Exception + msg::String +end struct LoggedFun{F} f::F args::Any + error_nonfinite::Bool +end +function LoggedFunctionException(lf::LoggedFun, args, msg) + LoggedFunctionException( + "Function $(lf.f)($(join(lf.args, ", "))) " * msg * " with input" * + join("\n " .* string.(lf.args .=> args)) # one line for each "var => val" for readability + ) end +Base.showerror(io::IO, err::LoggedFunctionException) = print(io, err.msg) Base.nameof(lf::LoggedFun) = nameof(lf.f) SymbolicUtils.promote_symtype(::LoggedFun, Ts...) = Real function (lf::LoggedFun)(args...) - f = lf.f - symbolic_args = lf.args - if is_legal(f, args...) - f(args...) - else - args_str = join(string.(symbolic_args .=> args), ", ", ", and ") - throw(DomainError(args, "$(lf.f) errors with input(s): $args_str")) + val = try + lf.f(args...) # try to call with numerical input, as usual + catch err + throw(LoggedFunctionException(lf, args, "errors")) # Julia automatically attaches original error message end + if lf.error_nonfinite && !isfinite(val) + throw(LoggedFunctionException(lf, args, "output non-finite value $val")) + end + return val end -function logged_fun(f, args...) +function logged_fun(f, args...; error_nonfinite = true) # remember to update error_nonfinite in debug_system() docstring # Currently we don't really support complex numbers - term(LoggedFun(f, args), args..., type = Real) + term(LoggedFun(f, args, error_nonfinite), args..., type = Real) end -debug_sub(eq::Equation) = debug_sub(eq.lhs) ~ debug_sub(eq.rhs) -function debug_sub(ex) +function debug_sub(eq::Equation, funcs; kw...) + debug_sub(eq.lhs, funcs; kw...) ~ debug_sub(eq.rhs, funcs; kw...) +end +function debug_sub(ex, funcs; kw...) iscall(ex) || return ex f = operation(ex) - args = map(debug_sub, arguments(ex)) - f in LOGGED_FUN ? logged_fun(f, args...) : + args = map(ex -> debug_sub(ex, funcs; kw...), arguments(ex)) + f in funcs ? logged_fun(f, args...; kw...) : maketerm(typeof(ex), f, args, metadata(ex)) end diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 168260ae69..e6c07bcf9c 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -2260,37 +2260,37 @@ macro mtkbuild(exprs...) end """ -$(SIGNATURES) + debug_system(sys::AbstractSystem; functions = [log, sqrt, (^), /, inv, asin, acos], error_nonfinite = true) -Replace functions with singularities with a function that errors with symbolic -information. E.g. +Wrap `functions` in `sys` so any error thrown in them shows helpful symbolic-numeric +information about its input. If `error_nonfinite`, functions that output nonfinite +values (like `Inf` or `NaN`) also display errors, even though the raw function itself +does not throw an exception (like `1/0`). For example: ```julia-repl -julia> sys = debug_system(sys); - -julia> sys = complete(sys); +julia> sys = debug_system(complete(sys)) -julia> prob = ODEProblem(sys, [], (0, 1.0)); +julia> prob = ODEProblem(sys, [0.0, 2.0], (0.0, 1.0)) -julia> du = zero(prob.u0); - -julia> prob.f(du, prob.u0, prob.p, 0.0) -ERROR: DomainError with (-1.0,): -log errors with input(s): -cos(Q(t)) => -1.0 -Stacktrace: - [1] (::ModelingToolkit.LoggedFun{typeof(log)})(args::Float64) - ... +julia> prob.f(prob.u0, prob.p, 0.0) +ERROR: Function /(1, sin(P(t))) output non-finite value Inf with input + 1 => 1 + sin(P(t)) => 0.0 ``` """ -function debug_system(sys::AbstractSystem) +function debug_system( + sys::AbstractSystem; functions = [log, sqrt, (^), /, inv, asin, acos], kw...) + if !(functions isa Set) + functions = Set(functions) # more efficient "in" lookup + end if has_systems(sys) && !isempty(get_systems(sys)) - error("debug_system only works on systems with no sub-systems!") + error("debug_system(sys) only works on systems with no sub-systems! Consider flattening it with flatten(sys) or structural_simplify(sys) first.") end if has_eqs(sys) - @set! sys.eqs = debug_sub.(equations(sys)) + @set! sys.eqs = debug_sub.(equations(sys), Ref(functions); kw...) end if has_observed(sys) - @set! sys.observed = debug_sub.(observed(sys)) + @set! sys.observed = debug_sub.(observed(sys), Ref(functions); kw...) end return sys end diff --git a/test/odesystem.jl b/test/odesystem.jl index 85d135b338..94f676461b 100644 --- a/test/odesystem.jl +++ b/test/odesystem.jl @@ -931,22 +931,13 @@ testdict = Dict([:name => "test"]) @named sys = ODESystem(eqs, t, metadata = testdict) @test get_metadata(sys) == testdict -@variables P(t)=0 Q(t)=2 -∂t = D - -eqs = [∂t(Q) ~ 1 / sin(P) - ∂t(P) ~ log(-cos(Q))] +@variables P(t)=NaN Q(t)=NaN +eqs = [D(Q) ~ 1 / sin(P), D(P) ~ log(-cos(Q))] @named sys = ODESystem(eqs, t, [P, Q], []) -sys = complete(debug_system(sys)); -prob = ODEProblem(sys, [], (0, 1.0)); -du = zero(prob.u0); -if VERSION < v"1.8" - @test_throws DomainError prob.f(du, [1, 0], prob.p, 0.0) - @test_throws DomainError prob.f(du, [0, 2], prob.p, 0.0) -else - @test_throws "-cos(Q(t))" prob.f(du, [1, 0], prob.p, 0.0) - @test_throws "sin(P(t))" prob.f(du, [0, 2], prob.p, 0.0) -end +sys = complete(debug_system(sys)) +prob = ODEProblem(sys, [], (0.0, 1.0)) +@test_throws "log(-cos(Q(t))) errors" prob.f([1, 0], prob.p, 0.0) +@test_throws "/(1, sin(P(t))) output non-finite value" prob.f([0, 2], prob.p, 0.0) let @variables x(t) = 1