From 9530a2e6bce78f9fdd49ed1841b9ed6120e04ad2 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Thu, 23 Jan 2025 16:31:25 -0500
Subject: [PATCH 1/7] init

---
 src/systems/parameter_buffer.jl |  6 +++--
 src/systems/problem_utils.jl    | 42 +++++++++++++++++++++++++++++++--
 test/problem_validation.jl      | 24 +++++++++++++++++++
 3 files changed, 68 insertions(+), 4 deletions(-)
 create mode 100644 test/problem_validation.jl

diff --git a/src/systems/parameter_buffer.jl b/src/systems/parameter_buffer.jl
index bc4e62a773..7a88489b48 100644
--- a/src/systems/parameter_buffer.jl
+++ b/src/systems/parameter_buffer.jl
@@ -33,17 +33,19 @@ function MTKParameters(
     else
         error("Cannot create MTKParameters if system does not have index_cache")
     end
+
     all_ps = Set(unwrap.(parameters(sys)))
     union!(all_ps, default_toterm.(unwrap.(parameters(sys))))
     if p isa Vector && !(eltype(p) <: Pair) && !isempty(p)
         ps = parameters(sys)
-        length(p) == length(ps) || error("Invalid parameters")
+        length(p) == length(ps) || error("The number of parameter values is not equal to the number of parameters.")
         p = ps .=> p
     end
     if p isa SciMLBase.NullParameters || isempty(p)
         p = Dict()
     end
     p = todict(p)
+
     defs = Dict(default_toterm(unwrap(k)) => v for (k, v) in defaults(sys))
     if eltype(u0) <: Pair
         u0 = todict(u0)
@@ -761,7 +763,7 @@ end
 
 function Base.showerror(io::IO, e::MissingParametersError)
     println(io, MISSING_PARAMETERS_MESSAGE)
-    println(io, e.vars)
+    println(io, join(e.vars, ", "))
 end
 
 function InvalidParameterSizeException(param, val)
diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl
index bf3d72e1e5..e2af471a15 100644
--- a/src/systems/problem_utils.jl
+++ b/src/systems/problem_utils.jl
@@ -684,12 +684,17 @@ function process_SciMLProblem(
 
     u0Type = typeof(u0map)
     pType = typeof(pmap)
-    _u0map = u0map
+
     u0map = to_varmap(u0map, dvs)
     symbols_to_symbolics!(sys, u0map)
-    _pmap = pmap
+    check_keys(sys, u0map)
+
     pmap = to_varmap(pmap, ps)
     symbols_to_symbolics!(sys, pmap)
+    check_keys(sys, pmap)
+    badkeys = filter(k -> symbolic_type(k) === NotSymbolic(), keys(pmap))
+    isempty(badkeys) || throw(BadKeyError(collect(badkeys)))
+
     defs = add_toterms(recursive_unwrap(defaults(sys)))
     cmap, cs = get_cmap(sys)
     kwargs = NamedTuple(kwargs)
@@ -778,6 +783,39 @@ function process_SciMLProblem(
     implicit_dae ? (f, du0, u0, p) : (f, u0, p)
 end
 
+# Check that the keys of a u0map or pmap are valid
+# (i.e. are symbolic keys, and are defined for the system.)
+function check_keys(sys, map) 
+    badkeys = Any[]
+    for k in keys(map)
+        if symbolic_type(k) === NotSymbolic()
+            push!(badkeys, k)
+        elseif k isa Symbol
+            !hasproperty(sys, k) && push!(badkeys, k)
+        elseif k ∉ Set(parameters(sys)) && k ∉ Set(unknowns(sys)) 
+            push!(badkeys, k)
+        end
+    end
+
+    isempty(badkeys) || throw(BadKeyError(collect(badkeys)))
+end
+
+const BAD_KEY_MESSAGE = """
+                        Undefined keys found in the parameter or initial condition maps. 
+                        The following keys are either invalid or not parameters/states of the system:
+                        """
+
+struct BadKeyError <: Exception
+    vars::Any
+end
+
+function Base.showerror(io::IO, e::BadKeyError) 
+    println(io, BAD_KEY_MESSAGE)
+    println(io, join(e.vars, ", "))
+end
+
+
+
 ##############
 # Legacy functions for backward compatibility
 ##############
diff --git a/test/problem_validation.jl b/test/problem_validation.jl
new file mode 100644
index 0000000000..fb724c55bd
--- /dev/null
+++ b/test/problem_validation.jl
@@ -0,0 +1,24 @@
+using ModelingToolkit
+using ModelingToolkit: t_nounits as t, D_nounits as D
+
+@testset "Input map validation" begin
+    @variables X(t)
+    @parameters p d
+    eqs = [D(X) ~ p - d*X]
+    @mtkbuild osys = ODESystem(eqs, t)
+    
+    p = "I accidentally renamed p"
+    u0 = [X => 1.0]
+    ps = [p => 1.0, d => 0.5]
+    @test_throws ModelingToolkit.BadKeyError oprob = ODEProblem(osys, u0, (0.0, 1.0), ps)
+    
+    ps = [p => 1.0, d => 0.5, "Random stuff" => 3.0]
+    @test_throws ModelingToolkit.BadKeyError oprob = ODEProblem(osys, u0, (0.0, 1.0), ps)
+
+    u0 = [:X => 1.0, "random" => 3.0]
+    @test_throws ModelingToolkit.BadKeyError oprob = ODEProblem(osys, u0, (0.0, 1.0), ps)
+
+    @parameters k
+    ps = [p => 1., d => 0.5, k => 3.]
+    @test_throws ModelingToolkit.BadKeyError oprob = ODEProblem(osys, u0, (0.0, 1.0), ps)
+end

From a9dc112905ced2e1b1b16e1af0e179b0604f2563 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Thu, 23 Jan 2025 16:34:38 -0500
Subject: [PATCH 2/7] up

---
 test/problem_validation.jl | 1 +
 1 file changed, 1 insertion(+)

diff --git a/test/problem_validation.jl b/test/problem_validation.jl
index fb724c55bd..f871327ae8 100644
--- a/test/problem_validation.jl
+++ b/test/problem_validation.jl
@@ -12,6 +12,7 @@ using ModelingToolkit: t_nounits as t, D_nounits as D
     ps = [p => 1.0, d => 0.5]
     @test_throws ModelingToolkit.BadKeyError oprob = ODEProblem(osys, u0, (0.0, 1.0), ps)
     
+    @parameters p d
     ps = [p => 1.0, d => 0.5, "Random stuff" => 3.0]
     @test_throws ModelingToolkit.BadKeyError oprob = ODEProblem(osys, u0, (0.0, 1.0), ps)
 

From 3c629ac227950886235d85307f3a82c7b3183ac7 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Thu, 23 Jan 2025 16:36:30 -0500
Subject: [PATCH 3/7] up

---
 src/systems/problem_utils.jl | 2 --
 1 file changed, 2 deletions(-)

diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl
index e2af471a15..8c6082e436 100644
--- a/src/systems/problem_utils.jl
+++ b/src/systems/problem_utils.jl
@@ -692,8 +692,6 @@ function process_SciMLProblem(
     pmap = to_varmap(pmap, ps)
     symbols_to_symbolics!(sys, pmap)
     check_keys(sys, pmap)
-    badkeys = filter(k -> symbolic_type(k) === NotSymbolic(), keys(pmap))
-    isempty(badkeys) || throw(BadKeyError(collect(badkeys)))
 
     defs = add_toterms(recursive_unwrap(defaults(sys)))
     cmap, cs = get_cmap(sys)

From 417b386a24a6c8c1ed8c49ee6a5c2581a562afae Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Fri, 24 Jan 2025 08:48:29 -0500
Subject: [PATCH 4/7] just check not-symbolic

---
 src/systems/problem_utils.jl | 4 ----
 1 file changed, 4 deletions(-)

diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl
index 8c6082e436..8541056272 100644
--- a/src/systems/problem_utils.jl
+++ b/src/systems/problem_utils.jl
@@ -788,10 +788,6 @@ function check_keys(sys, map)
     for k in keys(map)
         if symbolic_type(k) === NotSymbolic()
             push!(badkeys, k)
-        elseif k isa Symbol
-            !hasproperty(sys, k) && push!(badkeys, k)
-        elseif k ∉ Set(parameters(sys)) && k ∉ Set(unknowns(sys)) 
-            push!(badkeys, k)
         end
     end
 

From ae4e6f70200221e736bdeea732f50c9c8fbeced2 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Mon, 17 Feb 2025 11:44:51 -0500
Subject: [PATCH 5/7] add tests

---
 src/systems/problem_utils.jl | 32 ++++++++++++++++++++------------
 test/odesystem.jl            | 15 +++++++++++++++
 2 files changed, 35 insertions(+), 12 deletions(-)

diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl
index 8541056272..3d5b00c418 100644
--- a/src/systems/problem_utils.jl
+++ b/src/systems/problem_utils.jl
@@ -687,11 +687,11 @@ function process_SciMLProblem(
 
     u0map = to_varmap(u0map, dvs)
     symbols_to_symbolics!(sys, u0map)
-    check_keys(sys, u0map)
 
     pmap = to_varmap(pmap, ps)
     symbols_to_symbolics!(sys, pmap)
-    check_keys(sys, pmap)
+
+    check_inputmap_keys(sys, u0map, pmap)
 
     defs = add_toterms(recursive_unwrap(defaults(sys)))
     cmap, cs = get_cmap(sys)
@@ -783,29 +783,37 @@ end
 
 # Check that the keys of a u0map or pmap are valid
 # (i.e. are symbolic keys, and are defined for the system.)
-function check_keys(sys, map) 
-    badkeys = Any[]
-    for k in keys(map)
+function check_inputmap_keys(sys, u0map, pmap)
+    badvarkeys = Any[]
+    for k in keys(u0map)
         if symbolic_type(k) === NotSymbolic()
-            push!(badkeys, k)
+            push!(badvarkeys, k)
         end
     end
 
-    isempty(badkeys) || throw(BadKeyError(collect(badkeys)))
+    badparamkeys = Any[]
+    for k in keys(pmap)
+        if symbolic_type(k) === NotSymbolic()
+            push!(badparamkeys, k)
+        end
+    end
+    (isempty(badvarkeys) && isempty(badparamkeys)) || throw(InvalidKeyError(collect(badvarkeys), collect(badparamkeys)))
 end
 
 const BAD_KEY_MESSAGE = """
-                        Undefined keys found in the parameter or initial condition maps. 
-                        The following keys are either invalid or not parameters/states of the system:
+                        Undefined keys found in the parameter or initial condition maps. Check if symbolic variable names have been reassigned. 
+                        The following keys are invalid:
                         """
 
-struct BadKeyError <: Exception
+struct InvalidKeyError <: Exception
     vars::Any
+    params::Any
 end
 
-function Base.showerror(io::IO, e::BadKeyError) 
+function Base.showerror(io::IO, e::InvalidKeyError) 
     println(io, BAD_KEY_MESSAGE)
-    println(io, join(e.vars, ", "))
+    println(io, "u0map: $(join(e.vars, ", "))")
+    println(io, "pmap: $(join(e.params, ", "))")
 end
 
 
diff --git a/test/odesystem.jl b/test/odesystem.jl
index a635c3dad9..7b4d580718 100644
--- a/test/odesystem.jl
+++ b/test/odesystem.jl
@@ -1626,3 +1626,18 @@ end
     prob = ODEProblem{false}(lowered_dae_sys; u0_constructor = x -> SVector(x...))
     @test prob.u0 isa SVector
 end
+
+@testset "input map validation" begin
+    import ModelingToolkit: InvalidKeyError 
+    @variables x(t) y(t) z(t)
+    @parameters a b c d 
+    eqs = [D(x) ~ x*a, D(y) ~ y*c, D(z) ~ b + d]
+    @mtkbuild sys = ODESystem(eqs, t)
+    pmap = [a => 1, b => 2, c => 3, d => 4, "b" => 2]
+    u0map = [x => 1, y => 2, z => 3]
+    @test_throws InvalidKeyError ODEProblem(sys, u0map, (0., 1.), pmap)
+
+    pmap = [a => 1, b => 2, c => 3, d => 4]
+    u0map = [x => 1, y => 2, z => 3, :0 => 3]
+    @test_throws InvalidKeyError ODEProblem(sys, u0map, (0., 1.), pmap)
+end

From 83f572c42a3787a395675bd5a0de4daa5aa30063 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Mon, 17 Feb 2025 18:25:45 -0500
Subject: [PATCH 6/7] refactor tests

---
 test/odesystem.jl          | 15 ---------------
 test/problem_validation.jl | 21 +++++++++++++++------
 2 files changed, 15 insertions(+), 21 deletions(-)

diff --git a/test/odesystem.jl b/test/odesystem.jl
index 62bcf4c355..de166ef0a1 100644
--- a/test/odesystem.jl
+++ b/test/odesystem.jl
@@ -1673,18 +1673,3 @@ end
     prob = ODEProblem{false}(lowered_dae_sys; u0_constructor = x -> SVector(x...))
     @test prob.u0 isa SVector
 end
-
-@testset "input map validation" begin
-    import ModelingToolkit: InvalidKeyError 
-    @variables x(t) y(t) z(t)
-    @parameters a b c d 
-    eqs = [D(x) ~ x*a, D(y) ~ y*c, D(z) ~ b + d]
-    @mtkbuild sys = ODESystem(eqs, t)
-    pmap = [a => 1, b => 2, c => 3, d => 4, "b" => 2]
-    u0map = [x => 1, y => 2, z => 3]
-    @test_throws InvalidKeyError ODEProblem(sys, u0map, (0., 1.), pmap)
-
-    pmap = [a => 1, b => 2, c => 3, d => 4]
-    u0map = [x => 1, y => 2, z => 3, :0 => 3]
-    @test_throws InvalidKeyError ODEProblem(sys, u0map, (0., 1.), pmap)
-end
diff --git a/test/problem_validation.jl b/test/problem_validation.jl
index f871327ae8..bce39b51d2 100644
--- a/test/problem_validation.jl
+++ b/test/problem_validation.jl
@@ -2,6 +2,7 @@ using ModelingToolkit
 using ModelingToolkit: t_nounits as t, D_nounits as D
 
 @testset "Input map validation" begin
+    import ModelingToolkit: InvalidKeyError, MissingParametersError 
     @variables X(t)
     @parameters p d
     eqs = [D(X) ~ p - d*X]
@@ -10,16 +11,24 @@ using ModelingToolkit: t_nounits as t, D_nounits as D
     p = "I accidentally renamed p"
     u0 = [X => 1.0]
     ps = [p => 1.0, d => 0.5]
-    @test_throws ModelingToolkit.BadKeyError oprob = ODEProblem(osys, u0, (0.0, 1.0), ps)
+    @test_throws MissingParametersError oprob = ODEProblem(osys, u0, (0.0, 1.0), ps)
     
     @parameters p d
     ps = [p => 1.0, d => 0.5, "Random stuff" => 3.0]
-    @test_throws ModelingToolkit.BadKeyError oprob = ODEProblem(osys, u0, (0.0, 1.0), ps)
+    @test_throws InvalidKeyError oprob = ODEProblem(osys, u0, (0.0, 1.0), ps)
 
     u0 = [:X => 1.0, "random" => 3.0]
-    @test_throws ModelingToolkit.BadKeyError oprob = ODEProblem(osys, u0, (0.0, 1.0), ps)
+    @test_throws InvalidKeyError oprob = ODEProblem(osys, u0, (0.0, 1.0), ps)
 
-    @parameters k
-    ps = [p => 1., d => 0.5, k => 3.]
-    @test_throws ModelingToolkit.BadKeyError oprob = ODEProblem(osys, u0, (0.0, 1.0), ps)
+    @variables x(t) y(t) z(t)
+    @parameters a b c d 
+    eqs = [D(x) ~ x*a, D(y) ~ y*c, D(z) ~ b + d]
+    @mtkbuild sys = ODESystem(eqs, t)
+    pmap = [a => 1, b => 2, c => 3, d => 4, "b" => 2]
+    u0map = [x => 1, y => 2, z => 3]
+    @test_throws InvalidKeyError ODEProblem(sys, u0map, (0., 1.), pmap)
+
+    pmap = [a => 1, b => 2, c => 3, d => 4]
+    u0map = [x => 1, y => 2, z => 3, :0 => 3]
+    @test_throws InvalidKeyError ODEProblem(sys, u0map, (0., 1.), pmap)
 end

From 18e766efa165188ea6df5b2ea9a095c85c9ff40d Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Fri, 21 Feb 2025 11:06:26 -0800
Subject: [PATCH 7/7] fix bug

---
 src/systems/problem_utils.jl | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl
index e7345f5d5d..1f40f61eed 100644
--- a/src/systems/problem_utils.jl
+++ b/src/systems/problem_utils.jl
@@ -755,7 +755,7 @@ function process_SciMLProblem(
 
     u0map = to_varmap(u0map, dvs)
     symbols_to_symbolics!(sys, u0map)
-    pmap = to_varmap(pmap, ps)
+    pmap = to_varmap(pmap, parameters(sys))
     symbols_to_symbolics!(sys, pmap)
 
     check_inputmap_keys(sys, u0map, pmap)