diff --git a/src/systems/parameter_buffer.jl b/src/systems/parameter_buffer.jl index f2d0867a32..0954d9a40c 100644 --- a/src/systems/parameter_buffer.jl +++ b/src/systems/parameter_buffer.jl @@ -721,7 +721,7 @@ end function Base.showerror(io::IO, e::MissingParametersError) println(io, MISSING_PARAMETERS_MESSAGE) - println(io, e.vars) + println(io, join(e.vars, ", ")) end function InvalidParameterSizeException(param, val) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index 0be41ea6f7..1f40f61eed 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -752,12 +752,14 @@ function process_SciMLProblem( u0Type = typeof(u0map) pType = typeof(pmap) - _u0map = u0map + u0map = to_varmap(u0map, dvs) symbols_to_symbolics!(sys, u0map) - _pmap = pmap pmap = to_varmap(pmap, parameters(sys)) symbols_to_symbolics!(sys, pmap) + + check_inputmap_keys(sys, u0map, pmap) + defs = add_toterms(recursive_unwrap(defaults(sys))) cmap, cs = get_cmap(sys) kwargs = NamedTuple(kwargs) @@ -854,6 +856,43 @@ function process_SciMLProblem( implicit_dae ? (f, du0, u0, p) : (f, u0, p) end +# Check that the keys of a u0map or pmap are valid +# (i.e. are symbolic keys, and are defined for the system.) +function check_inputmap_keys(sys, u0map, pmap) + badvarkeys = Any[] + for k in keys(u0map) + if symbolic_type(k) === NotSymbolic() + push!(badvarkeys, k) + end + end + + badparamkeys = Any[] + for k in keys(pmap) + if symbolic_type(k) === NotSymbolic() + push!(badparamkeys, k) + end + end + (isempty(badvarkeys) && isempty(badparamkeys)) || throw(InvalidKeyError(collect(badvarkeys), collect(badparamkeys))) +end + +const BAD_KEY_MESSAGE = """ + Undefined keys found in the parameter or initial condition maps. Check if symbolic variable names have been reassigned. + The following keys are invalid: + """ + +struct InvalidKeyError <: Exception + vars::Any + params::Any +end + +function Base.showerror(io::IO, e::InvalidKeyError) + println(io, BAD_KEY_MESSAGE) + println(io, "u0map: $(join(e.vars, ", "))") + println(io, "pmap: $(join(e.params, ", "))") +end + + + ############## # Legacy functions for backward compatibility ############## diff --git a/test/problem_validation.jl b/test/problem_validation.jl new file mode 100644 index 0000000000..bce39b51d2 --- /dev/null +++ b/test/problem_validation.jl @@ -0,0 +1,34 @@ +using ModelingToolkit +using ModelingToolkit: t_nounits as t, D_nounits as D + +@testset "Input map validation" begin + import ModelingToolkit: InvalidKeyError, MissingParametersError + @variables X(t) + @parameters p d + eqs = [D(X) ~ p - d*X] + @mtkbuild osys = ODESystem(eqs, t) + + p = "I accidentally renamed p" + u0 = [X => 1.0] + ps = [p => 1.0, d => 0.5] + @test_throws MissingParametersError oprob = ODEProblem(osys, u0, (0.0, 1.0), ps) + + @parameters p d + ps = [p => 1.0, d => 0.5, "Random stuff" => 3.0] + @test_throws InvalidKeyError oprob = ODEProblem(osys, u0, (0.0, 1.0), ps) + + u0 = [:X => 1.0, "random" => 3.0] + @test_throws InvalidKeyError oprob = ODEProblem(osys, u0, (0.0, 1.0), ps) + + @variables x(t) y(t) z(t) + @parameters a b c d + eqs = [D(x) ~ x*a, D(y) ~ y*c, D(z) ~ b + d] + @mtkbuild sys = ODESystem(eqs, t) + pmap = [a => 1, b => 2, c => 3, d => 4, "b" => 2] + u0map = [x => 1, y => 2, z => 3] + @test_throws InvalidKeyError ODEProblem(sys, u0map, (0., 1.), pmap) + + pmap = [a => 1, b => 2, c => 3, d => 4] + u0map = [x => 1, y => 2, z => 3, :0 => 3] + @test_throws InvalidKeyError ODEProblem(sys, u0map, (0., 1.), pmap) +end