From 9530a2e6bce78f9fdd49ed1841b9ed6120e04ad2 Mon Sep 17 00:00:00 2001 From: vyudu <vincent.duyuan@gmail.com> Date: Thu, 23 Jan 2025 16:31:25 -0500 Subject: [PATCH 1/7] init --- src/systems/parameter_buffer.jl | 6 +++-- src/systems/problem_utils.jl | 42 +++++++++++++++++++++++++++++++-- test/problem_validation.jl | 24 +++++++++++++++++++ 3 files changed, 68 insertions(+), 4 deletions(-) create mode 100644 test/problem_validation.jl diff --git a/src/systems/parameter_buffer.jl b/src/systems/parameter_buffer.jl index bc4e62a773..7a88489b48 100644 --- a/src/systems/parameter_buffer.jl +++ b/src/systems/parameter_buffer.jl @@ -33,17 +33,19 @@ function MTKParameters( else error("Cannot create MTKParameters if system does not have index_cache") end + all_ps = Set(unwrap.(parameters(sys))) union!(all_ps, default_toterm.(unwrap.(parameters(sys)))) if p isa Vector && !(eltype(p) <: Pair) && !isempty(p) ps = parameters(sys) - length(p) == length(ps) || error("Invalid parameters") + length(p) == length(ps) || error("The number of parameter values is not equal to the number of parameters.") p = ps .=> p end if p isa SciMLBase.NullParameters || isempty(p) p = Dict() end p = todict(p) + defs = Dict(default_toterm(unwrap(k)) => v for (k, v) in defaults(sys)) if eltype(u0) <: Pair u0 = todict(u0) @@ -761,7 +763,7 @@ end function Base.showerror(io::IO, e::MissingParametersError) println(io, MISSING_PARAMETERS_MESSAGE) - println(io, e.vars) + println(io, join(e.vars, ", ")) end function InvalidParameterSizeException(param, val) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index bf3d72e1e5..e2af471a15 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -684,12 +684,17 @@ function process_SciMLProblem( u0Type = typeof(u0map) pType = typeof(pmap) - _u0map = u0map + u0map = to_varmap(u0map, dvs) symbols_to_symbolics!(sys, u0map) - _pmap = pmap + check_keys(sys, u0map) + pmap = to_varmap(pmap, ps) symbols_to_symbolics!(sys, pmap) + check_keys(sys, pmap) + badkeys = filter(k -> symbolic_type(k) === NotSymbolic(), keys(pmap)) + isempty(badkeys) || throw(BadKeyError(collect(badkeys))) + defs = add_toterms(recursive_unwrap(defaults(sys))) cmap, cs = get_cmap(sys) kwargs = NamedTuple(kwargs) @@ -778,6 +783,39 @@ function process_SciMLProblem( implicit_dae ? (f, du0, u0, p) : (f, u0, p) end +# Check that the keys of a u0map or pmap are valid +# (i.e. are symbolic keys, and are defined for the system.) +function check_keys(sys, map) + badkeys = Any[] + for k in keys(map) + if symbolic_type(k) === NotSymbolic() + push!(badkeys, k) + elseif k isa Symbol + !hasproperty(sys, k) && push!(badkeys, k) + elseif k ∉ Set(parameters(sys)) && k ∉ Set(unknowns(sys)) + push!(badkeys, k) + end + end + + isempty(badkeys) || throw(BadKeyError(collect(badkeys))) +end + +const BAD_KEY_MESSAGE = """ + Undefined keys found in the parameter or initial condition maps. + The following keys are either invalid or not parameters/states of the system: + """ + +struct BadKeyError <: Exception + vars::Any +end + +function Base.showerror(io::IO, e::BadKeyError) + println(io, BAD_KEY_MESSAGE) + println(io, join(e.vars, ", ")) +end + + + ############## # Legacy functions for backward compatibility ############## diff --git a/test/problem_validation.jl b/test/problem_validation.jl new file mode 100644 index 0000000000..fb724c55bd --- /dev/null +++ b/test/problem_validation.jl @@ -0,0 +1,24 @@ +using ModelingToolkit +using ModelingToolkit: t_nounits as t, D_nounits as D + +@testset "Input map validation" begin + @variables X(t) + @parameters p d + eqs = [D(X) ~ p - d*X] + @mtkbuild osys = ODESystem(eqs, t) + + p = "I accidentally renamed p" + u0 = [X => 1.0] + ps = [p => 1.0, d => 0.5] + @test_throws ModelingToolkit.BadKeyError oprob = ODEProblem(osys, u0, (0.0, 1.0), ps) + + ps = [p => 1.0, d => 0.5, "Random stuff" => 3.0] + @test_throws ModelingToolkit.BadKeyError oprob = ODEProblem(osys, u0, (0.0, 1.0), ps) + + u0 = [:X => 1.0, "random" => 3.0] + @test_throws ModelingToolkit.BadKeyError oprob = ODEProblem(osys, u0, (0.0, 1.0), ps) + + @parameters k + ps = [p => 1., d => 0.5, k => 3.] + @test_throws ModelingToolkit.BadKeyError oprob = ODEProblem(osys, u0, (0.0, 1.0), ps) +end From a9dc112905ced2e1b1b16e1af0e179b0604f2563 Mon Sep 17 00:00:00 2001 From: vyudu <vincent.duyuan@gmail.com> Date: Thu, 23 Jan 2025 16:34:38 -0500 Subject: [PATCH 2/7] up --- test/problem_validation.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/problem_validation.jl b/test/problem_validation.jl index fb724c55bd..f871327ae8 100644 --- a/test/problem_validation.jl +++ b/test/problem_validation.jl @@ -12,6 +12,7 @@ using ModelingToolkit: t_nounits as t, D_nounits as D ps = [p => 1.0, d => 0.5] @test_throws ModelingToolkit.BadKeyError oprob = ODEProblem(osys, u0, (0.0, 1.0), ps) + @parameters p d ps = [p => 1.0, d => 0.5, "Random stuff" => 3.0] @test_throws ModelingToolkit.BadKeyError oprob = ODEProblem(osys, u0, (0.0, 1.0), ps) From 3c629ac227950886235d85307f3a82c7b3183ac7 Mon Sep 17 00:00:00 2001 From: vyudu <vincent.duyuan@gmail.com> Date: Thu, 23 Jan 2025 16:36:30 -0500 Subject: [PATCH 3/7] up --- src/systems/problem_utils.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index e2af471a15..8c6082e436 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -692,8 +692,6 @@ function process_SciMLProblem( pmap = to_varmap(pmap, ps) symbols_to_symbolics!(sys, pmap) check_keys(sys, pmap) - badkeys = filter(k -> symbolic_type(k) === NotSymbolic(), keys(pmap)) - isempty(badkeys) || throw(BadKeyError(collect(badkeys))) defs = add_toterms(recursive_unwrap(defaults(sys))) cmap, cs = get_cmap(sys) From 417b386a24a6c8c1ed8c49ee6a5c2581a562afae Mon Sep 17 00:00:00 2001 From: vyudu <vincent.duyuan@gmail.com> Date: Fri, 24 Jan 2025 08:48:29 -0500 Subject: [PATCH 4/7] just check not-symbolic --- src/systems/problem_utils.jl | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index 8c6082e436..8541056272 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -788,10 +788,6 @@ function check_keys(sys, map) for k in keys(map) if symbolic_type(k) === NotSymbolic() push!(badkeys, k) - elseif k isa Symbol - !hasproperty(sys, k) && push!(badkeys, k) - elseif k ∉ Set(parameters(sys)) && k ∉ Set(unknowns(sys)) - push!(badkeys, k) end end From ae4e6f70200221e736bdeea732f50c9c8fbeced2 Mon Sep 17 00:00:00 2001 From: vyudu <vincent.duyuan@gmail.com> Date: Mon, 17 Feb 2025 11:44:51 -0500 Subject: [PATCH 5/7] add tests --- src/systems/problem_utils.jl | 32 ++++++++++++++++++++------------ test/odesystem.jl | 15 +++++++++++++++ 2 files changed, 35 insertions(+), 12 deletions(-) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index 8541056272..3d5b00c418 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -687,11 +687,11 @@ function process_SciMLProblem( u0map = to_varmap(u0map, dvs) symbols_to_symbolics!(sys, u0map) - check_keys(sys, u0map) pmap = to_varmap(pmap, ps) symbols_to_symbolics!(sys, pmap) - check_keys(sys, pmap) + + check_inputmap_keys(sys, u0map, pmap) defs = add_toterms(recursive_unwrap(defaults(sys))) cmap, cs = get_cmap(sys) @@ -783,29 +783,37 @@ end # Check that the keys of a u0map or pmap are valid # (i.e. are symbolic keys, and are defined for the system.) -function check_keys(sys, map) - badkeys = Any[] - for k in keys(map) +function check_inputmap_keys(sys, u0map, pmap) + badvarkeys = Any[] + for k in keys(u0map) if symbolic_type(k) === NotSymbolic() - push!(badkeys, k) + push!(badvarkeys, k) end end - isempty(badkeys) || throw(BadKeyError(collect(badkeys))) + badparamkeys = Any[] + for k in keys(pmap) + if symbolic_type(k) === NotSymbolic() + push!(badparamkeys, k) + end + end + (isempty(badvarkeys) && isempty(badparamkeys)) || throw(InvalidKeyError(collect(badvarkeys), collect(badparamkeys))) end const BAD_KEY_MESSAGE = """ - Undefined keys found in the parameter or initial condition maps. - The following keys are either invalid or not parameters/states of the system: + Undefined keys found in the parameter or initial condition maps. Check if symbolic variable names have been reassigned. + The following keys are invalid: """ -struct BadKeyError <: Exception +struct InvalidKeyError <: Exception vars::Any + params::Any end -function Base.showerror(io::IO, e::BadKeyError) +function Base.showerror(io::IO, e::InvalidKeyError) println(io, BAD_KEY_MESSAGE) - println(io, join(e.vars, ", ")) + println(io, "u0map: $(join(e.vars, ", "))") + println(io, "pmap: $(join(e.params, ", "))") end diff --git a/test/odesystem.jl b/test/odesystem.jl index a635c3dad9..7b4d580718 100644 --- a/test/odesystem.jl +++ b/test/odesystem.jl @@ -1626,3 +1626,18 @@ end prob = ODEProblem{false}(lowered_dae_sys; u0_constructor = x -> SVector(x...)) @test prob.u0 isa SVector end + +@testset "input map validation" begin + import ModelingToolkit: InvalidKeyError + @variables x(t) y(t) z(t) + @parameters a b c d + eqs = [D(x) ~ x*a, D(y) ~ y*c, D(z) ~ b + d] + @mtkbuild sys = ODESystem(eqs, t) + pmap = [a => 1, b => 2, c => 3, d => 4, "b" => 2] + u0map = [x => 1, y => 2, z => 3] + @test_throws InvalidKeyError ODEProblem(sys, u0map, (0., 1.), pmap) + + pmap = [a => 1, b => 2, c => 3, d => 4] + u0map = [x => 1, y => 2, z => 3, :0 => 3] + @test_throws InvalidKeyError ODEProblem(sys, u0map, (0., 1.), pmap) +end From 83f572c42a3787a395675bd5a0de4daa5aa30063 Mon Sep 17 00:00:00 2001 From: vyudu <vincent.duyuan@gmail.com> Date: Mon, 17 Feb 2025 18:25:45 -0500 Subject: [PATCH 6/7] refactor tests --- test/odesystem.jl | 15 --------------- test/problem_validation.jl | 21 +++++++++++++++------ 2 files changed, 15 insertions(+), 21 deletions(-) diff --git a/test/odesystem.jl b/test/odesystem.jl index 62bcf4c355..de166ef0a1 100644 --- a/test/odesystem.jl +++ b/test/odesystem.jl @@ -1673,18 +1673,3 @@ end prob = ODEProblem{false}(lowered_dae_sys; u0_constructor = x -> SVector(x...)) @test prob.u0 isa SVector end - -@testset "input map validation" begin - import ModelingToolkit: InvalidKeyError - @variables x(t) y(t) z(t) - @parameters a b c d - eqs = [D(x) ~ x*a, D(y) ~ y*c, D(z) ~ b + d] - @mtkbuild sys = ODESystem(eqs, t) - pmap = [a => 1, b => 2, c => 3, d => 4, "b" => 2] - u0map = [x => 1, y => 2, z => 3] - @test_throws InvalidKeyError ODEProblem(sys, u0map, (0., 1.), pmap) - - pmap = [a => 1, b => 2, c => 3, d => 4] - u0map = [x => 1, y => 2, z => 3, :0 => 3] - @test_throws InvalidKeyError ODEProblem(sys, u0map, (0., 1.), pmap) -end diff --git a/test/problem_validation.jl b/test/problem_validation.jl index f871327ae8..bce39b51d2 100644 --- a/test/problem_validation.jl +++ b/test/problem_validation.jl @@ -2,6 +2,7 @@ using ModelingToolkit using ModelingToolkit: t_nounits as t, D_nounits as D @testset "Input map validation" begin + import ModelingToolkit: InvalidKeyError, MissingParametersError @variables X(t) @parameters p d eqs = [D(X) ~ p - d*X] @@ -10,16 +11,24 @@ using ModelingToolkit: t_nounits as t, D_nounits as D p = "I accidentally renamed p" u0 = [X => 1.0] ps = [p => 1.0, d => 0.5] - @test_throws ModelingToolkit.BadKeyError oprob = ODEProblem(osys, u0, (0.0, 1.0), ps) + @test_throws MissingParametersError oprob = ODEProblem(osys, u0, (0.0, 1.0), ps) @parameters p d ps = [p => 1.0, d => 0.5, "Random stuff" => 3.0] - @test_throws ModelingToolkit.BadKeyError oprob = ODEProblem(osys, u0, (0.0, 1.0), ps) + @test_throws InvalidKeyError oprob = ODEProblem(osys, u0, (0.0, 1.0), ps) u0 = [:X => 1.0, "random" => 3.0] - @test_throws ModelingToolkit.BadKeyError oprob = ODEProblem(osys, u0, (0.0, 1.0), ps) + @test_throws InvalidKeyError oprob = ODEProblem(osys, u0, (0.0, 1.0), ps) - @parameters k - ps = [p => 1., d => 0.5, k => 3.] - @test_throws ModelingToolkit.BadKeyError oprob = ODEProblem(osys, u0, (0.0, 1.0), ps) + @variables x(t) y(t) z(t) + @parameters a b c d + eqs = [D(x) ~ x*a, D(y) ~ y*c, D(z) ~ b + d] + @mtkbuild sys = ODESystem(eqs, t) + pmap = [a => 1, b => 2, c => 3, d => 4, "b" => 2] + u0map = [x => 1, y => 2, z => 3] + @test_throws InvalidKeyError ODEProblem(sys, u0map, (0., 1.), pmap) + + pmap = [a => 1, b => 2, c => 3, d => 4] + u0map = [x => 1, y => 2, z => 3, :0 => 3] + @test_throws InvalidKeyError ODEProblem(sys, u0map, (0., 1.), pmap) end From 18e766efa165188ea6df5b2ea9a095c85c9ff40d Mon Sep 17 00:00:00 2001 From: vyudu <vincent.duyuan@gmail.com> Date: Fri, 21 Feb 2025 11:06:26 -0800 Subject: [PATCH 7/7] fix bug --- src/systems/problem_utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index e7345f5d5d..1f40f61eed 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -755,7 +755,7 @@ function process_SciMLProblem( u0map = to_varmap(u0map, dvs) symbols_to_symbolics!(sys, u0map) - pmap = to_varmap(pmap, ps) + pmap = to_varmap(pmap, parameters(sys)) symbols_to_symbolics!(sys, pmap) check_inputmap_keys(sys, u0map, pmap)