From 9733460668a662b41fd4cf87102d08521c1e4f92 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Fri, 22 Nov 2024 13:25:30 -0500
Subject: [PATCH 001/111] init

---
 src/systems/diffeqs/abstractodesystem.jl | 110 +++++++++++++++++++++++
 1 file changed, 110 insertions(+)

diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl
index 8d9e0b5381..2579ebafcf 100644
--- a/src/systems/diffeqs/abstractodesystem.jl
+++ b/src/systems/diffeqs/abstractodesystem.jl
@@ -466,6 +466,116 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
         initializeprobpmap = initializeprobpmap)
 end
 
+"""
+```julia
+SciMLBase.BVProblem{iip}(sys::AbstractODESystem, u0map, tspan,
+                         parammap = DiffEqBase.NullParameters();
+                         version = nothing, tgrad = false,
+                         jac = true, sparse = true,
+                         simplify = false,
+                         kwargs...) where {iip}
+```
+
+Create a `BVProblem` from the [`ODESystem`](@ref). The arguments `dvs` and
+`ps` are used to set the order of the dependent variable and parameter vectors,
+respectively.
+"""
+function SciMLBase.BVProblem(sys::AbstractODESystem, args...; kwargs...)
+    BVProblem{true}(sys, args...; kwargs...)
+end
+
+function SciMLBase.BVProblem(sys::AbstractODESystem,
+        u0map::StaticArray,
+        args...;
+        kwargs...)
+    BVProblem{false, SciMLBase.FullSpecialize}(sys, u0map, args...; kwargs...)
+end
+
+function SciMLBase.BVProblem{true}(sys::AbstractODESystem, args...; kwargs...)
+    BVProblem{true, SciMLBase.AutoSpecialize}(sys, args...; kwargs...)
+end
+
+function SciMLBase.BVProblem{false}(sys::AbstractODESystem, args...; kwargs...)
+    BVProblem{false, SciMLBase.FullSpecialize}(sys, args...; kwargs...)
+end
+
+function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = [],
+        tspan = get_tspan(sys),
+        parammap = DiffEqBase.NullParameters();
+        version = nothing, tgrad = false,
+        jac = true, sparse = true, 
+        sparsity = true, 
+        callback = nothing,
+        check_length = true,
+        warn_initialize_determined = true,
+        eval_expression = false,
+        eval_module = @__MODULE__,
+        kwargs...) where {iip, specialize}
+
+    if !iscomplete(sys)
+        error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `BVProblem`")
+    end
+
+    f, u0, p = process_SciMLProblem(ODEFunction{iip, specialize}, sys, u0map, parammap;
+        t = tspan !== nothing ? tspan[1] : tspan,
+        check_length, warn_initialize_determined, eval_expression, eval_module, jac, sparse, sparsity, kwargs...)
+
+    # if jac
+    #     jac_gen = generate_jacobian(sys, dvs, ps;
+    #         simplify = simplify, sparse = sparse,
+    #         expression = Val{true},
+    #         expression_module = eval_module,
+    #         checkbounds = checkbounds, kwargs...)
+    #     jac_oop, jac_iip = eval_or_rgf.(jac_gen; eval_expression, eval_module)
+
+    #     _jac(u, p, t) = jac_oop(u, p, t)
+    #     _jac(J, u, p, t) = jac_iip(J, u, p, t)
+    #     _jac(u, p::Tuple{Vararg{Number}}, t) = jac_oop(u, p, t)
+    #     _jac(J, u, p::Tuple{Vararg{Number}}, t) = jac_iip(J, u, p, t)
+    #     _jac(u, p::Tuple, t) = jac_oop(u, p..., t)
+    #     _jac(J, u, p::Tuple, t) = jac_iip(J, u, p..., t)
+    #     _jac(u, p::MTKParameters, t) = jac_oop(u, p..., t)
+    #     _jac(J, u, p::MTKParameters, t) = jac_iip(J, u, p..., t)
+    # else
+    #     _jac = nothing
+    # end
+
+    # jac_prototype = if sparse
+    #     uElType = u0 === nothing ? Float64 : eltype(u0)
+    #     if jac
+    #         similar(calculate_jacobian(sys, sparse = sparse), uElType)
+    #     else
+    #         similar(jacobian_sparsity(sys), uElType)
+    #     end
+    # else
+    #     nothing
+    # end
+
+    # f.jac = _jac
+    # f.jac_prototype = jac_prototype
+    # f.sparsity = jac ? jacobian_sparsity(sys) : nothing
+
+    cbs = process_events(sys; callback, eval_expression, eval_module, kwargs...)
+
+    kwargs = filter_kwargs(kwargs)
+
+    kwargs1 = (;)
+    if cbs !== nothing
+        kwargs1 = merge(kwargs1, (callback = cbs,))
+    end
+
+    # Define the boundary conditions
+    bc = if iip 
+        (residual, u, p, t) -> (residual = u[1] - u0)
+    else
+        (u, p, t) -> (u[1] - u0)
+    end
+
+    return BVProblem{iip}(f, bc, u0, tspan, p; kwargs1..., kwargs...)
+end
+
+get_callback(prob::BVProblem) = prob.kwargs[:callback]
+
 """
 ```julia
 DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys),

From b3da8137a9288b9c3eeff209a541dd41a5f97524 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Sun, 1 Dec 2024 16:18:18 -0500
Subject: [PATCH 002/111] up

---
 src/systems/diffeqs/abstractodesystem.jl | 64 +++++++++---------------
 1 file changed, 23 insertions(+), 41 deletions(-)

diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl
index 7a8bf56562..bb4cbbb6c3 100644
--- a/src/systems/diffeqs/abstractodesystem.jl
+++ b/src/systems/diffeqs/abstractodesystem.jl
@@ -469,6 +469,17 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
         initializeprobpmap = initializeprobpmap)
 end
 
+"""
+```julia
+SciMLBase.BVPFunction{iip}(sys::AbstractODESystem, u0map, tspan,
+                         parammap = DiffEqBase.NullParameters();
+                         version = nothing, tgrad = false,
+                         jac = true, sparse = true,
+                         simplify = false,
+                         kwargs...) where {iip}
+```
+"""
+
 """
 ```julia
 SciMLBase.BVProblem{iip}(sys::AbstractODESystem, u0map, tspan,
@@ -481,7 +492,7 @@ SciMLBase.BVProblem{iip}(sys::AbstractODESystem, u0map, tspan,
 
 Create a `BVProblem` from the [`ODESystem`](@ref). The arguments `dvs` and
 `ps` are used to set the order of the dependent variable and parameter vectors,
-respectively.
+respectively. `u0` should be either the initial condition, a vector of values `u(t_i)` for collocation methods, or a function returning one or the other.
 """
 function SciMLBase.BVProblem(sys::AbstractODESystem, args...; kwargs...)
     BVProblem{true}(sys, args...; kwargs...)
@@ -502,12 +513,13 @@ function SciMLBase.BVProblem{false}(sys::AbstractODESystem, args...; kwargs...)
     BVProblem{false, SciMLBase.FullSpecialize}(sys, args...; kwargs...)
 end
 
+# figure out what's going on when we try to set `sparse`?
+
 function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = [],
         tspan = get_tspan(sys),
         parammap = DiffEqBase.NullParameters();
         version = nothing, tgrad = false,
         jac = true, sparse = true, 
-        sparsity = true, 
         callback = nothing,
         check_length = true,
         warn_initialize_determined = true,
@@ -521,57 +533,27 @@ function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = []
 
     f, u0, p = process_SciMLProblem(ODEFunction{iip, specialize}, sys, u0map, parammap;
         t = tspan !== nothing ? tspan[1] : tspan,
-        check_length, warn_initialize_determined, eval_expression, eval_module, jac, sparse, sparsity, kwargs...)
-
-    # if jac
-    #     jac_gen = generate_jacobian(sys, dvs, ps;
-    #         simplify = simplify, sparse = sparse,
-    #         expression = Val{true},
-    #         expression_module = eval_module,
-    #         checkbounds = checkbounds, kwargs...)
-    #     jac_oop, jac_iip = eval_or_rgf.(jac_gen; eval_expression, eval_module)
-
-    #     _jac(u, p, t) = jac_oop(u, p, t)
-    #     _jac(J, u, p, t) = jac_iip(J, u, p, t)
-    #     _jac(u, p::Tuple{Vararg{Number}}, t) = jac_oop(u, p, t)
-    #     _jac(J, u, p::Tuple{Vararg{Number}}, t) = jac_iip(J, u, p, t)
-    #     _jac(u, p::Tuple, t) = jac_oop(u, p..., t)
-    #     _jac(J, u, p::Tuple, t) = jac_iip(J, u, p..., t)
-    #     _jac(u, p::MTKParameters, t) = jac_oop(u, p..., t)
-    #     _jac(J, u, p::MTKParameters, t) = jac_iip(J, u, p..., t)
-    # else
-    #     _jac = nothing
-    # end
-
-    # jac_prototype = if sparse
-    #     uElType = u0 === nothing ? Float64 : eltype(u0)
-    #     if jac
-    #         similar(calculate_jacobian(sys, sparse = sparse), uElType)
-    #     else
-    #         similar(jacobian_sparsity(sys), uElType)
-    #     end
-    # else
-    #     nothing
-    # end
-
-    # f.jac = _jac
-    # f.jac_prototype = jac_prototype
-    # f.sparsity = jac ? jacobian_sparsity(sys) : nothing
+        check_length, warn_initialize_determined, eval_expression, eval_module, jac, kwargs...)
 
     cbs = process_events(sys; callback, eval_expression, eval_module, kwargs...)
-
     kwargs = filter_kwargs(kwargs)
 
     kwargs1 = (;)
     if cbs !== nothing
         kwargs1 = merge(kwargs1, (callback = cbs,))
     end
+    
+    # Construct initial conditions
+    _u0 = prepare_initial_state(u0)
+    __u0 = if _u0 isa Function 
+        _u0(t_i)
+    end
 
     # Define the boundary conditions
     bc = if iip 
-        (residual, u, p, t) -> (residual = u[1] - u0)
+        (residual, u, p, t) -> (residual = u[1] - __u0)
     else
-        (u, p, t) -> (u[1] - u0)
+        (u, p, t) -> (u[1] - __u0)
     end
 
     return BVProblem{iip}(f, bc, u0, tspan, p; kwargs1..., kwargs...)

From 4affeac4b340a06c770373b498ec6b0e94c25a06 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Sun, 1 Dec 2024 17:35:05 -0500
Subject: [PATCH 003/111] up

---
 src/systems/diffeqs/abstractodesystem.jl | 4 +---
 1 file changed, 1 insertion(+), 3 deletions(-)

diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl
index 263a75d153..33bddece30 100644
--- a/src/systems/diffeqs/abstractodesystem.jl
+++ b/src/systems/diffeqs/abstractodesystem.jl
@@ -545,9 +545,7 @@ function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = []
     
     # Construct initial conditions
     _u0 = prepare_initial_state(u0)
-    __u0 = if _u0 isa Function 
-        _u0(t_i)
-    end
+    __u0 = _u0 isa Function ? _u0(tspan[1]) : _u0
 
     # Define the boundary conditions
     bc = if iip 

From a3429ea2b7d9e67898023eac1599e71c3d1b7bec Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Sun, 1 Dec 2024 17:42:39 -0500
Subject: [PATCH 004/111] up

---
 src/systems/diffeqs/abstractodesystem.jl | 9 +++------
 1 file changed, 3 insertions(+), 6 deletions(-)

diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl
index 33bddece30..84f40a01f7 100644
--- a/src/systems/diffeqs/abstractodesystem.jl
+++ b/src/systems/diffeqs/abstractodesystem.jl
@@ -513,8 +513,6 @@ function SciMLBase.BVProblem{false}(sys::AbstractODESystem, args...; kwargs...)
     BVProblem{false, SciMLBase.FullSpecialize}(sys, args...; kwargs...)
 end
 
-# figure out what's going on when we try to set `sparse`?
-
 function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = [],
         tspan = get_tspan(sys),
         parammap = DiffEqBase.NullParameters();
@@ -544,14 +542,13 @@ function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = []
     end
     
     # Construct initial conditions
-    _u0 = prepare_initial_state(u0)
-    __u0 = _u0 isa Function ? _u0(tspan[1]) : _u0
+    _u0 = u0 isa Function ? u0(tspan[1]) : u0
 
     # Define the boundary conditions
     bc = if iip 
-        (residual, u, p, t) -> (residual = u[1] - __u0)
+        (residual, u, p, t) -> (residual = u[1] - _u0)
     else
-        (u, p, t) -> (u[1] - __u0)
+        (u, p, t) -> (u[1] - _u0)
     end
 
     return BVProblem{iip}(f, bc, u0, tspan, p; kwargs1..., kwargs...)

From f751fbb51c49e0b35ee84a12c54ec2de919660a5 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Sun, 1 Dec 2024 21:22:08 -0500
Subject: [PATCH 005/111] up

---
 src/systems/diffeqs/abstractodesystem.jl | 18 ++------
 test/bvproblem.jl                        | 56 ++++++++++++++++++++++++
 2 files changed, 59 insertions(+), 15 deletions(-)
 create mode 100644 test/bvproblem.jl

diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl
index 84f40a01f7..112419dea8 100644
--- a/src/systems/diffeqs/abstractodesystem.jl
+++ b/src/systems/diffeqs/abstractodesystem.jl
@@ -469,17 +469,6 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
         initializeprobpmap = initializeprobpmap)
 end
 
-"""
-```julia
-SciMLBase.BVPFunction{iip}(sys::AbstractODESystem, u0map, tspan,
-                         parammap = DiffEqBase.NullParameters();
-                         version = nothing, tgrad = false,
-                         jac = true, sparse = true,
-                         simplify = false,
-                         kwargs...) where {iip}
-```
-"""
-
 """
 ```julia
 SciMLBase.BVProblem{iip}(sys::AbstractODESystem, u0map, tspan,
@@ -492,7 +481,7 @@ SciMLBase.BVProblem{iip}(sys::AbstractODESystem, u0map, tspan,
 
 Create a `BVProblem` from the [`ODESystem`](@ref). The arguments `dvs` and
 `ps` are used to set the order of the dependent variable and parameter vectors,
-respectively. `u0` should be either the initial condition, a vector of values `u(t_i)` for collocation methods, or a function returning one or the other.
+respectively. `u0map` should be used to specify the initial condition, or be a function returning an initial condition.
 """
 function SciMLBase.BVProblem(sys::AbstractODESystem, args...; kwargs...)
     BVProblem{true}(sys, args...; kwargs...)
@@ -517,7 +506,6 @@ function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = []
         tspan = get_tspan(sys),
         parammap = DiffEqBase.NullParameters();
         version = nothing, tgrad = false,
-        jac = true, sparse = true, 
         callback = nothing,
         check_length = true,
         warn_initialize_determined = true,
@@ -531,7 +519,7 @@ function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = []
 
     f, u0, p = process_SciMLProblem(ODEFunction{iip, specialize}, sys, u0map, parammap;
         t = tspan !== nothing ? tspan[1] : tspan,
-        check_length, warn_initialize_determined, eval_expression, eval_module, jac, kwargs...)
+        check_length, warn_initialize_determined, eval_expression, eval_module, kwargs...)
 
     cbs = process_events(sys; callback, eval_expression, eval_module, kwargs...)
     kwargs = filter_kwargs(kwargs)
@@ -546,7 +534,7 @@ function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = []
 
     # Define the boundary conditions
     bc = if iip 
-        (residual, u, p, t) -> (residual = u[1] - _u0)
+        (residual, u, p, t) -> residual .= u[1] - _u0
     else
         (u, p, t) -> (u[1] - _u0)
     end
diff --git a/test/bvproblem.jl b/test/bvproblem.jl
new file mode 100644
index 0000000000..90d41a96ad
--- /dev/null
+++ b/test/bvproblem.jl
@@ -0,0 +1,56 @@
+using BoundaryValueDiffEq, OrdinaryDiffEq
+using ModelingToolkit
+using ModelingToolkit: t_nounits as t, D_nounits as D
+
+@parameters σ = 10. ρ = 28 β = 8/3
+@variables x(t) = 1 y(t) = 0 z(t) = 0
+
+eqs = [D(x) ~ σ*(y-x),
+       D(y) ~ x*(ρ-z)-y,
+       D(z) ~ x*y - β*z]
+
+u0map = [:x => 1., :y => 0., :z => 0.]
+parammap = [:ρ => 28., :β => 8/3, :σ => 10.]
+tspan = (0., 10.)
+
+@mtkbuild lorenz = ODESystem(eqs, t)
+
+bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lorenz, u0map, tspan, parammap)
+sol = solve(bvp, MIRK4(), dt = 0.1);
+
+bvp2 = SciMLBase.BVProblem{false, SciMLBase.AutoSpecialize}(lorenz, u0map, tspan, parammap)
+sol2 = solve(bvp, MIRK4(), dt = 0.1);
+
+op = ODEProblem(lorenz, u0map, tspan, parammap)
+osol = solve(op)
+
+@test sol.u[end] ≈ osol.u[end]
+@test sol2.u[end] ≈ osol.u[end]
+@test sol.u[1] == [1., 0., 0.]
+@test sol2.u[1] == [1., 0., 0.]
+
+### Testing on pendulum
+
+@parameters g = 9.81 L = 1. 
+@variables θ(t) = π/2 
+
+eqs = [D(D(θ)) ~ -(g / L) * sin(θ)]
+
+@mtkbuild pend = ODESystem(eqs, t)
+
+u0map = [θ => π/2, D(θ) => π/2]
+parammap = [:L => 2.]
+tspan = (0., 10.)
+
+bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap)
+sol = solve(bvp, MIRK4(), dt = 0.05);
+
+bvp2 = SciMLBase.BVProblem{false, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap)
+sol2 = solve(bvp2, MIRK4(), dt = 0.05);
+
+osol = solve(pend)
+
+@test sol.u[end] ≈ osol.u[end]
+@test sol.u[1] == [π/2, π/2]
+@test sol2.u[end] ≈ osol.u[end]
+@test sol2.u[1] == [π/2, π/2]

From a9fdfd6a3a114c4df28e4781c6b667990c01bafc Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Tue, 3 Dec 2024 05:04:45 -0500
Subject: [PATCH 006/111] up

---
 src/systems/diffeqs/abstractodesystem.jl |  6 ++--
 test/bvproblem.jl                        | 44 ++++++++++++------------
 2 files changed, 25 insertions(+), 25 deletions(-)

diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl
index 112419dea8..b795bb981d 100644
--- a/src/systems/diffeqs/abstractodesystem.jl
+++ b/src/systems/diffeqs/abstractodesystem.jl
@@ -529,12 +529,12 @@ function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = []
         kwargs1 = merge(kwargs1, (callback = cbs,))
     end
     
-    # Construct initial conditions
+    # Construct initial conditions.
     _u0 = u0 isa Function ? u0(tspan[1]) : u0
 
-    # Define the boundary conditions
+    # Define the boundary conditions.
     bc = if iip 
-        (residual, u, p, t) -> residual .= u[1] - _u0
+        (residual, u, p, t) -> (residual .= u[1] - _u0)
     else
         (u, p, t) -> (u[1] - _u0)
     end
diff --git a/test/bvproblem.jl b/test/bvproblem.jl
index 90d41a96ad..4865c867b3 100644
--- a/test/bvproblem.jl
+++ b/test/bvproblem.jl
@@ -2,32 +2,31 @@ using BoundaryValueDiffEq, OrdinaryDiffEq
 using ModelingToolkit
 using ModelingToolkit: t_nounits as t, D_nounits as D
 
-@parameters σ = 10. ρ = 28 β = 8/3
-@variables x(t) = 1 y(t) = 0 z(t) = 0
+@parameters α = 7.5 β = 4. γ = 8. δ = 5. 
+@variables x(t) = 1. y(t) = 2. 
 
-eqs = [D(x) ~ σ*(y-x),
-       D(y) ~ x*(ρ-z)-y,
-       D(z) ~ x*y - β*z]
+eqs = [D(x) ~ α*x - β*x*y,
+       D(y) ~ -γ*y + δ*x*y]
 
-u0map = [:x => 1., :y => 0., :z => 0.]
-parammap = [:ρ => 28., :β => 8/3, :σ => 10.]
+u0map = [:x => 1., :y => 2.]
+parammap = [:α => 7.5, :β => 4, :γ => 8., :δ => 5.]
 tspan = (0., 10.)
 
-@mtkbuild lorenz = ODESystem(eqs, t)
+@mtkbuild lotkavolterra = ODESystem(eqs, t)
 
-bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lorenz, u0map, tspan, parammap)
-sol = solve(bvp, MIRK4(), dt = 0.1);
+bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lotkavolterra, u0map, tspan, parammap)
+sol = solve(bvp, MIRK4(), dt = 0.01);
 
-bvp2 = SciMLBase.BVProblem{false, SciMLBase.AutoSpecialize}(lorenz, u0map, tspan, parammap)
-sol2 = solve(bvp, MIRK4(), dt = 0.1);
+bvp2 = SciMLBase.BVProblem{false, SciMLBase.AutoSpecialize}(lotkavolterra, u0map, tspan, parammap)
+sol2 = solve(bvp, MIRK4(), dt = 0.01);
 
-op = ODEProblem(lorenz, u0map, tspan, parammap)
-osol = solve(op)
+op = ODEProblem(lotkavolterra, u0map, tspan, parammap)
+osol = solve(op, Vern9())
 
-@test sol.u[end] ≈ osol.u[end]
-@test sol2.u[end] ≈ osol.u[end]
-@test sol.u[1] == [1., 0., 0.]
-@test sol2.u[1] == [1., 0., 0.]
+@test isapprox(sol.u[end],osol.u[end]; atol = 0.001)
+@test isapprox(sol2.u[end],osol.u[end]; atol = 0.001)
+@test sol.u[1] == [1., 2.]
+@test sol2.u[1] == [1., 2.]
 
 ### Testing on pendulum
 
@@ -39,16 +38,17 @@ eqs = [D(D(θ)) ~ -(g / L) * sin(θ)]
 @mtkbuild pend = ODESystem(eqs, t)
 
 u0map = [θ => π/2, D(θ) => π/2]
-parammap = [:L => 2.]
+parammap = [:L => 2., :g => 9.81]
 tspan = (0., 10.)
 
 bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap)
-sol = solve(bvp, MIRK4(), dt = 0.05);
+sol = solve(bvp, MIRK4(), dt = 0.01);
 
 bvp2 = SciMLBase.BVProblem{false, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap)
-sol2 = solve(bvp2, MIRK4(), dt = 0.05);
+sol2 = solve(bvp2, MIRK4(), dt = 0.01);
 
-osol = solve(pend)
+op = ODEProblem(pend, u0map, tspan, parammap)
+osol = solve(op, Vern9())
 
 @test sol.u[end] ≈ osol.u[end]
 @test sol.u[1] == [π/2, π/2]

From a9f210691c1f38ef0f575eb89a939763b81c9c1b Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Tue, 3 Dec 2024 19:11:35 -0500
Subject: [PATCH 007/111] up

---
 src/systems/diffeqs/abstractodesystem.jl | 4 ++--
 test/bvproblem.jl                        | 8 ++++----
 test/runtests.jl                         | 1 +
 3 files changed, 7 insertions(+), 6 deletions(-)

diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl
index b795bb981d..8dac19f296 100644
--- a/src/systems/diffeqs/abstractodesystem.jl
+++ b/src/systems/diffeqs/abstractodesystem.jl
@@ -534,7 +534,7 @@ function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = []
 
     # Define the boundary conditions.
     bc = if iip 
-        (residual, u, p, t) -> (residual .= u[1] - _u0)
+        (residual, u, p, t) -> (residual .= u[1] .- _u0)
     else
         (u, p, t) -> (u[1] - _u0)
     end
@@ -542,7 +542,7 @@ function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = []
     return BVProblem{iip}(f, bc, u0, tspan, p; kwargs1..., kwargs...)
 end
 
-get_callback(prob::BVProblem) = prob.kwargs[:callback]
+get_callback(prob::BVProblem) = error("BVP solvers do not support callbacks.")
 
 """
 ```julia
diff --git a/test/bvproblem.jl b/test/bvproblem.jl
index 4865c867b3..7235638cf0 100644
--- a/test/bvproblem.jl
+++ b/test/bvproblem.jl
@@ -23,8 +23,8 @@ sol2 = solve(bvp, MIRK4(), dt = 0.01);
 op = ODEProblem(lotkavolterra, u0map, tspan, parammap)
 osol = solve(op, Vern9())
 
-@test isapprox(sol.u[end],osol.u[end]; atol = 0.001)
-@test isapprox(sol2.u[end],osol.u[end]; atol = 0.001)
+@test isapprox(sol.u[end],osol.u[end]; atol = 0.01)
+@test isapprox(sol2.u[end],osol.u[end]; atol = 0.01)
 @test sol.u[1] == [1., 2.]
 @test sol2.u[1] == [1., 2.]
 
@@ -50,7 +50,7 @@ sol2 = solve(bvp2, MIRK4(), dt = 0.01);
 op = ODEProblem(pend, u0map, tspan, parammap)
 osol = solve(op, Vern9())
 
-@test sol.u[end] ≈ osol.u[end]
+@test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
 @test sol.u[1] == [π/2, π/2]
-@test sol2.u[end] ≈ osol.u[end]
+@test isapprox(sol2.u[end], osol.u[end]; atol = 0.01)
 @test sol2.u[1] == [π/2, π/2]
diff --git a/test/runtests.jl b/test/runtests.jl
index 44846eed57..eaa87e6407 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -80,6 +80,7 @@ end
             @safetestset "NonlinearSystem Test" include("nonlinearsystem.jl")
             @safetestset "PDE Construction Test" include("pde.jl")
             @safetestset "JumpSystem Test" include("jumpsystem.jl")
+            @safetestset "BVProblem Test" include("bvproblem.jl")
             @safetestset "print_tree" include("print_tree.jl")
             @safetestset "Constraints Test" include("constraints.jl")
         end

From 18fdd5f79702627e30ce0acc971677a5cd1d303f Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Fri, 13 Dec 2024 10:04:59 +0900
Subject: [PATCH 008/111] up

---
 Project.toml      | 6 ++++--
 test/bvproblem.jl | 2 +-
 2 files changed, 5 insertions(+), 3 deletions(-)

diff --git a/Project.toml b/Project.toml
index 093c632c42..19063a0342 100644
--- a/Project.toml
+++ b/Project.toml
@@ -72,14 +72,15 @@ MTKBifurcationKitExt = "BifurcationKit"
 MTKChainRulesCoreExt = "ChainRulesCore"
 MTKDeepDiffsExt = "DeepDiffs"
 MTKHomotopyContinuationExt = "HomotopyContinuation"
-MTKLabelledArraysExt = "LabelledArrays"
 MTKInfiniteOptExt = "InfiniteOpt"
+MTKLabelledArraysExt = "LabelledArrays"
 
 [compat]
 AbstractTrees = "0.3, 0.4"
 ArrayInterface = "6, 7"
 BifurcationKit = "0.4"
 BlockArrays = "1.1"
+BoundaryValueDiffEq = "5.12.0"
 ChainRulesCore = "1"
 Combinatorics = "1"
 CommonSolve = "0.2.4"
@@ -145,6 +146,7 @@ julia = "1.9"
 [extras]
 AmplNLWriter = "7c4d4715-977e-5154-bfe0-e096adeac482"
 BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
+BoundaryValueDiffEq = "764a87c0-6b3e-53db-9096-fe964310641d"
 ControlSystemsBase = "aaaaaaaa-a6ca-5380-bf3e-84a91bcd477e"
 DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
 DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6"
@@ -174,4 +176,4 @@ Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
 Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
 
 [targets]
-test = ["AmplNLWriter", "BenchmarkTools", "ControlSystemsBase", "DataInterpolations", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "OrdinaryDiffEqCore", "REPL", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET"]
+test = ["AmplNLWriter", "BenchmarkTools", "BoundaryValueDiffEq", "ControlSystemsBase", "DataInterpolations", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "OrdinaryDiffEqCore", "REPL", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET"]
diff --git a/test/bvproblem.jl b/test/bvproblem.jl
index 7235638cf0..7787bd9c3e 100644
--- a/test/bvproblem.jl
+++ b/test/bvproblem.jl
@@ -44,7 +44,7 @@ tspan = (0., 10.)
 bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap)
 sol = solve(bvp, MIRK4(), dt = 0.01);
 
-bvp2 = SciMLBase.BVProblem{false, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap)
+bvp2 = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(pend, u0map, tspan, parammap)
 sol2 = solve(bvp2, MIRK4(), dt = 0.01);
 
 op = ODEProblem(pend, u0map, tspan, parammap)

From 9d65a3345ade530654a64ace2069ff24e8c67875 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Mon, 16 Dec 2024 23:43:35 +0800
Subject: [PATCH 009/111] fixing create_array

---
 Project.toml                             |  3 ++
 src/systems/diffeqs/abstractodesystem.jl | 10 ++++-
 test/bvproblem.jl                        | 53 +++++++++++++++---------
 3 files changed, 45 insertions(+), 21 deletions(-)

diff --git a/Project.toml b/Project.toml
index 19063a0342..ab854b80bd 100644
--- a/Project.toml
+++ b/Project.toml
@@ -7,6 +7,9 @@ version = "9.54.0"
 AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
 ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
 BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
+BoundaryValueDiffEq = "764a87c0-6b3e-53db-9096-fe964310641d"
+BoundaryValueDiffEqCore = "56b672f2-a5fe-4263-ab2d-da677488eb3a"
+BoundaryValueDiffEqMIRK = "1a22d4ce-7765-49ea-b6f2-13c8438986a6"
 Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
 CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
 Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl
index 8dac19f296..3a502d3183 100644
--- a/src/systems/diffeqs/abstractodesystem.jl
+++ b/src/systems/diffeqs/abstractodesystem.jl
@@ -539,11 +539,19 @@ function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = []
         (u, p, t) -> (u[1] - _u0)
     end
 
-    return BVProblem{iip}(f, bc, u0, tspan, p; kwargs1..., kwargs...)
+    return BVProblem{iip}(f, bc, _u0, tspan, p; kwargs1..., kwargs...)
 end
 
 get_callback(prob::BVProblem) = error("BVP solvers do not support callbacks.")
 
+@inline function create_array(::Type{Base.ReinterpretArray}, ::Nothing, ::Val{1}, ::Val{dims}, elems...) where dims
+    [elems...]
+end
+
+@inline function create_array(::Type{Base.ReinterpretArray}, T, ::Val{1}, ::Val{dims}, elems...) where dims
+    T[elems...]
+end
+
 """
 ```julia
 DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys),
diff --git a/test/bvproblem.jl b/test/bvproblem.jl
index 7787bd9c3e..7cdb7e1837 100644
--- a/test/bvproblem.jl
+++ b/test/bvproblem.jl
@@ -2,6 +2,8 @@ using BoundaryValueDiffEq, OrdinaryDiffEq
 using ModelingToolkit
 using ModelingToolkit: t_nounits as t, D_nounits as D
 
+solvers = [MIRK4, RadauIIa5, LobattoIIIa3]
+
 @parameters α = 7.5 β = 4. γ = 8. δ = 5. 
 @variables x(t) = 1. y(t) = 2. 
 
@@ -13,20 +15,26 @@ parammap = [:α => 7.5, :β => 4, :γ => 8., :δ => 5.]
 tspan = (0., 10.)
 
 @mtkbuild lotkavolterra = ODESystem(eqs, t)
+op = ODEProblem(lotkavolterra, u0map, tspan, parammap)
+osol = solve(op, Vern9())
 
-bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lotkavolterra, u0map, tspan, parammap)
-sol = solve(bvp, MIRK4(), dt = 0.01);
+bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lotkavolterra, u0map, tspan, parammap; eval_expression = true)
 
-bvp2 = SciMLBase.BVProblem{false, SciMLBase.AutoSpecialize}(lotkavolterra, u0map, tspan, parammap)
-sol2 = solve(bvp, MIRK4(), dt = 0.01);
+for solver in solvers
+    println("$solver")
+    sol = solve(bvp, solver(), dt = 0.01)
+    @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
+    @test sol.u[1] == [1., 2.]
+end
 
-op = ODEProblem(lotkavolterra, u0map, tspan, parammap)
-osol = solve(op, Vern9())
+# Test out of place
+bvp2 = SciMLBase.BVProblem{false, SciMLBase.AutoSpecialize}(lotkavolterra, u0map, tspan, parammap; eval_expression = true)
 
-@test isapprox(sol.u[end],osol.u[end]; atol = 0.01)
-@test isapprox(sol2.u[end],osol.u[end]; atol = 0.01)
-@test sol.u[1] == [1., 2.]
-@test sol2.u[1] == [1., 2.]
+for solver in solvers
+    sol = solve(bvp2, solver(), dt = 0.01)
+    @test isapprox(sol.u[end],osol.u[end]; atol = 0.01)
+    @test sol.u[1] == [1., 2.]
+end
 
 ### Testing on pendulum
 
@@ -38,19 +46,24 @@ eqs = [D(D(θ)) ~ -(g / L) * sin(θ)]
 @mtkbuild pend = ODESystem(eqs, t)
 
 u0map = [θ => π/2, D(θ) => π/2]
-parammap = [:L => 2., :g => 9.81]
+parammap = [:L => 1., :g => 9.81]
 tspan = (0., 10.)
 
+op = ODEProblem(pend, u0map, tspan, parammap)
+osol = solve(op, Vern9())
+
 bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap)
-sol = solve(bvp, MIRK4(), dt = 0.01);
+for solver in solvers
+    sol = solve(bvp2, solver(), dt = 0.01)
+    @test isapprox(sol.u[end],osol.u[end]; atol = 0.01)
+    @test sol.u[1] == [π/2, π/2]
+end
 
+# Test out-of-place
 bvp2 = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(pend, u0map, tspan, parammap)
-sol2 = solve(bvp2, MIRK4(), dt = 0.01);
-
-op = ODEProblem(pend, u0map, tspan, parammap)
-osol = solve(op, Vern9())
 
-@test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
-@test sol.u[1] == [π/2, π/2]
-@test isapprox(sol2.u[end], osol.u[end]; atol = 0.01)
-@test sol2.u[1] == [π/2, π/2]
+for solver in solvers
+    sol = solve(bvp2, solver(), dt = 0.01)
+    @test isapprox(sol.u[end],osol.u[end]; atol = 0.01)
+    @test sol.u[1] == [π/2, π/2]
+end

From 999ec308d58e32c51941743a6ba5a8f096c56ac2 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Mon, 16 Dec 2024 23:44:24 +0800
Subject: [PATCH 010/111] revert Project.toml

---
 Project.toml | 3 ---
 1 file changed, 3 deletions(-)

diff --git a/Project.toml b/Project.toml
index ab854b80bd..19063a0342 100644
--- a/Project.toml
+++ b/Project.toml
@@ -7,9 +7,6 @@ version = "9.54.0"
 AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
 ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
 BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
-BoundaryValueDiffEq = "764a87c0-6b3e-53db-9096-fe964310641d"
-BoundaryValueDiffEqCore = "56b672f2-a5fe-4263-ab2d-da677488eb3a"
-BoundaryValueDiffEqMIRK = "1a22d4ce-7765-49ea-b6f2-13c8438986a6"
 Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
 CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
 Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

From 9226ad687ff06cbc1dbbdcaa16ba3df85ef32d4e Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Mon, 16 Dec 2024 23:56:04 +0800
Subject: [PATCH 011/111] Up

---
 test/bvproblem.jl | 7 +++----
 1 file changed, 3 insertions(+), 4 deletions(-)

diff --git a/test/bvproblem.jl b/test/bvproblem.jl
index 7cdb7e1837..86e3722eec 100644
--- a/test/bvproblem.jl
+++ b/test/bvproblem.jl
@@ -21,7 +21,6 @@ osol = solve(op, Vern9())
 bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lotkavolterra, u0map, tspan, parammap; eval_expression = true)
 
 for solver in solvers
-    println("$solver")
     sol = solve(bvp, solver(), dt = 0.01)
     @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
     @test sol.u[1] == [1., 2.]
@@ -47,15 +46,15 @@ eqs = [D(D(θ)) ~ -(g / L) * sin(θ)]
 
 u0map = [θ => π/2, D(θ) => π/2]
 parammap = [:L => 1., :g => 9.81]
-tspan = (0., 10.)
+tspan = (0., 6.)
 
 op = ODEProblem(pend, u0map, tspan, parammap)
 osol = solve(op, Vern9())
 
 bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap)
 for solver in solvers
-    sol = solve(bvp2, solver(), dt = 0.01)
-    @test isapprox(sol.u[end],osol.u[end]; atol = 0.01)
+    sol = solve(bvp, solver(), dt = 0.01)
+    @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
     @test sol.u[1] == [π/2, π/2]
 end
 

From 67d8164c591b74a9b985ffbcaf1ef248fb0efaaa Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Tue, 17 Dec 2024 00:09:52 +0800
Subject: [PATCH 012/111] formatting

---
 src/systems/diffeqs/abstractodesystem.jl | 11 ++++---
 test/bvproblem.jl                        | 42 +++++++++++++-----------
 2 files changed, 28 insertions(+), 25 deletions(-)

diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl
index 87ee6e823d..06c83073bf 100644
--- a/src/systems/diffeqs/abstractodesystem.jl
+++ b/src/systems/diffeqs/abstractodesystem.jl
@@ -512,7 +512,6 @@ function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = []
         eval_expression = false,
         eval_module = @__MODULE__,
         kwargs...) where {iip, specialize}
-
     if !iscomplete(sys)
         error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `BVProblem`")
     end
@@ -528,12 +527,12 @@ function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = []
     if cbs !== nothing
         kwargs1 = merge(kwargs1, (callback = cbs,))
     end
-    
+
     # Construct initial conditions.
     _u0 = u0 isa Function ? u0(tspan[1]) : u0
 
     # Define the boundary conditions.
-    bc = if iip 
+    bc = if iip
         (residual, u, p, t) -> (residual .= u[1] .- _u0)
     else
         (u, p, t) -> (u[1] - _u0)
@@ -544,11 +543,13 @@ end
 
 get_callback(prob::BVProblem) = error("BVP solvers do not support callbacks.")
 
-@inline function create_array(::Type{Base.ReinterpretArray}, ::Nothing, ::Val{1}, ::Val{dims}, elems...) where dims
+@inline function create_array(::Type{Base.ReinterpretArray}, ::Nothing,
+        ::Val{1}, ::Val{dims}, elems...) where {dims}
     [elems...]
 end
 
-@inline function create_array(::Type{Base.ReinterpretArray}, T, ::Val{1}, ::Val{dims}, elems...) where dims
+@inline function create_array(
+        ::Type{Base.ReinterpretArray}, T, ::Val{1}, ::Val{dims}, elems...) where {dims}
     T[elems...]
 end
 
diff --git a/test/bvproblem.jl b/test/bvproblem.jl
index 86e3722eec..1072874917 100644
--- a/test/bvproblem.jl
+++ b/test/bvproblem.jl
@@ -4,49 +4,51 @@ using ModelingToolkit: t_nounits as t, D_nounits as D
 
 solvers = [MIRK4, RadauIIa5, LobattoIIIa3]
 
-@parameters α = 7.5 β = 4. γ = 8. δ = 5. 
-@variables x(t) = 1. y(t) = 2. 
+@parameters α=7.5 β=4.0 γ=8.0 δ=5.0
+@variables x(t)=1.0 y(t)=2.0
 
-eqs = [D(x) ~ α*x - β*x*y,
-       D(y) ~ -γ*y + δ*x*y]
+eqs = [D(x) ~ α * x - β * x * y,
+    D(y) ~ -γ * y + δ * x * y]
 
-u0map = [:x => 1., :y => 2.]
-parammap = [:α => 7.5, :β => 4, :γ => 8., :δ => 5.]
-tspan = (0., 10.)
+u0map = [:x => 1.0, :y => 2.0]
+parammap = [:α => 7.5, :β => 4, :γ => 8.0, :δ => 5.0]
+tspan = (0.0, 10.0)
 
 @mtkbuild lotkavolterra = ODESystem(eqs, t)
 op = ODEProblem(lotkavolterra, u0map, tspan, parammap)
 osol = solve(op, Vern9())
 
-bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lotkavolterra, u0map, tspan, parammap; eval_expression = true)
+bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(
+    lotkavolterra, u0map, tspan, parammap; eval_expression = true)
 
 for solver in solvers
     sol = solve(bvp, solver(), dt = 0.01)
     @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
-    @test sol.u[1] == [1., 2.]
+    @test sol.u[1] == [1.0, 2.0]
 end
 
 # Test out of place
-bvp2 = SciMLBase.BVProblem{false, SciMLBase.AutoSpecialize}(lotkavolterra, u0map, tspan, parammap; eval_expression = true)
+bvp2 = SciMLBase.BVProblem{false, SciMLBase.AutoSpecialize}(
+    lotkavolterra, u0map, tspan, parammap; eval_expression = true)
 
 for solver in solvers
     sol = solve(bvp2, solver(), dt = 0.01)
-    @test isapprox(sol.u[end],osol.u[end]; atol = 0.01)
-    @test sol.u[1] == [1., 2.]
+    @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
+    @test sol.u[1] == [1.0, 2.0]
 end
 
 ### Testing on pendulum
 
-@parameters g = 9.81 L = 1. 
-@variables θ(t) = π/2 
+@parameters g=9.81 L=1.0
+@variables θ(t) = π / 2
 
 eqs = [D(D(θ)) ~ -(g / L) * sin(θ)]
 
 @mtkbuild pend = ODESystem(eqs, t)
 
-u0map = [θ => π/2, D(θ) => π/2]
-parammap = [:L => 1., :g => 9.81]
-tspan = (0., 6.)
+u0map = [θ => π / 2, D(θ) => π / 2]
+parammap = [:L => 1.0, :g => 9.81]
+tspan = (0.0, 6.0)
 
 op = ODEProblem(pend, u0map, tspan, parammap)
 osol = solve(op, Vern9())
@@ -55,7 +57,7 @@ bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, pa
 for solver in solvers
     sol = solve(bvp, solver(), dt = 0.01)
     @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
-    @test sol.u[1] == [π/2, π/2]
+    @test sol.u[1] == [π / 2, π / 2]
 end
 
 # Test out-of-place
@@ -63,6 +65,6 @@ bvp2 = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(pend, u0map, tspan,
 
 for solver in solvers
     sol = solve(bvp2, solver(), dt = 0.01)
-    @test isapprox(sol.u[end],osol.u[end]; atol = 0.01)
-    @test sol.u[1] == [π/2, π/2]
+    @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
+    @test sol.u[1] == [π / 2, π / 2]
 end

From 25988f3bc6b6d66b008b6a40341f75e4547e574a Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Tue, 17 Dec 2024 10:52:37 +0800
Subject: [PATCH 013/111] up

---
 src/systems/diffeqs/abstractodesystem.jl | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl
index 06c83073bf..c84f5ff5be 100644
--- a/src/systems/diffeqs/abstractodesystem.jl
+++ b/src/systems/diffeqs/abstractodesystem.jl
@@ -543,13 +543,13 @@ end
 
 get_callback(prob::BVProblem) = error("BVP solvers do not support callbacks.")
 
-@inline function create_array(::Type{Base.ReinterpretArray}, ::Nothing,
+@inline function SciML.Code.create_array(::Type{<:Base.ReinterpretArray}, ::Nothing,
         ::Val{1}, ::Val{dims}, elems...) where {dims}
     [elems...]
 end
 
-@inline function create_array(
-        ::Type{Base.ReinterpretArray}, T, ::Val{1}, ::Val{dims}, elems...) where {dims}
+@inline function SciML.Code.create_array(
+        ::Type{<:Base.ReinterpretArray}, T, ::Val{1}, ::Val{dims}, elems...) where {dims}
     T[elems...]
 end
 

From bb28d4fe2a0d753221104186fe42808c7657f5d6 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Tue, 17 Dec 2024 10:53:22 +0800
Subject: [PATCH 014/111] up

---
 src/systems/diffeqs/abstractodesystem.jl | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl
index c84f5ff5be..eac2df16aa 100644
--- a/src/systems/diffeqs/abstractodesystem.jl
+++ b/src/systems/diffeqs/abstractodesystem.jl
@@ -543,12 +543,12 @@ end
 
 get_callback(prob::BVProblem) = error("BVP solvers do not support callbacks.")
 
-@inline function SciML.Code.create_array(::Type{<:Base.ReinterpretArray}, ::Nothing,
+@inline function SciMLBase.Code.create_array(::Type{<:Base.ReinterpretArray}, ::Nothing,
         ::Val{1}, ::Val{dims}, elems...) where {dims}
     [elems...]
 end
 
-@inline function SciML.Code.create_array(
+@inline function SciMLBase.Code.create_array(
         ::Type{<:Base.ReinterpretArray}, T, ::Val{1}, ::Val{dims}, elems...) where {dims}
     T[elems...]
 end

From b2bf7c05532bdf68d93e566d07c5e7444473dd7a Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Tue, 17 Dec 2024 11:00:30 +0800
Subject: [PATCH 015/111] fix

---
 src/systems/diffeqs/abstractodesystem.jl | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl
index eac2df16aa..3d143d49fb 100644
--- a/src/systems/diffeqs/abstractodesystem.jl
+++ b/src/systems/diffeqs/abstractodesystem.jl
@@ -543,12 +543,12 @@ end
 
 get_callback(prob::BVProblem) = error("BVP solvers do not support callbacks.")
 
-@inline function SciMLBase.Code.create_array(::Type{<:Base.ReinterpretArray}, ::Nothing,
+@inline function SymbolicUtils.Code.create_array(::Type{<:Base.ReinterpretArray}, ::Nothing,
         ::Val{1}, ::Val{dims}, elems...) where {dims}
     [elems...]
 end
 
-@inline function SciMLBase.Code.create_array(
+@inline function SymbolicUtils.Code.create_array(
         ::Type{<:Base.ReinterpretArray}, T, ::Val{1}, ::Val{dims}, elems...) where {dims}
     T[elems...]
 end

From 3751c2a92c282593ef52b67b727c7db635ea677e Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Fri, 20 Dec 2024 23:58:28 +0900
Subject: [PATCH 016/111] up

---
 src/systems/diffeqs/abstractodesystem.jl | 165 +++++++++++------------
 test/bvproblem.jl                        |   4 +-
 2 files changed, 83 insertions(+), 86 deletions(-)

diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl
index 3d143d49fb..d8ff71c324 100644
--- a/src/systems/diffeqs/abstractodesystem.jl
+++ b/src/systems/diffeqs/abstractodesystem.jl
@@ -469,90 +469,6 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
         initializeprobpmap = initializeprobpmap)
 end
 
-"""
-```julia
-SciMLBase.BVProblem{iip}(sys::AbstractODESystem, u0map, tspan,
-                         parammap = DiffEqBase.NullParameters();
-                         version = nothing, tgrad = false,
-                         jac = true, sparse = true,
-                         simplify = false,
-                         kwargs...) where {iip}
-```
-
-Create a `BVProblem` from the [`ODESystem`](@ref). The arguments `dvs` and
-`ps` are used to set the order of the dependent variable and parameter vectors,
-respectively. `u0map` should be used to specify the initial condition, or be a function returning an initial condition.
-"""
-function SciMLBase.BVProblem(sys::AbstractODESystem, args...; kwargs...)
-    BVProblem{true}(sys, args...; kwargs...)
-end
-
-function SciMLBase.BVProblem(sys::AbstractODESystem,
-        u0map::StaticArray,
-        args...;
-        kwargs...)
-    BVProblem{false, SciMLBase.FullSpecialize}(sys, u0map, args...; kwargs...)
-end
-
-function SciMLBase.BVProblem{true}(sys::AbstractODESystem, args...; kwargs...)
-    BVProblem{true, SciMLBase.AutoSpecialize}(sys, args...; kwargs...)
-end
-
-function SciMLBase.BVProblem{false}(sys::AbstractODESystem, args...; kwargs...)
-    BVProblem{false, SciMLBase.FullSpecialize}(sys, args...; kwargs...)
-end
-
-function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = [],
-        tspan = get_tspan(sys),
-        parammap = DiffEqBase.NullParameters();
-        version = nothing, tgrad = false,
-        callback = nothing,
-        check_length = true,
-        warn_initialize_determined = true,
-        eval_expression = false,
-        eval_module = @__MODULE__,
-        kwargs...) where {iip, specialize}
-    if !iscomplete(sys)
-        error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `BVProblem`")
-    end
-
-    f, u0, p = process_SciMLProblem(ODEFunction{iip, specialize}, sys, u0map, parammap;
-        t = tspan !== nothing ? tspan[1] : tspan,
-        check_length, warn_initialize_determined, eval_expression, eval_module, kwargs...)
-
-    cbs = process_events(sys; callback, eval_expression, eval_module, kwargs...)
-    kwargs = filter_kwargs(kwargs)
-
-    kwargs1 = (;)
-    if cbs !== nothing
-        kwargs1 = merge(kwargs1, (callback = cbs,))
-    end
-
-    # Construct initial conditions.
-    _u0 = u0 isa Function ? u0(tspan[1]) : u0
-
-    # Define the boundary conditions.
-    bc = if iip
-        (residual, u, p, t) -> (residual .= u[1] .- _u0)
-    else
-        (u, p, t) -> (u[1] - _u0)
-    end
-
-    return BVProblem{iip}(f, bc, _u0, tspan, p; kwargs1..., kwargs...)
-end
-
-get_callback(prob::BVProblem) = error("BVP solvers do not support callbacks.")
-
-@inline function SymbolicUtils.Code.create_array(::Type{<:Base.ReinterpretArray}, ::Nothing,
-        ::Val{1}, ::Val{dims}, elems...) where {dims}
-    [elems...]
-end
-
-@inline function SymbolicUtils.Code.create_array(
-        ::Type{<:Base.ReinterpretArray}, T, ::Val{1}, ::Val{dims}, elems...) where {dims}
-    T[elems...]
-end
-
 """
 ```julia
 DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys),
@@ -943,6 +859,87 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
 end
 get_callback(prob::ODEProblem) = prob.kwargs[:callback]
 
+"""
+```julia
+SciMLBase.BVProblem{iip}(sys::AbstractODESystem, u0map, tspan,
+                         parammap = DiffEqBase.NullParameters();
+                         version = nothing, tgrad = false,
+                         jac = true, sparse = true,
+                         simplify = false,
+                         kwargs...) where {iip}
+```
+
+Create a `BVProblem` from the [`ODESystem`](@ref). The arguments `dvs` and
+`ps` are used to set the order of the dependent variable and parameter vectors,
+respectively. `u0map` should be used to specify the initial condition.
+"""
+function SciMLBase.BVProblem(sys::AbstractODESystem, args...; kwargs...)
+    BVProblem{true}(sys, args...; kwargs...)
+end
+
+function SciMLBase.BVProblem(sys::AbstractODESystem,
+        u0map::StaticArray,
+        args...;
+        kwargs...)
+    BVProblem{false, SciMLBase.FullSpecialize}(sys, u0map, args...; kwargs...)
+end
+
+function SciMLBase.BVProblem{true}(sys::AbstractODESystem, args...; kwargs...)
+    BVProblem{true, SciMLBase.AutoSpecialize}(sys, args...; kwargs...)
+end
+
+function SciMLBase.BVProblem{false}(sys::AbstractODESystem, args...; kwargs...)
+    BVProblem{false, SciMLBase.FullSpecialize}(sys, args...; kwargs...)
+end
+
+function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = [],
+        tspan = get_tspan(sys),
+        parammap = DiffEqBase.NullParameters();
+        version = nothing, tgrad = false,
+        callback = nothing,
+        check_length = true,
+        warn_initialize_determined = true,
+        eval_expression = false,
+        eval_module = @__MODULE__,
+        kwargs...) where {iip, specialize}
+    if !iscomplete(sys)
+        error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `BVProblem`")
+    end
+
+    f, u0, p = process_SciMLProblem(ODEFunction{iip, specialize}, sys, u0map, parammap;
+        t = tspan !== nothing ? tspan[1] : tspan,
+        check_length, warn_initialize_determined, eval_expression, eval_module, kwargs...)
+
+    cbs = process_events(sys; callback, eval_expression, eval_module, kwargs...)
+    kwargs = filter_kwargs(kwargs)
+
+    kwargs1 = (;)
+    if cbs !== nothing
+        kwargs1 = merge(kwargs1, (callback = cbs,))
+    end
+
+    # Define the boundary conditions.
+    bc = if iip
+        (residual, u, p, t) -> (residual .= u[1] .- u0)
+    else
+        (u, p, t) -> (u[1] - u0)
+    end
+
+    return BVProblem{iip}(f, bc, u0, tspan, p; kwargs1..., kwargs...)
+end
+
+get_callback(prob::BVProblem) = error("BVP solvers do not support callbacks.")
+
+@inline function SymbolicUtils.Code.create_array(::Type{<:Base.ReinterpretArray}, ::Nothing,
+        ::Val{1}, ::Val{dims}, elems...) where {dims}
+    [elems...]
+end
+
+@inline function SymbolicUtils.Code.create_array(
+        ::Type{<:Base.ReinterpretArray}, T, ::Val{1}, ::Val{dims}, elems...) where {dims}
+    T[elems...]
+end
+
 """
 ```julia
 DiffEqBase.DAEProblem{iip}(sys::AbstractODESystem, du0map, u0map, tspan,
diff --git a/test/bvproblem.jl b/test/bvproblem.jl
index 1072874917..c5a302147d 100644
--- a/test/bvproblem.jl
+++ b/test/bvproblem.jl
@@ -10,8 +10,8 @@ solvers = [MIRK4, RadauIIa5, LobattoIIIa3]
 eqs = [D(x) ~ α * x - β * x * y,
     D(y) ~ -γ * y + δ * x * y]
 
-u0map = [:x => 1.0, :y => 2.0]
-parammap = [:α => 7.5, :β => 4, :γ => 8.0, :δ => 5.0]
+u0map = [x => 1.0, y => 2.0]
+parammap = [α => 7.5, β => 4, γ => 8.0, δ => 5.0]
 tspan = (0.0, 10.0)
 
 @mtkbuild lotkavolterra = ODESystem(eqs, t)

From ef1f089cbd493272e2b9297c2189aa2a6029d72f Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Wed, 8 Jan 2025 15:12:12 -0500
Subject: [PATCH 017/111] up

---
 src/systems/diffeqs/abstractodesystem.jl | 10 ----------
 1 file changed, 10 deletions(-)

diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl
index d8ff71c324..87ab83f10b 100644
--- a/src/systems/diffeqs/abstractodesystem.jl
+++ b/src/systems/diffeqs/abstractodesystem.jl
@@ -930,16 +930,6 @@ end
 
 get_callback(prob::BVProblem) = error("BVP solvers do not support callbacks.")
 
-@inline function SymbolicUtils.Code.create_array(::Type{<:Base.ReinterpretArray}, ::Nothing,
-        ::Val{1}, ::Val{dims}, elems...) where {dims}
-    [elems...]
-end
-
-@inline function SymbolicUtils.Code.create_array(
-        ::Type{<:Base.ReinterpretArray}, T, ::Val{1}, ::Val{dims}, elems...) where {dims}
-    T[elems...]
-end
-
 """
 ```julia
 DiffEqBase.DAEProblem{iip}(sys::AbstractODESystem, du0map, u0map, tspan,

From 2a25200bcb6a1ad57c4329b6383ba844eb73c2f7 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Thu, 9 Jan 2025 17:36:38 -0500
Subject: [PATCH 018/111] extend BVProblem for constraint equations

---
 src/systems/diffeqs/abstractodesystem.jl | 74 +++++++++++++++++--
 test/bvproblem.jl                        | 90 ++++++++++++++++++++++++
 2 files changed, 160 insertions(+), 4 deletions(-)

diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl
index f5c1c288d7..74b1bf7596 100644
--- a/src/systems/diffeqs/abstractodesystem.jl
+++ b/src/systems/diffeqs/abstractodesystem.jl
@@ -873,7 +873,7 @@ function SciMLBase.BVProblem(sys::AbstractODESystem,
         kwargs...)
     BVProblem{false, SciMLBase.FullSpecialize}(sys, u0map, args...; kwargs...)
 end
-
+o
 function SciMLBase.BVProblem{true}(sys::AbstractODESystem, args...; kwargs...)
     BVProblem{true, SciMLBase.AutoSpecialize}(sys, args...; kwargs...)
 end
@@ -908,11 +908,32 @@ function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = []
         kwargs1 = merge(kwargs1, (callback = cbs,))
     end
 
+    # Handle algebraic equations
+    stidxmap = Dict([v => i for (i, v) in enumerate(unknowns(sys))])
+    pidxmap = Dict([v => i for (i, v) in enumerate(parameters(sys))])
+    ns = length(stmap)
+    ne = length(get_alg_eqs(sys))
+    
     # Define the boundary conditions.
-    bc = if iip
-        (residual, u, p, t) -> (residual .= u[1] .- u0)
+    bc = if has_alg_eqs(sys)
+        if iip
+            (residual,u,p,t) -> begin
+                residual[1:ns] .= u[1] .- u0
+                residual[ns+1:ns+ne] .= sub_u_p_into_symeq.(get_alg_eqs(sys))
+            end
+        else
+            (u,p,t) -> begin
+                resid = vcat(u[1] - u0, sub_u_p_into_symeq.(get_alg_eqs(sys)))
+            end
+        end
     else
-        (u, p, t) -> (u[1] - u0)
+        if iip
+            (residual,u,p,t) -> begin
+                residual .= u[1] .- u0
+            end
+        else
+            (u,p,t) -> (u[1] - u0)
+        end
     end
 
     return BVProblem{iip}(f, bc, u0, tspan, p; kwargs1..., kwargs...)
@@ -920,6 +941,51 @@ end
 
 get_callback(prob::BVProblem) = error("BVP solvers do not support callbacks.")
 
+# Helper to create the dictionary that will substitute numeric values for u, p into the algebraic equations in the ODESystem. Used to construct the boundary condition function. 
+#   Take a system with variables x,y, parameters g
+#
+#   1 + x + y → 1 + u[1][1] + u[1][2]
+#   x(0.5) → u(0.5)[1]
+#   x(0.5)*g(0.5) → u(0.5)[1]*p[1]
+
+function sub_u_p_into_symeq(eq, u, p, stidxmap, pidxmap)
+    iv = ModelingToolkit.get_iv(sys)
+    eq = Symbolics.unwrap(eq)
+
+    stmap = Dict([st => u[1][i] for st => i in stidxmap])
+    pmap = Dict([pa => p[i] for pa => i in pidxmap])
+    eq = Symbolics.substitute(eq, merge(stmap, pmap))
+
+    csyms = []
+    # Find most nested calls, substitute those first.
+    while !isempty(find_callable_syms!(csyms, eq))
+        for sym in csyms 
+            t = arguments(sym)[1]
+            x = operation(sym)
+
+            if isparameter(x)
+                eq = Symbolics.substitute(eq, Dict(x(t) => p[pidxmap[x(iv)]]))
+            elseif isvariable(x)
+                eq = Symbolics.substitute(eq, Dict(x(t) => u(val)[stidxmap[x(iv)]]))
+            end
+        end
+        empty!(csyms)
+    end
+    eq
+end
+
+function find_callable_syms!(csyms, ex)
+    ex = Symbolics.unwrap(ex)
+
+    if iscall(ex)
+        operation(ex) isa Symbolic && (arguments(ex)[1] isa Symbolic) && push!(csyms, ex) # only add leaf nodes 
+        for arg in arguments(ex)
+            find_callable_syms!(csyms, arg)
+        end
+    end
+    csyms 
+end
+
 """
 ```julia
 DiffEqBase.DAEProblem{iip}(sys::AbstractODESystem, du0map, u0map, tspan,
diff --git a/test/bvproblem.jl b/test/bvproblem.jl
index c5a302147d..2d5535325a 100644
--- a/test/bvproblem.jl
+++ b/test/bvproblem.jl
@@ -2,6 +2,7 @@ using BoundaryValueDiffEq, OrdinaryDiffEq
 using ModelingToolkit
 using ModelingToolkit: t_nounits as t, D_nounits as D
 
+### Test Collocation solvers on simple problems 
 solvers = [MIRK4, RadauIIa5, LobattoIIIa3]
 
 @parameters α=7.5 β=4.0 γ=8.0 δ=5.0
@@ -68,3 +69,92 @@ for solver in solvers
     @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
     @test sol.u[1] == [π / 2, π / 2]
 end
+
+###################################################
+### TESTING ODESystem with Constraint Equations ###
+###################################################
+
+# Cartesian pendulum from the docs. Testing that initialization is satisfied.
+let
+    @parameters g
+    @variables x(t) y(t) [state_priority = 10] λ(t)
+    eqs = [D(D(x)) ~ λ * x
+           D(D(y)) ~ λ * y - g
+           x^2 + y^2 ~ 1]
+    @mtkbuild pend = ODESystem(eqs, t)
+
+    tspan = (0.0, 1.5)
+    u0map = [x => 1, y => 0]
+    parammap = [g => 1]
+    guesses = [λ => 1]
+
+    prob = ODEProblem(pend, u0map, tspan, pmap; guesses)
+    sol = solve(prob, Rodas5P())
+
+    bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; guesses)
+    
+    for solver in solvers
+        sol = solve(bvp, solver(), dt = 0.01)
+        @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
+        conditions = getfield.(equations(pend)[3:end], :rhs)
+        @test [sol[conditions][1]; sol[x][1] - 1; sol[y][1]] ≈ 0 
+    end
+
+    bvp2 = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(pend, u0map, tspan, parammap)
+    for solver in solvers
+        sol = solve(bvp, solver(), dt = 0.01)
+        @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
+        conditions = getfield.(equations(pend)[3:end], :rhs)
+        @test [sol[conditions][1]; sol[x][1] - 1; sol[y][1]] ≈ 0 
+    end
+end
+
+# Adding a midpoint boundary condition.
+let 
+    @parameters g
+    @variables x(..) y(t) [state_priority = 10] λ(t)
+    eqs = [D(D(x(t))) ~ λ * x(t)
+           D(D(y)) ~ λ * y - g
+           x(t)^2 + y^2 ~ 1
+           x(0.5) ~ 1]
+    @mtkbuild pend = ODESystem(eqs, t)
+
+    tspan = (0.0, 1.5)
+    u0map = [x(t) => 0.6, y => 0.8]
+    parammap = [g => 1]
+    guesses = [λ => 1]
+
+    prob = ODEProblem(pend, u0map, tspan, pmap; guesses, check_length = false)
+    sol = solve(prob, Rodas5P())
+
+    bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; guesses = guesses, check_length = false)
+    
+    for solver in solvers
+        sol = solve(bvp, solver(), dt = 0.01)
+        @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
+        conditions = getfield.(equations(pend)[3:end], :rhs)
+        @test [sol[conditions][1]; sol[x][1] - 1; sol[y][1]] ≈ 0 
+        @test sol.u[1] == [π / 2, π / 2]
+    end
+
+    bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; guesses)
+    
+    for solver in solvers
+        sol = solve(bvp, solver(), dt = 0.01)
+        @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
+        conditions = getfield.(equations(pend)[3:end], :rhs)
+        @test [sol[conditions][1]; sol[x][1] - 1; sol[y][1]] ≈ 0 
+    end
+
+    bvp2 = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(pend, u0map, tspan, parammap)
+    for solver in solvers
+        sol = solve(bvp, solver(), dt = 0.01)
+        @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
+        conditions = getfield.(equations(pend)[3:end], :rhs)
+        @test [sol[conditions][1]; sol[x][1] - 1; sol[y][1]] ≈ 0 
+    end
+end
+
+# Testing a more complicated case with multiple constraints.
+let
+end

From 50504abbcf0149d50bd6e858ae1c5f368f8d2835 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Sat, 11 Jan 2025 13:57:59 -0500
Subject: [PATCH 019/111] adding tests

---
 Project.toml                             |   6 +-
 src/systems/diffeqs/abstractodesystem.jl | 140 ++++++++++---
 test/bvproblem.jl                        | 242 ++++++++++++++---------
 3 files changed, 261 insertions(+), 127 deletions(-)

diff --git a/Project.toml b/Project.toml
index 98fd119f0c..b0bd4381b6 100644
--- a/Project.toml
+++ b/Project.toml
@@ -82,6 +82,7 @@ ArrayInterface = "6, 7"
 BifurcationKit = "0.4"
 BlockArrays = "1.1"
 BoundaryValueDiffEq = "5.12.0"
+BoundaryValueDiffEqAscher = "1.1.0"
 ChainRulesCore = "1"
 Combinatorics = "1"
 CommonSolve = "0.2.4"
@@ -140,8 +141,8 @@ SimpleNonlinearSolve = "0.1.0, 1, 2"
 SparseArrays = "1"
 SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
 StaticArrays = "0.10, 0.11, 0.12, 1.0"
-StochasticDiffEq = "6.72.1"
 StochasticDelayDiffEq = "1.8.1"
+StochasticDiffEq = "6.72.1"
 SymbolicIndexingInterface = "0.3.36"
 SymbolicUtils = "3.7"
 Symbolics = "6.19"
@@ -154,6 +155,7 @@ julia = "1.9"
 AmplNLWriter = "7c4d4715-977e-5154-bfe0-e096adeac482"
 BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
 BoundaryValueDiffEq = "764a87c0-6b3e-53db-9096-fe964310641d"
+BoundaryValueDiffEqAscher = "7227322d-7511-4e07-9247-ad6ff830280e"
 ControlSystemsBase = "aaaaaaaa-a6ca-5380-bf3e-84a91bcd477e"
 DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
 DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6"
@@ -185,4 +187,4 @@ Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
 Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
 
 [targets]
-test = ["AmplNLWriter", "BenchmarkTools", "ControlSystemsBase", "DataInterpolations", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "OrdinaryDiffEqCore", "OrdinaryDiffEqDefault", "REPL", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET", "OrdinaryDiffEqNonlinearSolve"]
+test = ["AmplNLWriter", "BenchmarkTools", "BoundaryValueDiffEq", "BoundaryValueDiffEqAscher", "ControlSystemsBase", "DataInterpolations", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "OrdinaryDiffEqCore", "OrdinaryDiffEqDefault", "REPL", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET", "OrdinaryDiffEqNonlinearSolve"]
diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl
index 74b1bf7596..3c45b1b2f8 100644
--- a/src/systems/diffeqs/abstractodesystem.jl
+++ b/src/systems/diffeqs/abstractodesystem.jl
@@ -853,15 +853,47 @@ get_callback(prob::ODEProblem) = prob.kwargs[:callback]
 ```julia
 SciMLBase.BVProblem{iip}(sys::AbstractODESystem, u0map, tspan,
                          parammap = DiffEqBase.NullParameters();
+                         constraints = nothing, guesses = nothing,
                          version = nothing, tgrad = false,
                          jac = true, sparse = true,
                          simplify = false,
                          kwargs...) where {iip}
 ```
 
-Create a `BVProblem` from the [`ODESystem`](@ref). The arguments `dvs` and
+Create a boundary value problem from the [`ODESystem`](@ref). The arguments `dvs` and
 `ps` are used to set the order of the dependent variable and parameter vectors,
-respectively. `u0map` should be used to specify the initial condition.
+respectively. `u0map` is used to specify fixed initial values for the states.
+
+Every variable must have either an initial guess supplied using `guesses` or 
+a fixed initial value specified using `u0map`.
+
+`constraints` are used to specify boundary conditions to the ODESystem in the
+form of equations. These values should specify values that state variables should
+take at specific points, as in `x(0.5) ~ 1`). More general constraints that 
+should hold over the entire solution, such as `x(t)^2 + y(t)^2`, should be 
+specified as one of the equations used to build the `ODESystem`. Below is an example.
+
+```julia
+    @parameters g
+    @variables x(..) y(t) [state_priority = 10] λ(t)
+    eqs = [D(D(x(t))) ~ λ * x(t)
+           D(D(y)) ~ λ * y - g
+           x(t)^2 + y^2 ~ 1]
+    @mtkbuild pend = ODESystem(eqs, t)
+
+    tspan = (0.0, 1.5)
+    u0map = [x(t) => 0.6, y => 0.8]
+    parammap = [g => 1]
+    guesses = [λ => 1]
+    constraints = [x(0.5) ~ 1]
+
+    bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; constraints, guesses, check_length = false)
+```
+
+If no `constraints` are specified, the problem will be treated as an initial value problem.
+
+If the `ODESystem` has algebraic equations like `x(t)^2 + y(t)^2`, the resulting 
+`BVProblem` must be solved using BVDAE solvers, such as Ascher.
 """
 function SciMLBase.BVProblem(sys::AbstractODESystem, args...; kwargs...)
     BVProblem{true}(sys, args...; kwargs...)
@@ -873,7 +905,7 @@ function SciMLBase.BVProblem(sys::AbstractODESystem,
         kwargs...)
     BVProblem{false, SciMLBase.FullSpecialize}(sys, u0map, args...; kwargs...)
 end
-o
+
 function SciMLBase.BVProblem{true}(sys::AbstractODESystem, args...; kwargs...)
     BVProblem{true, SciMLBase.AutoSpecialize}(sys, args...; kwargs...)
 end
@@ -885,6 +917,7 @@ end
 function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = [],
         tspan = get_tspan(sys),
         parammap = DiffEqBase.NullParameters();
+        constraints = nothing, guesses = nothing,
         version = nothing, tgrad = false,
         callback = nothing,
         check_length = true,
@@ -892,38 +925,63 @@ function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = []
         eval_expression = false,
         eval_module = @__MODULE__,
         kwargs...) where {iip, specialize}
+
     if !iscomplete(sys)
         error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `BVProblem`")
     end
+    !isnothing(callbacks) && error("BVP solvers do not support callbacks.")
 
-    f, u0, p = process_SciMLProblem(ODEFunction{iip, specialize}, sys, u0map, parammap;
-        t = tspan !== nothing ? tspan[1] : tspan,
-        check_length, warn_initialize_determined, eval_expression, eval_module, kwargs...)
+    iv = get_iv(sys)
+    constraintsts = nothing
+    constraintps = nothing
+    sts = unknowns(sys)
+    ps = parameters(sys)
 
-    cbs = process_events(sys; callback, eval_expression, eval_module, kwargs...)
-    kwargs = filter_kwargs(kwargs)
+    if !isnothing(constraints)
+        constraints isa Equation || 
+            constraints isa Vector{Equation} || 
+            error("Constraints must be specified as an equation or a vector of equations.")
 
-    kwargs1 = (;)
-    if cbs !== nothing
-        kwargs1 = merge(kwargs1, (callback = cbs,))
+        (length(constraints) + length(u0map) > length(sts)) && 
+            error("The BVProblem is overdetermined. The total number of conditions (# constraints + # fixed initial values given by u0map) cannot exceed the total number of states.")
+
+        constraintsts = OrderedSet()
+        constraintps = OrderedSet()
+
+        for eq in constraints
+            collect_vars!(constraintsts, constraintps, eq, iv)
+            validate_constraint_syms(eq, constraintsts, constraintps, Set(sts), Set(ps), iv)
+            empty!(constraintsts)
+            empty!(constraintps)
+        end
     end
 
-    # Handle algebraic equations
-    stidxmap = Dict([v => i for (i, v) in enumerate(unknowns(sys))])
-    pidxmap = Dict([v => i for (i, v) in enumerate(parameters(sys))])
-    ns = length(stmap)
-    ne = length(get_alg_eqs(sys))
+    f, u0, p = process_SciMLProblem(ODEFunction{iip, specialize}, sys, u0map, parammap;
+        t = tspan !== nothing ? tspan[1] : tspan, guesses,
+        check_length, warn_initialize_determined, eval_expression, eval_module, kwargs...)
+
+    stidxmap = Dict([v => i for (i, v) in enumerate(sts)])
+    pidxmap = Dict([v => i for (i, v) in enumerate(ps)])
+
+    # Indices of states that have initial constraints.
+    u0i = has_alg_eqs(sys) ? collect(1:length(sts)) : [stidxmap[k] for k in keys(u0map)]
+    ni = length(u0i)
     
-    # Define the boundary conditions.
-    bc = if has_alg_eqs(sys)
+    bc = if !isnothing(constraints)
+        ne = length(constraints)
         if iip
             (residual,u,p,t) -> begin
-                residual[1:ns] .= u[1] .- u0
-                residual[ns+1:ns+ne] .= sub_u_p_into_symeq.(get_alg_eqs(sys))
+                residual[1:ni] .= u[1][u0i] .- u0[u0i]
+                residual[ni+1:ni+ne] .= map(constraints) do cons
+                    sub_u_p_into_symeq(cons.rhs - cons.lhs, u, p, stidxmap, pidxmap, iv, tspan)
+                end
             end
         else
             (u,p,t) -> begin
-                resid = vcat(u[1] - u0, sub_u_p_into_symeq.(get_alg_eqs(sys)))
+                consresid = map(constraints) do cons
+                    sub_u_p_into_symeq(cons.rhs-cons.lhs, u, p, stidxmap, pidxmap, iv, tspan)
+                end
+                resid = vcat(u[1][u0i] - u0[u0i], consresid)
             end
         end
     else
@@ -941,32 +999,54 @@ end
 
 get_callback(prob::BVProblem) = error("BVP solvers do not support callbacks.")
 
-# Helper to create the dictionary that will substitute numeric values for u, p into the algebraic equations in the ODESystem. Used to construct the boundary condition function. 
+# Validate that all the variables in the BVP constraints are well-formed states or parameters.
+function validate_constraint_syms(eq, constraintsts, constraintps, sts, ps, iv) 
+    ModelingToolkit.check_variables(constraintsts)
+    ModelingToolkit.check_parameters(constraintps)
+
+    for var in constraintsts
+        if arguments(var) == iv
+            var ∈ sts || error("Constraint equation $eq contains a variable $var that is not a variable of the ODESystem.")
+            error("Constraint equation $eq contains a variable $var that does not have a specified argument. Such equations should be specified as algebraic equations to the ODESystem rather than a boundary constraints.")
+        else
+            operation(var)(iv) ∈ sts || error("Constraint equation $eq contains a variable $(operation(var)) that is not a variable of the ODESystem.")
+        end
+    end
+
+    for var in constraintps
+        if !iscall(var)
+            var ∈ ps || error("Constraint equation $eq contains a parameter $var that is not a parameter of the ODESystem.")
+        else
+            operation(var) ∈ ps || error("Constraint equations contain a parameter $var that is not a parameter of the ODESystem.")
+        end
+    end
+end
+
+# Helper to substitute numeric values for u, p into the algebraic equations in the ODESystem. Used to construct the boundary condition function. 
 #   Take a system with variables x,y, parameters g
 #
-#   1 + x + y → 1 + u[1][1] + u[1][2]
+#   1 + x(0) + y(0) → 1 + u[1][1] + u[1][2]
 #   x(0.5) → u(0.5)[1]
 #   x(0.5)*g(0.5) → u(0.5)[1]*p[1]
-
-function sub_u_p_into_symeq(eq, u, p, stidxmap, pidxmap)
-    iv = ModelingToolkit.get_iv(sys)
+function sub_u_p_into_symeq(eq, u, p, stidxmap, pidxmap, iv, tspan)
     eq = Symbolics.unwrap(eq)
 
-    stmap = Dict([st => u[1][i] for st => i in stidxmap])
-    pmap = Dict([pa => p[i] for pa => i in pidxmap])
+    stmap = Dict([st => u[1][i] for (st, i) in stidxmap])
+    pmap = Dict([pa => p[i] for (pa, i) in pidxmap])
     eq = Symbolics.substitute(eq, merge(stmap, pmap))
 
     csyms = []
     # Find most nested calls, substitute those first.
     while !isempty(find_callable_syms!(csyms, eq))
         for sym in csyms 
-            t = arguments(sym)[1]
             x = operation(sym)
+            t = arguments(sym)[1]
+            prog = (tspan[2] - tspan[1])/(t - tspan[1]) # 1 / the % of the timespan elapsed
 
             if isparameter(x)
                 eq = Symbolics.substitute(eq, Dict(x(t) => p[pidxmap[x(iv)]]))
             elseif isvariable(x)
-                eq = Symbolics.substitute(eq, Dict(x(t) => u(val)[stidxmap[x(iv)]]))
+                eq = Symbolics.substitute(eq, Dict(x(t) => u[Int(end ÷ prog)][stidxmap[x(iv)]]))
             end
         end
         empty!(csyms)
diff --git a/test/bvproblem.jl b/test/bvproblem.jl
index 2d5535325a..6432c5ae02 100644
--- a/test/bvproblem.jl
+++ b/test/bvproblem.jl
@@ -1,80 +1,85 @@
-using BoundaryValueDiffEq, OrdinaryDiffEq
+using BoundaryValueDiffEq, OrdinaryDiffEq, BoundaryValueDiffEqAscher
 using ModelingToolkit
 using ModelingToolkit: t_nounits as t, D_nounits as D
 
 ### Test Collocation solvers on simple problems 
 solvers = [MIRK4, RadauIIa5, LobattoIIIa3]
+daesolvers = [Ascher2, Ascher4, Ascher6]
 
-@parameters α=7.5 β=4.0 γ=8.0 δ=5.0
-@variables x(t)=1.0 y(t)=2.0
-
-eqs = [D(x) ~ α * x - β * x * y,
-    D(y) ~ -γ * y + δ * x * y]
-
-u0map = [x => 1.0, y => 2.0]
-parammap = [α => 7.5, β => 4, γ => 8.0, δ => 5.0]
-tspan = (0.0, 10.0)
-
-@mtkbuild lotkavolterra = ODESystem(eqs, t)
-op = ODEProblem(lotkavolterra, u0map, tspan, parammap)
-osol = solve(op, Vern9())
-
-bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(
-    lotkavolterra, u0map, tspan, parammap; eval_expression = true)
-
-for solver in solvers
-    sol = solve(bvp, solver(), dt = 0.01)
-    @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
-    @test sol.u[1] == [1.0, 2.0]
-end
-
-# Test out of place
-bvp2 = SciMLBase.BVProblem{false, SciMLBase.AutoSpecialize}(
-    lotkavolterra, u0map, tspan, parammap; eval_expression = true)
-
-for solver in solvers
-    sol = solve(bvp2, solver(), dt = 0.01)
-    @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
-    @test sol.u[1] == [1.0, 2.0]
+let
+     @parameters α=7.5 β=4.0 γ=8.0 δ=5.0
+     @variables x(t)=1.0 y(t)=2.0
+     
+     eqs = [D(x) ~ α * x - β * x * y,
+         D(y) ~ -γ * y + δ * x * y]
+     
+     u0map = [x => 1.0, y => 2.0]
+     parammap = [α => 7.5, β => 4, γ => 8.0, δ => 5.0]
+     tspan = (0.0, 10.0)
+     
+     @mtkbuild lotkavolterra = ODESystem(eqs, t)
+     op = ODEProblem(lotkavolterra, u0map, tspan, parammap)
+     osol = solve(op, Vern9())
+     
+     bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(
+         lotkavolterra, u0map, tspan, parammap; eval_expression = true)
+     
+     for solver in solvers
+         sol = solve(bvp, solver(), dt = 0.01)
+         @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
+         @test sol.u[1] == [1.0, 2.0]
+     end
+     
+     # Test out of place
+     bvp2 = SciMLBase.BVProblem{false, SciMLBase.AutoSpecialize}(
+         lotkavolterra, u0map, tspan, parammap; eval_expression = true)
+     
+     for solver in solvers
+         sol = solve(bvp2, solver(), dt = 0.01)
+         @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
+         @test sol.u[1] == [1.0, 2.0]
+     end
 end
 
 ### Testing on pendulum
-
-@parameters g=9.81 L=1.0
-@variables θ(t) = π / 2
-
-eqs = [D(D(θ)) ~ -(g / L) * sin(θ)]
-
-@mtkbuild pend = ODESystem(eqs, t)
-
-u0map = [θ => π / 2, D(θ) => π / 2]
-parammap = [:L => 1.0, :g => 9.81]
-tspan = (0.0, 6.0)
-
-op = ODEProblem(pend, u0map, tspan, parammap)
-osol = solve(op, Vern9())
-
-bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap)
-for solver in solvers
-    sol = solve(bvp, solver(), dt = 0.01)
-    @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
-    @test sol.u[1] == [π / 2, π / 2]
-end
-
-# Test out-of-place
-bvp2 = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(pend, u0map, tspan, parammap)
-
-for solver in solvers
-    sol = solve(bvp2, solver(), dt = 0.01)
-    @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
-    @test sol.u[1] == [π / 2, π / 2]
+let
+     @parameters g=9.81 L=1.0
+     @variables θ(t) = π / 2
+     
+     eqs = [D(D(θ)) ~ -(g / L) * sin(θ)]
+     
+     @mtkbuild pend = ODESystem(eqs, t)
+     
+     u0map = [θ => π / 2, D(θ) => π / 2]
+     parammap = [:L => 1.0, :g => 9.81]
+     tspan = (0.0, 6.0)
+     
+     op = ODEProblem(pend, u0map, tspan, parammap)
+     osol = solve(op, Vern9())
+     
+     bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap)
+     for solver in solvers
+         sol = solve(bvp, solver(), dt = 0.01)
+         @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
+         @test sol.u[1] == [π / 2, π / 2]
+     end
+     
+     # Test out-of-place
+     bvp2 = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(pend, u0map, tspan, parammap)
+     
+     for solver in solvers
+         sol = solve(bvp2, solver(), dt = 0.01)
+         @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
+         @test sol.u[1] == [π / 2, π / 2]
+     end
 end
 
 ###################################################
-### TESTING ODESystem with Constraint Equations ###
+### ODESystem with Constraint Equations, DAEs with constraints ###
 ###################################################
 
-# Cartesian pendulum from the docs. Testing that initialization is satisfied.
+# Cartesian pendulum from the docs.
+# DAE IVP solved using BoundaryValueDiffEq solvers.
 let
     @parameters g
     @variables x(t) y(t) [state_priority = 10] λ(t)
@@ -109,14 +114,74 @@ let
     end
 end
 
-# Adding a midpoint boundary condition.
+function test_solvers(solvers, prob, u0map, constraints, equations = []; dt = 0.01)
+    for solver in solvers
+        sol = solve(bvp, solver(); dt)
+
+        for (k, v) in u0map
+            @test sol[k][1] == v
+        end
+         
+        for cons in constraints
+            @test sol[cons.rhs - cons.lhs] ≈ 0
+        end
+
+        for eq in equations
+            @test sol[eq] ≈ 0
+        end
+    end
+end
+
+# Simple ODESystem with BVP constraints.
+let
+    @parameters α=7.5 β=4.0 γ=8.0 δ=5.0
+    @variables x(..) y(t)
+    
+    eqs = [D(x) ~ α * x - β * x * y,
+        D(y) ~ -γ * y + δ * x * y]
+    
+    u0map = [y => 2.0]
+    parammap = [α => 7.5, β => 4, γ => 8.0, δ => 5.0]
+    tspan = (0.0, 10.0)
+    guesses = [x => 1.0]
+
+    @mtkbuild lotkavolterra = ODESystem(eqs, t)
+    op = ODEProblem(lotkavolterra, u0map, tspan, parammap, guesses = guesses)
+
+    constraints = [x(6.) ~ 3]
+    bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; guesses, constraints)
+    test_solvers(solvers, bvp, u0map, constraints)
+
+    # Testing that more complicated constraints give correct solutions.
+    constraints = [y(2.) + x(8.) ~ 12]
+    bvp = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(pend, u0map, tspan, parammap; guesses, constraints)
+    test_solvers(solvers, bvp, u0map, constraints)
+
+    constraints = [α * β - x(6.) ~ 24]
+    bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; guesses, constraints)
+    test_solvers(solvers, bvp, u0map, constraints)
+
+    # Testing that errors are properly thrown when malformed constraints are given.
+    @variables bad(..)
+    constraints = [x(1.) + bad(3.) ~ 10]
+    @test_throws Exception bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; guesses, constraints)
+
+    constraints = [x(t) + y(t) ~ 3]
+    @test_throws Exception bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; guesses, constraints)
+
+    @parameters bad2
+    constraints = [bad2 + x(0.) ~ 3]
+    @test_throws Exception bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; guesses, constraints)
+end
+
+# Adding a midpoint boundary constraint.
+# Solve using BVDAE solvers.
 let 
     @parameters g
     @variables x(..) y(t) [state_priority = 10] λ(t)
     eqs = [D(D(x(t))) ~ λ * x(t)
            D(D(y)) ~ λ * y - g
-           x(t)^2 + y^2 ~ 1
-           x(0.5) ~ 1]
+           x(t)^2 + y^2 ~ 1]
     @mtkbuild pend = ODESystem(eqs, t)
 
     tspan = (0.0, 1.5)
@@ -124,37 +189,24 @@ let
     parammap = [g => 1]
     guesses = [λ => 1]
 
-    prob = ODEProblem(pend, u0map, tspan, pmap; guesses, check_length = false)
-    sol = solve(prob, Rodas5P())
+    constraints = [x(0.5) ~ 1]
+    bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; constraints, guesses, check_length = false)
+    test_solvers(daesolvers, bvp, u0map, constraints)
 
-    bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; guesses = guesses, check_length = false)
-    
-    for solver in solvers
-        sol = solve(bvp, solver(), dt = 0.01)
-        @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
-        conditions = getfield.(equations(pend)[3:end], :rhs)
-        @test [sol[conditions][1]; sol[x][1] - 1; sol[y][1]] ≈ 0 
-        @test sol.u[1] == [π / 2, π / 2]
-    end
+    bvp2 = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(pend, u0map, tspan, parammap)
+    test_solvers(daesolvers, bvp2, u0map, constraints, get_alg_eqs(pend))
 
-    bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; guesses)
-    
-    for solver in solvers
-        sol = solve(bvp, solver(), dt = 0.01)
-        @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
-        conditions = getfield.(equations(pend)[3:end], :rhs)
-        @test [sol[conditions][1]; sol[x][1] - 1; sol[y][1]] ≈ 0 
-    end
+    # More complicated constraints.
+    u0map = [x(t) => 0.6]
+    guesses = [λ => 1, y(t) => 0.8]
 
-    bvp2 = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(pend, u0map, tspan, parammap)
-    for solver in solvers
-        sol = solve(bvp, solver(), dt = 0.01)
-        @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
-        conditions = getfield.(equations(pend)[3:end], :rhs)
-        @test [sol[conditions][1]; sol[x][1] - 1; sol[y][1]] ≈ 0 
-    end
-end
+    constraints = [x(0.5) ~ 1, 
+                   x(0.3)^3 + y(0.6)^2 ~ 0.5]
+    bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; constraints, guesses, check_length = false)
+    test_solvers(daesolvers, bvp, u0map, constraints, get_alg_eqs(pend))
 
-# Testing a more complicated case with multiple constraints.
-let
+    constraints = [x(0.4) * g ~ y(0.2),
+                   y(0.7) ~ 0.3]
+    bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; constraints, guesses, check_length = false)
+    test_solvers(daesolvers, bvp, u0map, constraints, get_alg_eqs(pend))
 end

From 5d082ab05dc674b005ef1a724aa7273a712c2b66 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Sat, 11 Jan 2025 16:20:55 -0500
Subject: [PATCH 020/111] up

---
 src/systems/diffeqs/abstractodesystem.jl |  8 +++----
 test/bvproblem.jl                        | 27 ++++++++++++------------
 2 files changed, 18 insertions(+), 17 deletions(-)

diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl
index 3c45b1b2f8..99b34ec1a8 100644
--- a/src/systems/diffeqs/abstractodesystem.jl
+++ b/src/systems/diffeqs/abstractodesystem.jl
@@ -917,7 +917,7 @@ end
 function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = [],
         tspan = get_tspan(sys),
         parammap = DiffEqBase.NullParameters();
-        constraints = nothing, guesses = nothing,
+        constraints = nothing, guesses = Dict(),
         version = nothing, tgrad = false,
         callback = nothing,
         check_length = true,
@@ -929,7 +929,7 @@ function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = []
     if !iscomplete(sys)
         error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `BVProblem`")
     end
-    !isnothing(callbacks) && error("BVP solvers do not support callbacks.")
+    !isnothing(callback) && error("BVP solvers do not support callbacks.")
 
     iv = get_iv(sys)
     constraintsts = nothing
@@ -964,7 +964,7 @@ function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = []
     pidxmap = Dict([v => i for (i, v) in enumerate(ps)])
 
     # Indices of states that have initial constraints.
-    u0i = has_alg_eqs(sys) ? collect(1:length(sts)) : [stidxmap[k] for k in keys(u0map)]
+    u0i = has_alg_eqs(sys) ? collect(1:length(sts)) : [stidxmap[k] for (k,v) in u0map]
     ni = length(u0i)
     
     bc = if !isnothing(constraints)
@@ -994,7 +994,7 @@ function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = []
         end
     end
 
-    return BVProblem{iip}(f, bc, u0, tspan, p; kwargs1..., kwargs...)
+    return BVProblem{iip}(f, bc, u0, tspan, p; kwargs...)
 end
 
 get_callback(prob::BVProblem) = error("BVP solvers do not support callbacks.")
diff --git a/test/bvproblem.jl b/test/bvproblem.jl
index 6432c5ae02..e864433f3c 100644
--- a/test/bvproblem.jl
+++ b/test/bvproblem.jl
@@ -44,13 +44,14 @@ end
 ### Testing on pendulum
 let
      @parameters g=9.81 L=1.0
-     @variables θ(t) = π / 2
+     @variables θ(t) = π / 2 θ_t(t)
      
-     eqs = [D(D(θ)) ~ -(g / L) * sin(θ)]
+     eqs = [D(θ) ~ θ_t
+            D(θ_t) ~ -(g / L) * sin(θ)]
      
      @mtkbuild pend = ODESystem(eqs, t)
      
-     u0map = [θ => π / 2, D(θ) => π / 2]
+     u0map = [θ => π / 2, θ_t => π / 2]
      parammap = [:L => 1.0, :g => 9.81]
      tspan = (0.0, 6.0)
      
@@ -74,9 +75,9 @@ let
      end
 end
 
-###################################################
-### ODESystem with Constraint Equations, DAEs with constraints ###
-###################################################
+##################################################################
+### ODESystem with constraint equations, DAEs with constraints ###
+##################################################################
 
 # Cartesian pendulum from the docs.
 # DAE IVP solved using BoundaryValueDiffEq solvers.
@@ -90,19 +91,19 @@ let
 
     tspan = (0.0, 1.5)
     u0map = [x => 1, y => 0]
-    parammap = [g => 1]
-    guesses = [λ => 1]
+    pmap = [g => 1]
+    guess = [λ => 1]
 
-    prob = ODEProblem(pend, u0map, tspan, pmap; guesses)
-    sol = solve(prob, Rodas5P())
+    prob = ODEProblem(pend, u0map, tspan, pmap; guesses = guess)
+    osol = solve(prob, Rodas5P())
 
-    bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; guesses)
+    bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; guesses = guess)
     
     for solver in solvers
-        sol = solve(bvp, solver(), dt = 0.01)
+        sol = solve(bvp, solver(), dt = 0.001)
         @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
         conditions = getfield.(equations(pend)[3:end], :rhs)
-        @test [sol[conditions][1]; sol[x][1] - 1; sol[y][1]] ≈ 0 
+        @test isapprox([sol[conditions][1]; sol[x][1] - 1; sol[y][1]], zeros(5), atol = 0.001)
     end
 
     bvp2 = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(pend, u0map, tspan, parammap)

From b83e003babe1347439d05ea87a8b90995f1dcc82 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Tue, 14 Jan 2025 00:59:31 -0500
Subject: [PATCH 021/111] refactor the bc creation function

---
 src/systems/diffeqs/abstractodesystem.jl | 189 ++++++++++---------
 test/bvproblem.jl                        | 230 ++++++++++++++++-------
 2 files changed, 264 insertions(+), 155 deletions(-)

diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl
index 99b34ec1a8..25347a17cd 100644
--- a/src/systems/diffeqs/abstractodesystem.jl
+++ b/src/systems/diffeqs/abstractodesystem.jl
@@ -931,68 +931,60 @@ function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = []
     end
     !isnothing(callback) && error("BVP solvers do not support callbacks.")
 
-    iv = get_iv(sys)
+    has_alg_eqs(sys) && error("The BVProblem currently does not support ODESystems with algebraic equations.") # Remove this when the BVDAE solvers get updated, the codegen should work when it does.
+
     constraintsts = nothing
     constraintps = nothing
     sts = unknowns(sys)
     ps = parameters(sys)
 
-    if !isnothing(constraints)
+    # Constraint validation
+    f_cons = if !isnothing(constraints)
         constraints isa Equation || 
             constraints isa Vector{Equation} || 
             error("Constraints must be specified as an equation or a vector of equations.")
 
         (length(constraints) + length(u0map) > length(sts)) && 
-            error("The BVProblem is overdetermined. The total number of conditions (# constraints + # fixed initial values given by u0map) cannot exceed the total number of states.")
-
-        constraintsts = OrderedSet()
-        constraintps = OrderedSet()
-
-        for eq in constraints
-            collect_vars!(constraintsts, constraintps, eq, iv)
-            validate_constraint_syms(eq, constraintsts, constraintps, Set(sts), Set(ps), iv)
-            empty!(constraintsts)
-            empty!(constraintps)
-        end
+        error("The BVProblem is overdetermined. The total number of conditions (# constraints + # fixed initial values given by u0map) cannot exceed the total number of states.")
     end
 
-    f, u0, p = process_SciMLProblem(ODEFunction{iip, specialize}, sys, u0map, parammap;
+    # ODESystems without algebraic equations should use both fixed values + guesses
+    # for initialization.
+    _u0map = has_alg_eqs(sys) ? u0map : merge(Dict(u0map), Dict(guesses)) 
+    f, u0, p = process_SciMLProblem(ODEFunction{iip, specialize}, sys, _u0map, parammap;
         t = tspan !== nothing ? tspan[1] : tspan, guesses,
         check_length, warn_initialize_determined, eval_expression, eval_module, kwargs...)
 
     stidxmap = Dict([v => i for (i, v) in enumerate(sts)])
-    pidxmap = Dict([v => i for (i, v) in enumerate(ps)])
-
-    # Indices of states that have initial constraints.
-    u0i = has_alg_eqs(sys) ? collect(1:length(sts)) : [stidxmap[k] for (k,v) in u0map]
-    ni = length(u0i)
-    
-    bc = if !isnothing(constraints)
-        ne = length(constraints)
-        if iip
-            (residual,u,p,t) -> begin
-                residual[1:ni] .= u[1][u0i] .- u0[u0i]
-                residual[ni+1:ni+ne] .= map(constraints) do cons
-                    sub_u_p_into_symeq(cons.rhs - cons.lhs, u, p, stidxmap, pidxmap, iv, tspan)
-                end
-            end
-        else
-            (u,p,t) -> begin
-                consresid = map(constraints) do cons
-                    sub_u_p_into_symeq(cons.rhs-cons.lhs, u, p, stidxmap, pidxmap, iv, tspan)
-                end
-                resid = vcat(u[1][u0i] - u0[u0i], consresid)
-            end
-        end
-    else
-        if iip
-            (residual,u,p,t) -> begin
-                residual .= u[1] .- u0
-            end
-        else
-            (u,p,t) -> (u[1] - u0)
-        end
-    end
+    u0_idxs = has_alg_eqs(sys) ? collect(1:length(sts)) : [stidxmap[k] for (k,v) in u0map]
+
+    # bc = if !isnothing(constraints) && iip
+    #     (residual,u,p,t) -> begin
+    #         println(u(0.5))
+    #         residual[1:ni] .= u[1][u0i] .- u0[u0i]
+    #         for (i, cons) in enumerate(constraints)
+    #             residual[ni+i] = eval_symbolic_residual(cons, u, p, stidxmap, pidxmap, iv, tspan)
+    #         end
+    #     end
+
+    # elseif !isnothing(constraints) && !iip
+    #     (u,p,t) -> begin
+    #         consresid = map(constraints) do cons
+    #             eval_symbolic_residual(cons, u, p, stidxmap, pidxmap, iv, tspan)
+    #         end
+    #         resid = vcat(u[1][u0i] - u0[u0i], consresid)
+    #     end
+
+    # elseif iip
+    #     (residual,u,p,t) -> begin
+    #         println(u(0.5))
+    #         residual .= u[1] .- u0
+    #     end
+
+    # else
+    #     (u,p,t) -> (u[1] - u0)
+    # end
+    bc = process_constraints(sys, constraints, u0, u0_idxs, tspan, iip)
 
     return BVProblem{iip}(f, bc, u0, tspan, p; kwargs...)
 end
@@ -1001,11 +993,10 @@ get_callback(prob::BVProblem) = error("BVP solvers do not support callbacks.")
 
 # Validate that all the variables in the BVP constraints are well-formed states or parameters.
 function validate_constraint_syms(eq, constraintsts, constraintps, sts, ps, iv) 
-    ModelingToolkit.check_variables(constraintsts)
-    ModelingToolkit.check_parameters(constraintps)
-
     for var in constraintsts
-        if arguments(var) == iv
+        if length(arguments(var)) > 1
+            error("Too many arguments for variable $var.")
+        elseif arguments(var) == iv
             var ∈ sts || error("Constraint equation $eq contains a variable $var that is not a variable of the ODESystem.")
             error("Constraint equation $eq contains a variable $var that does not have a specified argument. Such equations should be specified as algebraic equations to the ODESystem rather than a boundary constraints.")
         else
@@ -1017,53 +1008,81 @@ function validate_constraint_syms(eq, constraintsts, constraintps, sts, ps, iv)
         if !iscall(var)
             var ∈ ps || error("Constraint equation $eq contains a parameter $var that is not a parameter of the ODESystem.")
         else
+            length(arguments(var)) > 1 && error("Too many arguments for parameter $var.")
             operation(var) ∈ ps || error("Constraint equations contain a parameter $var that is not a parameter of the ODESystem.")
         end
     end
 end
 
-# Helper to substitute numeric values for u, p into the algebraic equations in the ODESystem. Used to construct the boundary condition function. 
-#   Take a system with variables x,y, parameters g
-#
-#   1 + x(0) + y(0) → 1 + u[1][1] + u[1][2]
-#   x(0.5) → u(0.5)[1]
-#   x(0.5)*g(0.5) → u(0.5)[1]*p[1]
-function sub_u_p_into_symeq(eq, u, p, stidxmap, pidxmap, iv, tspan)
-    eq = Symbolics.unwrap(eq)
-
-    stmap = Dict([st => u[1][i] for (st, i) in stidxmap])
-    pmap = Dict([pa => p[i] for (pa, i) in pidxmap])
-    eq = Symbolics.substitute(eq, merge(stmap, pmap))
-
-    csyms = []
-    # Find most nested calls, substitute those first.
-    while !isempty(find_callable_syms!(csyms, eq))
-        for sym in csyms 
-            x = operation(sym)
-            t = arguments(sym)[1]
-            prog = (tspan[2] - tspan[1])/(t - tspan[1]) # 1 / the % of the timespan elapsed
-
-            if isparameter(x)
-                eq = Symbolics.substitute(eq, Dict(x(t) => p[pidxmap[x(iv)]]))
-            elseif isvariable(x)
-                eq = Symbolics.substitute(eq, Dict(x(t) => u[Int(end ÷ prog)][stidxmap[x(iv)]]))
+"""
+    process_constraints(sys, constraints, u0, tspan, iip)
+
+    Given an ODESystem with some constraints, generate the boundary condition function.
+"""
+function process_constraints(sys::ODESystem, constraints, u0, u0_idxs, tspan, iip)
+
+    iv = get_iv(sys)
+    sts = get_unknowns(sys)
+    ps = get_ps(sys)
+    np = length(ps)
+    ns = length(sts)
+
+    stidxmap = Dict([v => i for (i, v) in enumerate(sts)])
+    pidxmap = Dict([v => i for (i, v) in enumerate(ps)])
+
+    @variables sol(..)[1:ns] p[1:np]
+    exprs = Any[]
+
+    constraintsts = OrderedSet()
+    constraintps = OrderedSet()
+
+    !isnothing(constraints) && for cons in constraints
+        collect_vars!(constraintsts, constraintps, cons, iv)
+        validate_constraint_syms(cons, constraintsts, constraintps, Set(sts), Set(ps), iv)
+        expr = cons.rhs - cons.lhs
+
+        for st in constraintsts
+            x = operation(st)
+            t = arguments(st)[1]
+            idx = stidxmap[x(iv)]
+
+            expr = Symbolics.substitute(expr, Dict(x(t) => sol(t)[idx]))
+        end
+
+        for var in constraintps
+            if iscall(var)
+                x = operation(var)
+                t = arguments(var)[1]
+                idx = pidxmap[x]
+
+                expr = Symbolics.substitute(expr, Dict(x(t) => p[idx]))
+            else
+                idx = pidxmap[var]
+                expr = Symbolics.substitute(expr, Dict(var => p[idx]))
             end
         end
-        empty!(csyms)
+
+        empty!(constraintsts)
+        empty!(constraintps)
+        push!(exprs, expr)
     end
-    eq
-end
 
-function find_callable_syms!(csyms, ex)
-    ex = Symbolics.unwrap(ex)
+    init_cond_exprs = Any[]
 
-    if iscall(ex)
-        operation(ex) isa Symbolic && (arguments(ex)[1] isa Symbolic) && push!(csyms, ex) # only add leaf nodes 
-        for arg in arguments(ex)
-            find_callable_syms!(csyms, arg)
+    for i in u0_idxs
+        expr = sol(tspan[1])[i] - u0[i]
+        push!(init_cond_exprs, expr)
+    end
+
+    exprs = vcat(init_cond_exprs, exprs)
+    bcs = Symbolics.build_function(exprs, sol, p, expression = Val{false})
+    if iip
+        return (resid, u, p, t) -> begin
+            bcs[2](resid, u, p)
         end
+    else
+        return (u, p, t) -> bcs[1](u, p)
     end
-    csyms 
 end
 
 """
diff --git a/test/bvproblem.jl b/test/bvproblem.jl
index e864433f3c..7e0d45b128 100644
--- a/test/bvproblem.jl
+++ b/test/bvproblem.jl
@@ -1,3 +1,5 @@
+### TODO: update when BoundaryValueDiffEqAscher is updated to use the normal boundary condition conventions 
+
 using BoundaryValueDiffEq, OrdinaryDiffEq, BoundaryValueDiffEqAscher
 using ModelingToolkit
 using ModelingToolkit: t_nounits as t, D_nounits as D
@@ -81,51 +83,140 @@ end
 
 # Cartesian pendulum from the docs.
 # DAE IVP solved using BoundaryValueDiffEq solvers.
+# let
+#     @parameters g
+#     @variables x(t) y(t) [state_priority = 10] λ(t)
+#     eqs = [D(D(x)) ~ λ * x
+#            D(D(y)) ~ λ * y - g
+#            x^2 + y^2 ~ 1]
+#     @mtkbuild pend = ODESystem(eqs, t)
+# 
+#     tspan = (0.0, 1.5)
+#     u0map = [x => 1, y => 0]
+#     pmap = [g => 1]
+#     guess = [λ => 1]
+# 
+#     prob = ODEProblem(pend, u0map, tspan, pmap; guesses = guess)
+#     osol = solve(prob, Rodas5P())
+# 
+#     zeta = [0., 0., 0., 0., 0.]
+#     bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; guesses = guess)
+#     
+#     for solver in solvers
+#         sol = solve(bvp, solver(zeta), dt = 0.001)
+#         @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
+#         conditions = getfield.(equations(pend)[3:end], :rhs)
+#         @test isapprox([sol[conditions][1]; sol[x][1] - 1; sol[y][1]], zeros(5), atol = 0.001)
+#     end
+# 
+#     bvp2 = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(pend, u0map, tspan, parammap)
+#     for solver in solvers
+#         sol = solve(bvp, solver(zeta), dt = 0.01)
+#         @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
+#         conditions = getfield.(equations(pend)[3:end], :rhs)
+#         @test [sol[conditions][1]; sol[x][1] - 1; sol[y][1]] ≈ 0 
+#     end
+# end
+
+# Test generation of boundary condition function.
 let
-    @parameters g
-    @variables x(t) y(t) [state_priority = 10] λ(t)
-    eqs = [D(D(x)) ~ λ * x
-           D(D(y)) ~ λ * y - g
-           x^2 + y^2 ~ 1]
-    @mtkbuild pend = ODESystem(eqs, t)
-
-    tspan = (0.0, 1.5)
-    u0map = [x => 1, y => 0]
-    pmap = [g => 1]
-    guess = [λ => 1]
-
-    prob = ODEProblem(pend, u0map, tspan, pmap; guesses = guess)
-    osol = solve(prob, Rodas5P())
-
-    bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; guesses = guess)
+    @parameters α=7.5 β=4.0 γ=8.0 δ=5.0
+    @variables x(..) y(t)
+    eqs = [D(x(t)) ~ α * x(t) - β * x(t) * y,
+           D(y) ~ -γ * y + δ * x(t) * y]
     
-    for solver in solvers
-        sol = solve(bvp, solver(), dt = 0.001)
-        @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
-        conditions = getfield.(equations(pend)[3:end], :rhs)
-        @test isapprox([sol[conditions][1]; sol[x][1] - 1; sol[y][1]], zeros(5), atol = 0.001)
+    tspan = (0., 10.)
+    @mtkbuild lksys = ODESystem(eqs, t)
+
+    function lotkavolterra!(du, u, p, t) 
+        du[1] = p[1]*u[1] - p[2]*u[1]*u[2]
+        du[2] = -p[3]*u[2] + p[4]*u[1]*u[2]
     end
 
-    bvp2 = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(pend, u0map, tspan, parammap)
-    for solver in solvers
-        sol = solve(bvp, solver(), dt = 0.01)
-        @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
-        conditions = getfield.(equations(pend)[3:end], :rhs)
-        @test [sol[conditions][1]; sol[x][1] - 1; sol[y][1]] ≈ 0 
+    function lotkavolterra(u, p, t) 
+        [p[1]*u[1] - p[2]*u[1]*u[2], -p[3]*u[2] + p[4]*u[1]*u[2]]
+    end
+    # Compare the built bc function to the actual constructed one.
+    function bc!(resid, u, p, t) 
+        resid[1] = u[1][1] - 1.
+        resid[2] = u[1][2] - 2.
+        nothing
+    end
+    function bc(u, p, t)
+        [u[1][1] - 1., u[1][2] - 2.]
+    end
+
+    constraints = nothing
+    u0 = [1., 2.]; p = [7.5, 4., 8., 5.]
+    genbc_iip = ModelingToolkit.process_constraints(lksys, constraints, u0, [1, 2], tspan, true)
+    genbc_oop = ModelingToolkit.process_constraints(lksys, constraints, u0, [1, 2], tspan, false)
+
+    bvpi1 = SciMLBase.BVProblem(lotkavolterra!, bc!, [1,2], tspan, p)
+    bvpi2 = SciMLBase.BVProblem(lotkavolterra!, genbc_iip, [1,2], tspan, p)
+
+    sol1 = solve(bvpi1, MIRK4(), dt = 0.01)
+    sol2 = solve(bvpi2, MIRK4(), dt = 0.01)
+    @test sol1 ≈ sol2
+
+    bvpo1 = BVProblem(lotkavolterra, bc, [1,2], tspan, p)
+    bvpo2 = BVProblem(lotkavolterra, genbc_oop, [1,2], tspan, p)
+
+    sol1 = solve(bvpo1, MIRK4(), dt = 0.01)
+    sol2 = solve(bvpo2, MIRK4(), dt = 0.01)
+    @test sol1 ≈ sol2
+
+    # Test with a constraint.
+    constraints = [x(0.5) ~ 1.]
+
+    function bc!(resid, u, p, t) 
+        resid[1] = u[1][2] - 2.
+        resid[2] = u(0.5)[1] - 1.
+    end
+    function bc(u, p, t)
+        [u[1][2] - 2., u(0.5)[1] - 1.]
     end
+
+    u0 = [1., 2.]
+    genbc_iip = ModelingToolkit.process_constraints(lksys, constraints, u0, [2], tspan, true)
+    genbc_oop = ModelingToolkit.process_constraints(lksys, constraints, u0, [2], tspan, false)
+
+    bvpi1 = SciMLBase.BVProblem(lotkavolterra!, bc!, [1,2], tspan, p)
+    bvpi2 = SciMLBase.BVProblem(lotkavolterra!, genbc_iip, [1,2], tspan, p)
+
+    sol1 = solve(bvpi1, MIRK4(), dt = 0.01)
+    sol2 = solve(bvpi2, MIRK4(), dt = 0.01)
+    @test sol1 ≈ sol2 # don't get true equality here, not sure why
+
+    bvpo1 = BVProblem(lotkavolterra, bc, [1,2], tspan, p)
+    bvpo2 = BVProblem(lotkavolterra, genbc_oop, [1,2], tspan, p)
+
+    sol1 = solve(bvpo1, MIRK4(), dt = 0.01)
+    sol2 = solve(bvpo2, MIRK4(), dt = 0.01)
+    @test sol1 ≈ sol2
 end
 
 function test_solvers(solvers, prob, u0map, constraints, equations = []; dt = 0.01)
     for solver in solvers
-        sol = solve(bvp, solver(); dt)
+        sol = solve(prob, solver(); dt)
+        @test successful_retcode(sol.retcode)
+        p = prob.p; t = sol.t; bc = prob.f.bc
+        ns = length(prob.u0)
+
+        if isinplace(bvp.f)
+            resid = zeros(ns)
+            bc!(resid, sol, p, t)
+            @test isapprox(zeros(ns), resid)
+        else
+            @test isapprox(zeros(ns), bc(sol, p, t))
+        end
 
         for (k, v) in u0map
             @test sol[k][1] == v
         end
          
-        for cons in constraints
-            @test sol[cons.rhs - cons.lhs] ≈ 0
-        end
+        # for cons in constraints
+        #     @test sol[cons.rhs - cons.lhs] ≈ 0
+        # end
 
         for eq in equations
             @test sol[eq] ≈ 0
@@ -138,19 +229,18 @@ let
     @parameters α=7.5 β=4.0 γ=8.0 δ=5.0
     @variables x(..) y(t)
     
-    eqs = [D(x) ~ α * x - β * x * y,
-        D(y) ~ -γ * y + δ * x * y]
+    eqs = [D(x(t)) ~ α * x(t) - β * x(t) * y,
+           D(y) ~ -γ * y + δ * x(t) * y]
     
     u0map = [y => 2.0]
     parammap = [α => 7.5, β => 4, γ => 8.0, δ => 5.0]
     tspan = (0.0, 10.0)
-    guesses = [x => 1.0]
+    guesses = [x(t) => 1.0]
 
     @mtkbuild lotkavolterra = ODESystem(eqs, t)
-    op = ODEProblem(lotkavolterra, u0map, tspan, parammap, guesses = guesses)
 
     constraints = [x(6.) ~ 3]
-    bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; guesses, constraints)
+    bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lotkavolterra, u0map, tspan, parammap; guesses, constraints)
     test_solvers(solvers, bvp, u0map, constraints)
 
     # Testing that more complicated constraints give correct solutions.
@@ -177,37 +267,37 @@ end
 
 # Adding a midpoint boundary constraint.
 # Solve using BVDAE solvers.
-let 
-    @parameters g
-    @variables x(..) y(t) [state_priority = 10] λ(t)
-    eqs = [D(D(x(t))) ~ λ * x(t)
-           D(D(y)) ~ λ * y - g
-           x(t)^2 + y^2 ~ 1]
-    @mtkbuild pend = ODESystem(eqs, t)
-
-    tspan = (0.0, 1.5)
-    u0map = [x(t) => 0.6, y => 0.8]
-    parammap = [g => 1]
-    guesses = [λ => 1]
-
-    constraints = [x(0.5) ~ 1]
-    bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; constraints, guesses, check_length = false)
-    test_solvers(daesolvers, bvp, u0map, constraints)
-
-    bvp2 = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(pend, u0map, tspan, parammap)
-    test_solvers(daesolvers, bvp2, u0map, constraints, get_alg_eqs(pend))
-
-    # More complicated constraints.
-    u0map = [x(t) => 0.6]
-    guesses = [λ => 1, y(t) => 0.8]
-
-    constraints = [x(0.5) ~ 1, 
-                   x(0.3)^3 + y(0.6)^2 ~ 0.5]
-    bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; constraints, guesses, check_length = false)
-    test_solvers(daesolvers, bvp, u0map, constraints, get_alg_eqs(pend))
-
-    constraints = [x(0.4) * g ~ y(0.2),
-                   y(0.7) ~ 0.3]
-    bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; constraints, guesses, check_length = false)
-    test_solvers(daesolvers, bvp, u0map, constraints, get_alg_eqs(pend))
-end
+# let 
+#     @parameters g
+#     @variables x(..) y(t) [state_priority = 10] λ(t)
+#     eqs = [D(D(x(t))) ~ λ * x(t)
+#            D(D(y)) ~ λ * y - g
+#            x(t)^2 + y^2 ~ 1]
+#     @mtkbuild pend = ODESystem(eqs, t)
+# 
+#     tspan = (0.0, 1.5)
+#     u0map = [x(t) => 0.6, y => 0.8]
+#     parammap = [g => 1]
+#     guesses = [λ => 1]
+# 
+#     constraints = [x(0.5) ~ 1]
+#     bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; constraints, guesses, check_length = false)
+#     test_solvers(daesolvers, bvp, u0map, constraints)
+# 
+#     bvp2 = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(pend, u0map, tspan, parammap)
+#     test_solvers(daesolvers, bvp2, u0map, constraints, get_alg_eqs(pend))
+# 
+#     # More complicated constraints.
+#     u0map = [x(t) => 0.6]
+#     guesses = [λ => 1, y(t) => 0.8]
+# 
+#     constraints = [x(0.5) ~ 1, 
+#                    x(0.3)^3 + y(0.6)^2 ~ 0.5]
+#     bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; constraints, guesses, check_length = false)
+#     test_solvers(daesolvers, bvp, u0map, constraints, get_alg_eqs(pend))
+# 
+#     constraints = [x(0.4) * g ~ y(0.2),
+#                    y(0.7) ~ 0.3]
+#     bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; constraints, guesses, check_length = false)
+#     test_solvers(daesolvers, bvp, u0map, constraints, get_alg_eqs(pend))
+# end

From db5eb66ea29533e51aefd6ea49ea04e0257f3201 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Tue, 14 Jan 2025 13:46:31 -0500
Subject: [PATCH 022/111] up

---
 src/systems/diffeqs/abstractodesystem.jl |  33 +-----
 test/bvproblem.jl                        | 144 ++++++++++++-----------
 2 files changed, 76 insertions(+), 101 deletions(-)

diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl
index 25347a17cd..0e8435942c 100644
--- a/src/systems/diffeqs/abstractodesystem.jl
+++ b/src/systems/diffeqs/abstractodesystem.jl
@@ -939,7 +939,7 @@ function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = []
     ps = parameters(sys)
 
     # Constraint validation
-    f_cons = if !isnothing(constraints)
+    if !isnothing(constraints)
         constraints isa Equation || 
             constraints isa Vector{Equation} || 
             error("Constraints must be specified as an equation or a vector of equations.")
@@ -958,32 +958,6 @@ function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = []
     stidxmap = Dict([v => i for (i, v) in enumerate(sts)])
     u0_idxs = has_alg_eqs(sys) ? collect(1:length(sts)) : [stidxmap[k] for (k,v) in u0map]
 
-    # bc = if !isnothing(constraints) && iip
-    #     (residual,u,p,t) -> begin
-    #         println(u(0.5))
-    #         residual[1:ni] .= u[1][u0i] .- u0[u0i]
-    #         for (i, cons) in enumerate(constraints)
-    #             residual[ni+i] = eval_symbolic_residual(cons, u, p, stidxmap, pidxmap, iv, tspan)
-    #         end
-    #     end
-
-    # elseif !isnothing(constraints) && !iip
-    #     (u,p,t) -> begin
-    #         consresid = map(constraints) do cons
-    #             eval_symbolic_residual(cons, u, p, stidxmap, pidxmap, iv, tspan)
-    #         end
-    #         resid = vcat(u[1][u0i] - u0[u0i], consresid)
-    #     end
-
-    # elseif iip
-    #     (residual,u,p,t) -> begin
-    #         println(u(0.5))
-    #         residual .= u[1] .- u0
-    #     end
-
-    # else
-    #     (u,p,t) -> (u[1] - u0)
-    # end
     bc = process_constraints(sys, constraints, u0, u0_idxs, tspan, iip)
 
     return BVProblem{iip}(f, bc, u0, tspan, p; kwargs...)
@@ -1075,11 +1049,10 @@ function process_constraints(sys::ODESystem, constraints, u0, u0_idxs, tspan, ii
     end
 
     exprs = vcat(init_cond_exprs, exprs)
+    @show exprs
     bcs = Symbolics.build_function(exprs, sol, p, expression = Val{false})
     if iip
-        return (resid, u, p, t) -> begin
-            bcs[2](resid, u, p)
-        end
+        return (resid, u, p, t) -> bcs[2](resid, u, p)
     else
         return (u, p, t) -> bcs[1](u, p)
     end
diff --git a/test/bvproblem.jl b/test/bvproblem.jl
index 7e0d45b128..42762b8a3f 100644
--- a/test/bvproblem.jl
+++ b/test/bvproblem.jl
@@ -2,7 +2,9 @@
 
 using BoundaryValueDiffEq, OrdinaryDiffEq, BoundaryValueDiffEqAscher
 using ModelingToolkit
+using SciMLBase
 using ModelingToolkit: t_nounits as t, D_nounits as D
+import ModelingToolkit: process_constraints
 
 ### Test Collocation solvers on simple problems 
 solvers = [MIRK4, RadauIIa5, LobattoIIIa3]
@@ -81,46 +83,9 @@ end
 ### ODESystem with constraint equations, DAEs with constraints ###
 ##################################################################
 
-# Cartesian pendulum from the docs.
-# DAE IVP solved using BoundaryValueDiffEq solvers.
-# let
-#     @parameters g
-#     @variables x(t) y(t) [state_priority = 10] λ(t)
-#     eqs = [D(D(x)) ~ λ * x
-#            D(D(y)) ~ λ * y - g
-#            x^2 + y^2 ~ 1]
-#     @mtkbuild pend = ODESystem(eqs, t)
-# 
-#     tspan = (0.0, 1.5)
-#     u0map = [x => 1, y => 0]
-#     pmap = [g => 1]
-#     guess = [λ => 1]
-# 
-#     prob = ODEProblem(pend, u0map, tspan, pmap; guesses = guess)
-#     osol = solve(prob, Rodas5P())
-# 
-#     zeta = [0., 0., 0., 0., 0.]
-#     bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; guesses = guess)
-#     
-#     for solver in solvers
-#         sol = solve(bvp, solver(zeta), dt = 0.001)
-#         @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
-#         conditions = getfield.(equations(pend)[3:end], :rhs)
-#         @test isapprox([sol[conditions][1]; sol[x][1] - 1; sol[y][1]], zeros(5), atol = 0.001)
-#     end
-# 
-#     bvp2 = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(pend, u0map, tspan, parammap)
-#     for solver in solvers
-#         sol = solve(bvp, solver(zeta), dt = 0.01)
-#         @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
-#         conditions = getfield.(equations(pend)[3:end], :rhs)
-#         @test [sol[conditions][1]; sol[x][1] - 1; sol[y][1]] ≈ 0 
-#     end
-# end
-
 # Test generation of boundary condition function.
 let
-    @parameters α=7.5 β=4.0 γ=8.0 δ=5.0
+    @parameters α=1.5 β=1.0 γ=3.0 δ=1.0
     @variables x(..) y(t)
     eqs = [D(x(t)) ~ α * x(t) - β * x(t) * y,
            D(y) ~ -γ * y + δ * x(t) * y]
@@ -130,11 +95,11 @@ let
 
     function lotkavolterra!(du, u, p, t) 
         du[1] = p[1]*u[1] - p[2]*u[1]*u[2]
-        du[2] = -p[3]*u[2] + p[4]*u[1]*u[2]
+        du[2] = -p[4]*u[2] + p[3]*u[1]*u[2]
     end
 
     function lotkavolterra(u, p, t) 
-        [p[1]*u[1] - p[2]*u[1]*u[2], -p[3]*u[2] + p[4]*u[1]*u[2]]
+        [p[1]*u[1] - p[2]*u[1]*u[2], -p[4]*u[2] + p[3]*u[1]*u[2]]
     end
     # Compare the built bc function to the actual constructed one.
     function bc!(resid, u, p, t) 
@@ -146,23 +111,22 @@ let
         [u[1][1] - 1., u[1][2] - 2.]
     end
 
-    constraints = nothing
-    u0 = [1., 2.]; p = [7.5, 4., 8., 5.]
-    genbc_iip = ModelingToolkit.process_constraints(lksys, constraints, u0, [1, 2], tspan, true)
-    genbc_oop = ModelingToolkit.process_constraints(lksys, constraints, u0, [1, 2], tspan, false)
+    u0 = [1., 2.]; p = [1.5, 1., 3., 1.]
+    genbc_iip = ModelingToolkit.process_constraints(lksys, nothing, u0, [1, 2], tspan, true)
+    genbc_oop = ModelingToolkit.process_constraints(lksys, nothing, u0, [1, 2], tspan, false)
 
-    bvpi1 = SciMLBase.BVProblem(lotkavolterra!, bc!, [1,2], tspan, p)
-    bvpi2 = SciMLBase.BVProblem(lotkavolterra!, genbc_iip, [1,2], tspan, p)
+    bvpi1 = SciMLBase.BVProblem(lotkavolterra!, bc!, [1.,2.], tspan, p)
+    bvpi2 = SciMLBase.BVProblem(lotkavolterra!, genbc_iip, [1.,2.], tspan, p)
 
-    sol1 = solve(bvpi1, MIRK4(), dt = 0.01)
-    sol2 = solve(bvpi2, MIRK4(), dt = 0.01)
+    sol1 = solve(bvpi1, MIRK4(), dt = 0.05)
+    sol2 = solve(bvpi2, MIRK4(), dt = 0.05)
     @test sol1 ≈ sol2
 
     bvpo1 = BVProblem(lotkavolterra, bc, [1,2], tspan, p)
     bvpo2 = BVProblem(lotkavolterra, genbc_oop, [1,2], tspan, p)
 
-    sol1 = solve(bvpo1, MIRK4(), dt = 0.01)
-    sol2 = solve(bvpo2, MIRK4(), dt = 0.01)
+    sol1 = solve(bvpo1, MIRK4(), dt = 0.05)
+    sol2 = solve(bvpo2, MIRK4(), dt = 0.05)
     @test sol1 ≈ sol2
 
     # Test with a constraint.
@@ -180,28 +144,28 @@ let
     genbc_iip = ModelingToolkit.process_constraints(lksys, constraints, u0, [2], tspan, true)
     genbc_oop = ModelingToolkit.process_constraints(lksys, constraints, u0, [2], tspan, false)
 
-    bvpi1 = SciMLBase.BVProblem(lotkavolterra!, bc!, [1,2], tspan, p)
-    bvpi2 = SciMLBase.BVProblem(lotkavolterra!, genbc_iip, [1,2], tspan, p)
-
-    sol1 = solve(bvpi1, MIRK4(), dt = 0.01)
-    sol2 = solve(bvpi2, MIRK4(), dt = 0.01)
+    bvpi1 = SciMLBase.BVProblem(lotkavolterra!, bc!, [1.,2.], tspan, p)
+    bvpi2 = SciMLBase.BVProblem(lotkavolterra!, genbc_iip, [1.,2.], tspan, p)
+    bvpi3 = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lksys, u0map, tspan, parammap; guesses, constraints)
+    
+    sol1 = solve(bvpi1, MIRK4(), dt = 0.05)
+    sol2 = solve(bvpi2, MIRK4(), dt = 0.05)
     @test sol1 ≈ sol2 # don't get true equality here, not sure why
 
     bvpo1 = BVProblem(lotkavolterra, bc, [1,2], tspan, p)
     bvpo2 = BVProblem(lotkavolterra, genbc_oop, [1,2], tspan, p)
 
-    sol1 = solve(bvpo1, MIRK4(), dt = 0.01)
-    sol2 = solve(bvpo2, MIRK4(), dt = 0.01)
+    sol1 = solve(bvpo1, MIRK4(), dt = 0.05)
+    sol2 = solve(bvpo2, MIRK4(), dt = 0.05)
     @test sol1 ≈ sol2
 end
 
-function test_solvers(solvers, prob, u0map, constraints, equations = []; dt = 0.01)
+function test_solvers(solvers, prob, u0map, constraints, equations = []; dt = 0.05)
     for solver in solvers
         sol = solve(prob, solver(); dt)
-        @test successful_retcode(sol.retcode)
+        @test SciMLBase.successful_retcode(sol.retcode)
         p = prob.p; t = sol.t; bc = prob.f.bc
         ns = length(prob.u0)
-
         if isinplace(bvp.f)
             resid = zeros(ns)
             bc!(resid, sol, p, t)
@@ -226,45 +190,83 @@ end
 
 # Simple ODESystem with BVP constraints.
 let
-    @parameters α=7.5 β=4.0 γ=8.0 δ=5.0
+    t = ModelingToolkit.t_nounits; D = ModelingToolkit.D_nounits
+    @parameters α=1.5 β=1.0 γ=3.0 δ=1.0
     @variables x(..) y(t)
     
     eqs = [D(x(t)) ~ α * x(t) - β * x(t) * y,
            D(y) ~ -γ * y + δ * x(t) * y]
     
-    u0map = [y => 2.0]
+    u0map = []
     parammap = [α => 7.5, β => 4, γ => 8.0, δ => 5.0]
     tspan = (0.0, 10.0)
-    guesses = [x(t) => 1.0]
+    guesses = [x(t) => 1.0, y => 2.]
 
     @mtkbuild lotkavolterra = ODESystem(eqs, t)
 
-    constraints = [x(6.) ~ 3]
+    constraints = [x(6.) ~ 1.5]
     bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lotkavolterra, u0map, tspan, parammap; guesses, constraints)
     test_solvers(solvers, bvp, u0map, constraints)
 
     # Testing that more complicated constraints give correct solutions.
-    constraints = [y(2.) + x(8.) ~ 12]
-    bvp = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(pend, u0map, tspan, parammap; guesses, constraints)
+    constraints = [y(2.) + x(8.) ~ 2.]
+    bvp = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(lotkavolterra, u0map, tspan, parammap; guesses, constraints)
     test_solvers(solvers, bvp, u0map, constraints)
 
-    constraints = [α * β - x(6.) ~ 24]
-    bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; guesses, constraints)
+    constraints = [α * β - x(6.) ~ 0.5]
+    bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lotkavolterra, u0map, tspan, parammap; guesses, constraints)
     test_solvers(solvers, bvp, u0map, constraints)
 
     # Testing that errors are properly thrown when malformed constraints are given.
     @variables bad(..)
     constraints = [x(1.) + bad(3.) ~ 10]
-    @test_throws Exception bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; guesses, constraints)
+    @test_throws Exception bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lotkavolterra, u0map, tspan, parammap; guesses, constraints)
 
     constraints = [x(t) + y(t) ~ 3]
-    @test_throws Exception bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; guesses, constraints)
+    @test_throws Exception bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lotkavolterra, u0map, tspan, parammap; guesses, constraints)
 
     @parameters bad2
     constraints = [bad2 + x(0.) ~ 3]
-    @test_throws Exception bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; guesses, constraints)
+    @test_throws Exception bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lotkavolterra, u0map, tspan, parammap; guesses, constraints)
 end
 
+# Cartesian pendulum from the docs.
+# DAE IVP solved using BoundaryValueDiffEq solvers.
+# let
+#     @parameters g
+#     @variables x(t) y(t) [state_priority = 10] λ(t)
+#     eqs = [D(D(x)) ~ λ * x
+#            D(D(y)) ~ λ * y - g
+#            x^2 + y^2 ~ 1]
+#     @mtkbuild pend = ODESystem(eqs, t)
+# 
+#     tspan = (0.0, 1.5)
+#     u0map = [x => 1, y => 0]
+#     pmap = [g => 1]
+#     guess = [λ => 1]
+# 
+#     prob = ODEProblem(pend, u0map, tspan, pmap; guesses = guess)
+#     osol = solve(prob, Rodas5P())
+# 
+#     zeta = [0., 0., 0., 0., 0.]
+#     bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; guesses = guess)
+#     
+#     for solver in solvers
+#         sol = solve(bvp, solver(zeta), dt = 0.001)
+#         @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
+#         conditions = getfield.(equations(pend)[3:end], :rhs)
+#         @test isapprox([sol[conditions][1]; sol[x][1] - 1; sol[y][1]], zeros(5), atol = 0.001)
+#     end
+# 
+#     bvp2 = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(pend, u0map, tspan, parammap)
+#     for solver in solvers
+#         sol = solve(bvp, solver(zeta), dt = 0.01)
+#         @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
+#         conditions = getfield.(equations(pend)[3:end], :rhs)
+#         @test [sol[conditions][1]; sol[x][1] - 1; sol[y][1]] ≈ 0 
+#     end
+# end
+
 # Adding a midpoint boundary constraint.
 # Solve using BVDAE solvers.
 # let 

From e802946e597d8f7fa7932ed328cb3acb1599e2c3 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Wed, 15 Jan 2025 16:10:55 -0500
Subject: [PATCH 023/111] test update

---
 src/systems/diffeqs/abstractodesystem.jl | 10 ++--
 test/bvproblem.jl                        | 76 +++++++++++++-----------
 2 files changed, 44 insertions(+), 42 deletions(-)

diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl
index 0e8435942c..d935a94eb8 100644
--- a/src/systems/diffeqs/abstractodesystem.jl
+++ b/src/systems/diffeqs/abstractodesystem.jl
@@ -860,12 +860,11 @@ SciMLBase.BVProblem{iip}(sys::AbstractODESystem, u0map, tspan,
                          kwargs...) where {iip}
 ```
 
-Create a boundary value problem from the [`ODESystem`](@ref). The arguments `dvs` and
-`ps` are used to set the order of the dependent variable and parameter vectors,
-respectively. `u0map` is used to specify fixed initial values for the states.
+Create a boundary value problem from the [`ODESystem`](@ref). 
 
-Every variable must have either an initial guess supplied using `guesses` or 
-a fixed initial value specified using `u0map`.
+`u0map` is used to specify fixed initial values for the states. Every variable 
+must have either an initial guess supplied using `guesses` or a fixed initial 
+value specified using `u0map`.
 
 `constraints` are used to specify boundary conditions to the ODESystem in the
 form of equations. These values should specify values that state variables should
@@ -1049,7 +1048,6 @@ function process_constraints(sys::ODESystem, constraints, u0, u0_idxs, tspan, ii
     end
 
     exprs = vcat(init_cond_exprs, exprs)
-    @show exprs
     bcs = Symbolics.build_function(exprs, sol, p, expression = Val{false})
     if iip
         return (resid, u, p, t) -> bcs[2](resid, u, p)
diff --git a/test/bvproblem.jl b/test/bvproblem.jl
index 42762b8a3f..e6645c060d 100644
--- a/test/bvproblem.jl
+++ b/test/bvproblem.jl
@@ -83,14 +83,14 @@ end
 ### ODESystem with constraint equations, DAEs with constraints ###
 ##################################################################
 
-# Test generation of boundary condition function.
+# Test generation of boundary condition function using `process_constraints`. Compare solutions to manually written boundary conditions
 let
     @parameters α=1.5 β=1.0 γ=3.0 δ=1.0
-    @variables x(..) y(t)
-    eqs = [D(x(t)) ~ α * x(t) - β * x(t) * y,
-           D(y) ~ -γ * y + δ * x(t) * y]
+    @variables x(..) y(..)
+    eqs = [D(x(t)) ~ α * x(t) - β * x(t) * y(t),
+           D(y(t)) ~ -γ * y(t) + δ * x(t) * y(t)]
     
-    tspan = (0., 10.)
+    tspan = (0., 1.)
     @mtkbuild lksys = ODESystem(eqs, t)
 
     function lotkavolterra!(du, u, p, t) 
@@ -111,7 +111,7 @@ let
         [u[1][1] - 1., u[1][2] - 2.]
     end
 
-    u0 = [1., 2.]; p = [1.5, 1., 3., 1.]
+    u0 = [1., 2.]; p = [1.5, 1., 1., 3.]
     genbc_iip = ModelingToolkit.process_constraints(lksys, nothing, u0, [1, 2], tspan, true)
     genbc_oop = ModelingToolkit.process_constraints(lksys, nothing, u0, [1, 2], tspan, false)
 
@@ -130,48 +130,54 @@ let
     @test sol1 ≈ sol2
 
     # Test with a constraint.
-    constraints = [x(0.5) ~ 1.]
+    constraints = [y(0.5) ~ 2.]
 
     function bc!(resid, u, p, t) 
-        resid[1] = u[1][2] - 2.
-        resid[2] = u(0.5)[1] - 1.
+        resid[1] = u(0.0)[1] - 1.
+        resid[2] = u(0.5)[2] - 2.
     end
     function bc(u, p, t)
-        [u[1][2] - 2., u(0.5)[1] - 1.]
+        [u(0.0)[1] - 1., u(0.5)[2] - 2.]
     end
 
-    u0 = [1., 2.]
-    genbc_iip = ModelingToolkit.process_constraints(lksys, constraints, u0, [2], tspan, true)
-    genbc_oop = ModelingToolkit.process_constraints(lksys, constraints, u0, [2], tspan, false)
+    u0 = [1, 1.]
+    genbc_iip = ModelingToolkit.process_constraints(lksys, constraints, u0, [1], tspan, true)
+    genbc_oop = ModelingToolkit.process_constraints(lksys, constraints, u0, [1], tspan, false)
 
-    bvpi1 = SciMLBase.BVProblem(lotkavolterra!, bc!, [1.,2.], tspan, p)
-    bvpi2 = SciMLBase.BVProblem(lotkavolterra!, genbc_iip, [1.,2.], tspan, p)
-    bvpi3 = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lksys, u0map, tspan, parammap; guesses, constraints)
+    bvpi1 = SciMLBase.BVProblem(lotkavolterra!, bc!, u0, tspan, p)
+    bvpi2 = SciMLBase.BVProblem(lotkavolterra!, genbc_iip, u0, tspan, p)
+    bvpi3 = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lksys, [x(t) => 1.], tspan; guesses = [y(t) => 1.], constraints)
+    bvpi4 = SciMLBase.BVProblem{true, SciMLBase.FullSpecialize}(lksys, [x(t) => 1.], tspan; guesses = [y(t) => 1.], constraints)
     
-    sol1 = solve(bvpi1, MIRK4(), dt = 0.05)
-    sol2 = solve(bvpi2, MIRK4(), dt = 0.05)
-    @test sol1 ≈ sol2 # don't get true equality here, not sure why
+    @btime sol1 = solve(bvpi1, MIRK4(), dt = 0.01)
+    @btime sol2 = solve(bvpi2, MIRK4(), dt = 0.01)
+    @btime sol3 = solve(bvpi3, MIRK4(), dt = 0.01)
+    @btime sol4 = solve(bvpi4, MIRK4(), dt = 0.01)
+    @test sol1 ≈ sol2 ≈ sol3 ≈ sol4 # don't get true equality here, not sure why
 
     bvpo1 = BVProblem(lotkavolterra, bc, [1,2], tspan, p)
     bvpo2 = BVProblem(lotkavolterra, genbc_oop, [1,2], tspan, p)
+    bvpo3 = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(lksys, [x(t) => 1.], tspan, parammap; guesses = [y(t) => 1.], constraints)
 
-    sol1 = solve(bvpo1, MIRK4(), dt = 0.05)
-    sol2 = solve(bvpo2, MIRK4(), dt = 0.05)
-    @test sol1 ≈ sol2
+    @btime sol1 = solve(bvpo1, MIRK4(), dt = 0.05)
+    @btime sol2 = solve(bvpo2, MIRK4(), dt = 0.05)
+    @btime sol3 = solve(bvpo3, MIRK4(), dt = 0.05)
+    @test sol1 ≈ sol2 ≈ sol3
 end
 
-function test_solvers(solvers, prob, u0map, constraints, equations = []; dt = 0.05)
+function test_solvers(solvers, prob, u0map, constraints, equations = []; dt = 0.05, atol = 1e-4)
     for solver in solvers
-        sol = solve(prob, solver(); dt)
+        println("Solver: $solver")
+        @btime sol = solve(prob, solver(), dt = dt, abstol = atol)
         @test SciMLBase.successful_retcode(sol.retcode)
         p = prob.p; t = sol.t; bc = prob.f.bc
         ns = length(prob.u0)
         if isinplace(bvp.f)
             resid = zeros(ns)
             bc!(resid, sol, p, t)
-            @test isapprox(zeros(ns), resid)
+            @test isapprox(zeros(ns), resid; atol)
         else
-            @test isapprox(zeros(ns), bc(sol, p, t))
+            @test isapprox(zeros(ns), bc(sol, p, t); atol)
         end
 
         for (k, v) in u0map
@@ -190,23 +196,21 @@ end
 
 # Simple ODESystem with BVP constraints.
 let
-    t = ModelingToolkit.t_nounits; D = ModelingToolkit.D_nounits
     @parameters α=1.5 β=1.0 γ=3.0 δ=1.0
-    @variables x(..) y(t)
+    @variables x(..) y(..)
     
-    eqs = [D(x(t)) ~ α * x(t) - β * x(t) * y,
-           D(y) ~ -γ * y + δ * x(t) * y]
+    eqs = [D(x(t)) ~ α * x(t) - β * x(t) * y(t),
+           D(y(t)) ~ -γ * y(t) + δ * x(t) * y(t)]
     
     u0map = []
-    parammap = [α => 7.5, β => 4, γ => 8.0, δ => 5.0]
     tspan = (0.0, 10.0)
-    guesses = [x(t) => 1.0, y => 2.]
+    guesses = [x(t) => 4.0, y(t) => 2.]
 
-    @mtkbuild lotkavolterra = ODESystem(eqs, t)
+    @mtkbuild lksys = ODESystem(eqs, t)
 
-    constraints = [x(6.) ~ 1.5]
-    bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lotkavolterra, u0map, tspan, parammap; guesses, constraints)
-    test_solvers(solvers, bvp, u0map, constraints)
+    constraints = [x(6.) ~ 3.5, x(3.) ~ 7.]
+    bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lksys, u0map, tspan; guesses, constraints)
+    test_solvers(solvers, bvp, u0map, constraints; dt = 0.1)
 
     # Testing that more complicated constraints give correct solutions.
     constraints = [y(2.) + x(8.) ~ 2.]

From e74e047bf333fc197ecd0b5390f22aa2651a4431 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Wed, 15 Jan 2025 17:20:43 -0500
Subject: [PATCH 024/111] fix

---
 test/bvproblem.jl | 44 ++++++++++++++++++++++----------------------
 1 file changed, 22 insertions(+), 22 deletions(-)

diff --git a/test/bvproblem.jl b/test/bvproblem.jl
index e6645c060d..a8de6af44d 100644
--- a/test/bvproblem.jl
+++ b/test/bvproblem.jl
@@ -7,7 +7,7 @@ using ModelingToolkit: t_nounits as t, D_nounits as D
 import ModelingToolkit: process_constraints
 
 ### Test Collocation solvers on simple problems 
-solvers = [MIRK4, RadauIIa5, LobattoIIIa3]
+solvers = [MIRK4, RadauIIa5]
 daesolvers = [Ascher2, Ascher4, Ascher6]
 
 let
@@ -149,32 +149,32 @@ let
     bvpi3 = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lksys, [x(t) => 1.], tspan; guesses = [y(t) => 1.], constraints)
     bvpi4 = SciMLBase.BVProblem{true, SciMLBase.FullSpecialize}(lksys, [x(t) => 1.], tspan; guesses = [y(t) => 1.], constraints)
     
-    @btime sol1 = solve(bvpi1, MIRK4(), dt = 0.01)
-    @btime sol2 = solve(bvpi2, MIRK4(), dt = 0.01)
-    @btime sol3 = solve(bvpi3, MIRK4(), dt = 0.01)
-    @btime sol4 = solve(bvpi4, MIRK4(), dt = 0.01)
+    sol1 = @btime solve($bvpi1, MIRK4(), dt = 0.01)
+    sol2 = @btime solve($bvpi2, MIRK4(), dt = 0.01)
+    sol3 = @btime solve($bvpi3, MIRK4(), dt = 0.01)
+    sol4 = @btime solve($bvpi4, MIRK4(), dt = 0.01)
     @test sol1 ≈ sol2 ≈ sol3 ≈ sol4 # don't get true equality here, not sure why
 
     bvpo1 = BVProblem(lotkavolterra, bc, [1,2], tspan, p)
     bvpo2 = BVProblem(lotkavolterra, genbc_oop, [1,2], tspan, p)
-    bvpo3 = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(lksys, [x(t) => 1.], tspan, parammap; guesses = [y(t) => 1.], constraints)
+    bvpo3 = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(lksys, [x(t) => 1.], tspan; guesses = [y(t) => 1.], constraints)
 
-    @btime sol1 = solve(bvpo1, MIRK4(), dt = 0.05)
-    @btime sol2 = solve(bvpo2, MIRK4(), dt = 0.05)
-    @btime sol3 = solve(bvpo3, MIRK4(), dt = 0.05)
+    sol1 = @btime solve($bvpo1, MIRK4(), dt = 0.05)
+    sol2 = @btime solve($bvpo2, MIRK4(), dt = 0.05)
+    sol3 = @btime solve($bvpo3, MIRK4(), dt = 0.05)
     @test sol1 ≈ sol2 ≈ sol3
 end
 
 function test_solvers(solvers, prob, u0map, constraints, equations = []; dt = 0.05, atol = 1e-4)
     for solver in solvers
         println("Solver: $solver")
-        @btime sol = solve(prob, solver(), dt = dt, abstol = atol)
+        sol = @btime solve($prob, $solver(), dt = $dt, abstol = $atol)
         @test SciMLBase.successful_retcode(sol.retcode)
         p = prob.p; t = sol.t; bc = prob.f.bc
         ns = length(prob.u0)
         if isinplace(bvp.f)
             resid = zeros(ns)
-            bc!(resid, sol, p, t)
+            bc(resid, sol, p, t)
             @test isapprox(zeros(ns), resid; atol)
         else
             @test isapprox(zeros(ns), bc(sol, p, t); atol)
@@ -203,35 +203,35 @@ let
            D(y(t)) ~ -γ * y(t) + δ * x(t) * y(t)]
     
     u0map = []
-    tspan = (0.0, 10.0)
+    tspan = (0.0, 1.0)
     guesses = [x(t) => 4.0, y(t) => 2.]
 
     @mtkbuild lksys = ODESystem(eqs, t)
 
-    constraints = [x(6.) ~ 3.5, x(3.) ~ 7.]
+    constraints = [x(.6) ~ 3.5, x(.3) ~ 7.]
     bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lksys, u0map, tspan; guesses, constraints)
-    test_solvers(solvers, bvp, u0map, constraints; dt = 0.1)
+    test_solvers(solvers, bvp, u0map, constraints; dt = 0.05)
 
     # Testing that more complicated constraints give correct solutions.
-    constraints = [y(2.) + x(8.) ~ 2.]
-    bvp = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(lotkavolterra, u0map, tspan, parammap; guesses, constraints)
-    test_solvers(solvers, bvp, u0map, constraints)
+    constraints = [y(.2) + x(.8) ~ 3., y(.3) + x(.5) ~ 5.]
+    bvp = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(lksys, u0map, tspan; guesses, constraints, jac = true)
+    test_solvers(solvers, bvp, u0map, constraints; dt = 0.05)
 
-    constraints = [α * β - x(6.) ~ 0.5]
-    bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lotkavolterra, u0map, tspan, parammap; guesses, constraints)
+    constraints = [α * β - x(.6) ~ 0.0, y(.2) ~ 3.]
+    bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lksys, u0map, tspan; guesses, constraints)
     test_solvers(solvers, bvp, u0map, constraints)
 
     # Testing that errors are properly thrown when malformed constraints are given.
     @variables bad(..)
     constraints = [x(1.) + bad(3.) ~ 10]
-    @test_throws Exception bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lotkavolterra, u0map, tspan, parammap; guesses, constraints)
+    @test_throws Exception bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lksys, u0map, tspan; guesses, constraints)
 
     constraints = [x(t) + y(t) ~ 3]
-    @test_throws Exception bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lotkavolterra, u0map, tspan, parammap; guesses, constraints)
+    @test_throws Exception bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lksys, u0map, tspan; guesses, constraints)
 
     @parameters bad2
     constraints = [bad2 + x(0.) ~ 3]
-    @test_throws Exception bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lotkavolterra, u0map, tspan, parammap; guesses, constraints)
+    @test_throws Exception bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lksys, u0map, tspan; guesses, constraints)
 end
 
 # Cartesian pendulum from the docs.

From 86d4144d910c0ba56ec5057f7a380ccca5caf879 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Fri, 17 Jan 2025 16:49:58 -0500
Subject: [PATCH 025/111] test more solvers:

---
 src/systems/diffeqs/abstractodesystem.jl |   2 +-
 test/bvproblem.jl                        | 333 ++++++++++++-----------
 2 files changed, 172 insertions(+), 163 deletions(-)

diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl
index d935a94eb8..cf6e0962fd 100644
--- a/src/systems/diffeqs/abstractodesystem.jl
+++ b/src/systems/diffeqs/abstractodesystem.jl
@@ -969,7 +969,7 @@ function validate_constraint_syms(eq, constraintsts, constraintps, sts, ps, iv)
     for var in constraintsts
         if length(arguments(var)) > 1
             error("Too many arguments for variable $var.")
-        elseif arguments(var) == iv
+        elseif isequal(arguments(var)[1], iv)
             var ∈ sts || error("Constraint equation $eq contains a variable $var that is not a variable of the ODESystem.")
             error("Constraint equation $eq contains a variable $var that does not have a specified argument. Such equations should be specified as algebraic equations to the ODESystem rather than a boundary constraints.")
         else
diff --git a/test/bvproblem.jl b/test/bvproblem.jl
index a8de6af44d..2f278e6135 100644
--- a/test/bvproblem.jl
+++ b/test/bvproblem.jl
@@ -1,183 +1,186 @@
 ### TODO: update when BoundaryValueDiffEqAscher is updated to use the normal boundary condition conventions 
 
 using BoundaryValueDiffEq, OrdinaryDiffEq, BoundaryValueDiffEqAscher
+using BenchmarkTools
 using ModelingToolkit
 using SciMLBase
 using ModelingToolkit: t_nounits as t, D_nounits as D
 import ModelingToolkit: process_constraints
 
 ### Test Collocation solvers on simple problems 
-solvers = [MIRK4, RadauIIa5]
+solvers = [MIRK4, RadauIIa5, LobattoIIIa3]
 daesolvers = [Ascher2, Ascher4, Ascher6]
 
-let
-     @parameters α=7.5 β=4.0 γ=8.0 δ=5.0
-     @variables x(t)=1.0 y(t)=2.0
-     
-     eqs = [D(x) ~ α * x - β * x * y,
-         D(y) ~ -γ * y + δ * x * y]
-     
-     u0map = [x => 1.0, y => 2.0]
-     parammap = [α => 7.5, β => 4, γ => 8.0, δ => 5.0]
-     tspan = (0.0, 10.0)
-     
-     @mtkbuild lotkavolterra = ODESystem(eqs, t)
-     op = ODEProblem(lotkavolterra, u0map, tspan, parammap)
-     osol = solve(op, Vern9())
-     
-     bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(
-         lotkavolterra, u0map, tspan, parammap; eval_expression = true)
-     
-     for solver in solvers
-         sol = solve(bvp, solver(), dt = 0.01)
-         @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
-         @test sol.u[1] == [1.0, 2.0]
-     end
-     
-     # Test out of place
-     bvp2 = SciMLBase.BVProblem{false, SciMLBase.AutoSpecialize}(
-         lotkavolterra, u0map, tspan, parammap; eval_expression = true)
-     
-     for solver in solvers
-         sol = solve(bvp2, solver(), dt = 0.01)
-         @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
-         @test sol.u[1] == [1.0, 2.0]
-     end
-end
-
-### Testing on pendulum
-let
-     @parameters g=9.81 L=1.0
-     @variables θ(t) = π / 2 θ_t(t)
-     
-     eqs = [D(θ) ~ θ_t
-            D(θ_t) ~ -(g / L) * sin(θ)]
-     
-     @mtkbuild pend = ODESystem(eqs, t)
-     
-     u0map = [θ => π / 2, θ_t => π / 2]
-     parammap = [:L => 1.0, :g => 9.81]
-     tspan = (0.0, 6.0)
-     
-     op = ODEProblem(pend, u0map, tspan, parammap)
-     osol = solve(op, Vern9())
-     
-     bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap)
-     for solver in solvers
-         sol = solve(bvp, solver(), dt = 0.01)
-         @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
-         @test sol.u[1] == [π / 2, π / 2]
-     end
-     
-     # Test out-of-place
-     bvp2 = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(pend, u0map, tspan, parammap)
-     
-     for solver in solvers
-         sol = solve(bvp2, solver(), dt = 0.01)
-         @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
-         @test sol.u[1] == [π / 2, π / 2]
-     end
-end
-
-##################################################################
-### ODESystem with constraint equations, DAEs with constraints ###
-##################################################################
-
-# Test generation of boundary condition function using `process_constraints`. Compare solutions to manually written boundary conditions
-let
-    @parameters α=1.5 β=1.0 γ=3.0 δ=1.0
-    @variables x(..) y(..)
-    eqs = [D(x(t)) ~ α * x(t) - β * x(t) * y(t),
-           D(y(t)) ~ -γ * y(t) + δ * x(t) * y(t)]
-    
-    tspan = (0., 1.)
-    @mtkbuild lksys = ODESystem(eqs, t)
-
-    function lotkavolterra!(du, u, p, t) 
-        du[1] = p[1]*u[1] - p[2]*u[1]*u[2]
-        du[2] = -p[4]*u[2] + p[3]*u[1]*u[2]
-    end
-
-    function lotkavolterra(u, p, t) 
-        [p[1]*u[1] - p[2]*u[1]*u[2], -p[4]*u[2] + p[3]*u[1]*u[2]]
-    end
-    # Compare the built bc function to the actual constructed one.
-    function bc!(resid, u, p, t) 
-        resid[1] = u[1][1] - 1.
-        resid[2] = u[1][2] - 2.
-        nothing
-    end
-    function bc(u, p, t)
-        [u[1][1] - 1., u[1][2] - 2.]
-    end
-
-    u0 = [1., 2.]; p = [1.5, 1., 1., 3.]
-    genbc_iip = ModelingToolkit.process_constraints(lksys, nothing, u0, [1, 2], tspan, true)
-    genbc_oop = ModelingToolkit.process_constraints(lksys, nothing, u0, [1, 2], tspan, false)
-
-    bvpi1 = SciMLBase.BVProblem(lotkavolterra!, bc!, [1.,2.], tspan, p)
-    bvpi2 = SciMLBase.BVProblem(lotkavolterra!, genbc_iip, [1.,2.], tspan, p)
-
-    sol1 = solve(bvpi1, MIRK4(), dt = 0.05)
-    sol2 = solve(bvpi2, MIRK4(), dt = 0.05)
-    @test sol1 ≈ sol2
-
-    bvpo1 = BVProblem(lotkavolterra, bc, [1,2], tspan, p)
-    bvpo2 = BVProblem(lotkavolterra, genbc_oop, [1,2], tspan, p)
-
-    sol1 = solve(bvpo1, MIRK4(), dt = 0.05)
-    sol2 = solve(bvpo2, MIRK4(), dt = 0.05)
-    @test sol1 ≈ sol2
-
-    # Test with a constraint.
-    constraints = [y(0.5) ~ 2.]
-
-    function bc!(resid, u, p, t) 
-        resid[1] = u(0.0)[1] - 1.
-        resid[2] = u(0.5)[2] - 2.
-    end
-    function bc(u, p, t)
-        [u(0.0)[1] - 1., u(0.5)[2] - 2.]
-    end
-
-    u0 = [1, 1.]
-    genbc_iip = ModelingToolkit.process_constraints(lksys, constraints, u0, [1], tspan, true)
-    genbc_oop = ModelingToolkit.process_constraints(lksys, constraints, u0, [1], tspan, false)
-
-    bvpi1 = SciMLBase.BVProblem(lotkavolterra!, bc!, u0, tspan, p)
-    bvpi2 = SciMLBase.BVProblem(lotkavolterra!, genbc_iip, u0, tspan, p)
-    bvpi3 = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lksys, [x(t) => 1.], tspan; guesses = [y(t) => 1.], constraints)
-    bvpi4 = SciMLBase.BVProblem{true, SciMLBase.FullSpecialize}(lksys, [x(t) => 1.], tspan; guesses = [y(t) => 1.], constraints)
-    
-    sol1 = @btime solve($bvpi1, MIRK4(), dt = 0.01)
-    sol2 = @btime solve($bvpi2, MIRK4(), dt = 0.01)
-    sol3 = @btime solve($bvpi3, MIRK4(), dt = 0.01)
-    sol4 = @btime solve($bvpi4, MIRK4(), dt = 0.01)
-    @test sol1 ≈ sol2 ≈ sol3 ≈ sol4 # don't get true equality here, not sure why
-
-    bvpo1 = BVProblem(lotkavolterra, bc, [1,2], tspan, p)
-    bvpo2 = BVProblem(lotkavolterra, genbc_oop, [1,2], tspan, p)
-    bvpo3 = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(lksys, [x(t) => 1.], tspan; guesses = [y(t) => 1.], constraints)
-
-    sol1 = @btime solve($bvpo1, MIRK4(), dt = 0.05)
-    sol2 = @btime solve($bvpo2, MIRK4(), dt = 0.05)
-    sol3 = @btime solve($bvpo3, MIRK4(), dt = 0.05)
-    @test sol1 ≈ sol2 ≈ sol3
-end
+# let
+#      @parameters α=7.5 β=4.0 γ=8.0 δ=5.0
+#      @variables x(t)=1.0 y(t)=2.0
+#      
+#      eqs = [D(x) ~ α * x - β * x * y,
+#          D(y) ~ -γ * y + δ * x * y]
+#      
+#      u0map = [x => 1.0, y => 2.0]
+#      parammap = [α => 7.5, β => 4, γ => 8.0, δ => 5.0]
+#      tspan = (0.0, 10.0)
+#      
+#      @mtkbuild lotkavolterra = ODESystem(eqs, t)
+#      op = ODEProblem(lotkavolterra, u0map, tspan, parammap)
+#      osol = solve(op, Vern9())
+#      
+#      bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(
+#          lotkavolterra, u0map, tspan, parammap; eval_expression = true)
+#      
+#      for solver in solvers
+#          sol = solve(bvp, solver(), dt = 0.01)
+#          @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
+#          @test sol.u[1] == [1.0, 2.0]
+#      end
+#      
+#      # Test out of place
+#      bvp2 = SciMLBase.BVProblem{false, SciMLBase.AutoSpecialize}(
+#          lotkavolterra, u0map, tspan, parammap; eval_expression = true)
+#      
+#      for solver in solvers
+#          sol = solve(bvp2, solver(), dt = 0.01)
+#          @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
+#          @test sol.u[1] == [1.0, 2.0]
+#      end
+# end
+# 
+# ### Testing on pendulum
+# let
+#      @parameters g=9.81 L=1.0
+#      @variables θ(t) = π / 2 θ_t(t)
+#      
+#      eqs = [D(θ) ~ θ_t
+#             D(θ_t) ~ -(g / L) * sin(θ)]
+#      
+#      @mtkbuild pend = ODESystem(eqs, t)
+#      
+#      u0map = [θ => π / 2, θ_t => π / 2]
+#      parammap = [:L => 1.0, :g => 9.81]
+#      tspan = (0.0, 6.0)
+#      
+#      op = ODEProblem(pend, u0map, tspan, parammap)
+#      osol = solve(op, Vern9())
+#      
+#      bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap)
+#      for solver in solvers
+#          sol = solve(bvp, solver(), dt = 0.01)
+#          @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
+#          @test sol.u[1] == [π / 2, π / 2]
+#      end
+#      
+#      # Test out-of-place
+#      bvp2 = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(pend, u0map, tspan, parammap)
+#      
+#      for solver in solvers
+#          sol = solve(bvp2, solver(), dt = 0.01)
+#          @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
+#          @test sol.u[1] == [π / 2, π / 2]
+#      end
+# end
+# 
+# ##################################################################
+# ### ODESystem with constraint equations, DAEs with constraints ###
+# ##################################################################
+# 
+# # Test generation of boundary condition function using `process_constraints`. Compare solutions to manually written boundary conditions
+# let
+#     @parameters α=1.5 β=1.0 γ=3.0 δ=1.0
+#     @variables x(..) y(..)
+#     eqs = [D(x(t)) ~ α * x(t) - β * x(t) * y(t),
+#            D(y(t)) ~ -γ * y(t) + δ * x(t) * y(t)]
+#     
+#     tspan = (0., 1.)
+#     @mtkbuild lksys = ODESystem(eqs, t)
+# 
+#     function lotkavolterra!(du, u, p, t) 
+#         du[1] = p[1]*u[1] - p[2]*u[1]*u[2]
+#         du[2] = -p[4]*u[2] + p[3]*u[1]*u[2]
+#     end
+# 
+#     function lotkavolterra(u, p, t) 
+#         [p[1]*u[1] - p[2]*u[1]*u[2], -p[4]*u[2] + p[3]*u[1]*u[2]]
+#     end
+#     # Compare the built bc function to the actual constructed one.
+#     function bc!(resid, u, p, t) 
+#         resid[1] = u[1][1] - 1.
+#         resid[2] = u[1][2] - 2.
+#         nothing
+#     end
+#     function bc(u, p, t)
+#         [u[1][1] - 1., u[1][2] - 2.]
+#     end
+# 
+#     u0 = [1., 2.]; p = [1.5, 1., 1., 3.]
+#     genbc_iip = ModelingToolkit.process_constraints(lksys, nothing, u0, [1, 2], tspan, true)
+#     genbc_oop = ModelingToolkit.process_constraints(lksys, nothing, u0, [1, 2], tspan, false)
+# 
+#     bvpi1 = SciMLBase.BVProblem(lotkavolterra!, bc!, [1.,2.], tspan, p)
+#     bvpi2 = SciMLBase.BVProblem(lotkavolterra!, genbc_iip, [1.,2.], tspan, p)
+# 
+#     sol1 = solve(bvpi1, MIRK4(), dt = 0.05)
+#     sol2 = solve(bvpi2, MIRK4(), dt = 0.05)
+#     @test sol1 ≈ sol2
+# 
+#     bvpo1 = BVProblem(lotkavolterra, bc, [1,2], tspan, p)
+#     bvpo2 = BVProblem(lotkavolterra, genbc_oop, [1,2], tspan, p)
+# 
+#     sol1 = solve(bvpo1, MIRK4(), dt = 0.05)
+#     sol2 = solve(bvpo2, MIRK4(), dt = 0.05)
+#     @test sol1 ≈ sol2
+# 
+#     # Test with a constraint.
+#     constraints = [y(0.5) ~ 2.]
+# 
+#     function bc!(resid, u, p, t) 
+#         resid[1] = u(0.0)[1] - 1.
+#         resid[2] = u(0.5)[2] - 2.
+#     end
+#     function bc(u, p, t)
+#         [u(0.0)[1] - 1., u(0.5)[2] - 2.]
+#     end
+# 
+#     u0 = [1, 1.]
+#     genbc_iip = ModelingToolkit.process_constraints(lksys, constraints, u0, [1], tspan, true)
+#     genbc_oop = ModelingToolkit.process_constraints(lksys, constraints, u0, [1], tspan, false)
+# 
+#     bvpi1 = SciMLBase.BVProblem(lotkavolterra!, bc!, u0, tspan, p)
+#     bvpi2 = SciMLBase.BVProblem(lotkavolterra!, genbc_iip, u0, tspan, p)
+#     bvpi3 = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lksys, [x(t) => 1.], tspan; guesses = [y(t) => 1.], constraints)
+#     bvpi4 = SciMLBase.BVProblem{true, SciMLBase.FullSpecialize}(lksys, [x(t) => 1.], tspan; guesses = [y(t) => 1.], constraints)
+#     
+#     sol1 = @btime solve($bvpi1, MIRK4(), dt = 0.01)
+#     sol2 = @btime solve($bvpi2, MIRK4(), dt = 0.01)
+#     sol3 = @btime solve($bvpi3, MIRK4(), dt = 0.01)
+#     sol4 = @btime solve($bvpi4, MIRK4(), dt = 0.01)
+#     @test sol1 ≈ sol2 ≈ sol3 ≈ sol4 # don't get true equality here, not sure why
+# 
+#     bvpo1 = BVProblem(lotkavolterra, bc, u0, tspan, p)
+#     bvpo2 = BVProblem(lotkavolterra, genbc_oop, u0, tspan, p)
+#     bvpo3 = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(lksys, [x(t) => 1.], tspan; guesses = [y(t) => 1.], constraints)
+# 
+#     sol1 = @btime solve($bvpo1, MIRK4(), dt = 0.05)
+#     sol2 = @btime solve($bvpo2, MIRK4(), dt = 0.05)
+#     sol3 = @btime solve($bvpo3, MIRK4(), dt = 0.05)
+#     @test sol1 ≈ sol2 ≈ sol3
+# end
 
-function test_solvers(solvers, prob, u0map, constraints, equations = []; dt = 0.05, atol = 1e-4)
+function test_solvers(solvers, prob, u0map, constraints, equations = []; dt = 0.05, atol = 1e-3)
     for solver in solvers
         println("Solver: $solver")
         sol = @btime solve($prob, $solver(), dt = $dt, abstol = $atol)
         @test SciMLBase.successful_retcode(sol.retcode)
         p = prob.p; t = sol.t; bc = prob.f.bc
         ns = length(prob.u0)
-        if isinplace(bvp.f)
+        if isinplace(prob.f)
             resid = zeros(ns)
             bc(resid, sol, p, t)
             @test isapprox(zeros(ns), resid; atol)
+            @show resid
         else
             @test isapprox(zeros(ns), bc(sol, p, t); atol)
+            @show bc(sol, p, t)
         end
 
         for (k, v) in u0map
@@ -194,6 +197,12 @@ function test_solvers(solvers, prob, u0map, constraints, equations = []; dt = 0.
     end
 end
 
+solvers = [RadauIIa3, RadauIIa5, RadauIIa7,
+           LobattoIIIa2, LobattoIIIa4, LobattoIIIa5,
+           LobattoIIIb2, LobattoIIIb3, LobattoIIIb4, LobattoIIIb5,
+           LobattoIIIc2, LobattoIIIc3, LobattoIIIc4, LobattoIIIc5]
+weird = [MIRK2, MIRK5, RadauIIa2]
+daesolvers = []
 # Simple ODESystem with BVP constraints.
 let
     @parameters α=1.5 β=1.0 γ=3.0 δ=1.0
@@ -213,8 +222,8 @@ let
     test_solvers(solvers, bvp, u0map, constraints; dt = 0.05)
 
     # Testing that more complicated constraints give correct solutions.
-    constraints = [y(.2) + x(.8) ~ 3., y(.3) + x(.5) ~ 5.]
-    bvp = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(lksys, u0map, tspan; guesses, constraints, jac = true)
+    constraints = [y(.2) + x(.8) ~ 3., y(.3) ~ 2.]
+    bvp = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(lksys, u0map, tspan; guesses, constraints)
     test_solvers(solvers, bvp, u0map, constraints; dt = 0.05)
 
     constraints = [α * β - x(.6) ~ 0.0, y(.2) ~ 3.]
@@ -224,14 +233,14 @@ let
     # Testing that errors are properly thrown when malformed constraints are given.
     @variables bad(..)
     constraints = [x(1.) + bad(3.) ~ 10]
-    @test_throws Exception bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lksys, u0map, tspan; guesses, constraints)
+    @test_throws ErrorException bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lksys, u0map, tspan; guesses, constraints)
 
     constraints = [x(t) + y(t) ~ 3]
-    @test_throws Exception bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lksys, u0map, tspan; guesses, constraints)
+    @test_throws ErrorException bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lksys, u0map, tspan; guesses, constraints)
 
     @parameters bad2
     constraints = [bad2 + x(0.) ~ 3]
-    @test_throws Exception bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lksys, u0map, tspan; guesses, constraints)
+    @test_throws ErrorException bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lksys, u0map, tspan; guesses, constraints)
 end
 
 # Cartesian pendulum from the docs.

From 76e515c4c55af8b9ca3bf8a3318d1dd3339702a3 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Thu, 23 Jan 2025 12:33:15 -0500
Subject: [PATCH 026/111] Init

---
 .../discrete_system/implicitdiscretesystem.jl | 406 ++++++++++++++++++
 1 file changed, 406 insertions(+)
 create mode 100644 src/systems/discrete_system/implicitdiscretesystem.jl

diff --git a/src/systems/discrete_system/implicitdiscretesystem.jl b/src/systems/discrete_system/implicitdiscretesystem.jl
new file mode 100644
index 0000000000..404fa7ff49
--- /dev/null
+++ b/src/systems/discrete_system/implicitdiscretesystem.jl
@@ -0,0 +1,406 @@
+"""
+$(TYPEDEF)
+An implicit system of difference equations.
+# Fields
+$(FIELDS)
+# Example
+```
+using ModelingToolkit
+using ModelingToolkit: t_nounits as t
+@parameters σ=28.0 ρ=10.0 β=8/3 δt=0.1
+@variables x(t)=1.0 y(t)=0.0 z(t)=0.0
+k = ShiftIndex(t)
+eqs = [x(k+1) ~ σ*(y-x),
+       y(k+1) ~ x*(ρ-z)-y,
+       z(k+1) ~ x*y - β*z]
+@named de = ImplicitDiscreteSystem(eqs,t,[x,y,z],[σ,ρ,β]; tspan = (0, 1000.0)) # or
+@named de = ImplicitDiscreteSystem(eqs)
+```
+"""
+struct ImplicitDiscreteSystem <: AbstractTimeDependentSystem
+    """
+    A tag for the system. If two systems have the same tag, then they are
+    structurally identical.
+    """
+    tag::UInt
+    """The differential equations defining the discrete system."""
+    eqs::Vector{Equation}
+    """Independent variable."""
+    iv::BasicSymbolic{Real}
+    """Dependent (state) variables. Must not contain the independent variable."""
+    unknowns::Vector
+    """Parameter variables. Must not contain the independent variable."""
+    ps::Vector
+    """Time span."""
+    tspan::Union{NTuple{2, Any}, Nothing}
+    """Array variables."""
+    var_to_name::Any
+    """Observed states."""
+    observed::Vector{Equation}
+    """
+    The name of the system
+    """
+    name::Symbol
+    """
+    A description of the system.
+    """
+    description::String
+    """
+    The internal systems. These are required to have unique names.
+    """
+    systems::Vector{DiscreteSystem}
+    """
+    The default values to use when initial conditions and/or
+    parameters are not supplied in `DiscreteProblem`.
+    """
+    defaults::Dict
+    """
+    The guesses to use as the initial conditions for the
+    initialization system.
+    """
+    guesses::Dict
+    """
+    The system for performing the initialization.
+    """
+    initializesystem::Union{Nothing, NonlinearSystem}
+    """
+    Extra equations to be enforced during the initialization sequence.
+    """
+    initialization_eqs::Vector{Equation}
+    """
+    Inject assignment statements before the evaluation of the RHS function.
+    """
+    preface::Any
+    """
+    Type of the system.
+    """
+    connector_type::Any
+    """
+    Topologically sorted parameter dependency equations, where all symbols are parameters and
+    the LHS is a single parameter.
+    """
+    parameter_dependencies::Vector{Equation}
+    """
+    Metadata for the system, to be used by downstream packages.
+    """
+    metadata::Any
+    """
+    Metadata for MTK GUI.
+    """
+    gui_metadata::Union{Nothing, GUIMetadata}
+    """
+    Cache for intermediate tearing state.
+    """
+    tearing_state::Any
+    """
+    Substitutions generated by tearing.
+    """
+    substitutions::Any
+    """
+    If a model `sys` is complete, then `sys.x` no longer performs namespacing.
+    """
+    complete::Bool
+    """
+    Cached data for fast symbolic indexing.
+    """
+    index_cache::Union{Nothing, IndexCache}
+    """
+    The hierarchical parent system before simplification.
+    """
+    parent::Any
+    isscheduled::Bool
+
+    function ImplicitDiscreteSystem(tag, discreteEqs, iv, dvs, ps, tspan, var_to_name,
+            observed, name, description, systems, defaults, guesses, initializesystem,
+            initialization_eqs, preface, connector_type, parameter_dependencies = Equation[],
+            metadata = nothing, gui_metadata = nothing,
+            tearing_state = nothing, substitutions = nothing,
+            complete = false, index_cache = nothing, parent = nothing,
+            isscheduled = false;
+            checks::Union{Bool, Int} = true)
+        if checks == true || (checks & CheckComponents) > 0
+            check_independent_variables([iv])
+            check_variables(dvs, iv)
+            check_parameters(ps, iv)
+        end
+        if checks == true || (checks & CheckUnits) > 0
+            u = __get_unit_type(dvs, ps, iv)
+            check_units(u, discreteEqs)
+        end
+        new(tag, discreteEqs, iv, dvs, ps, tspan, var_to_name, observed, name, description,
+            systems, defaults, guesses, initializesystem, initialization_eqs,
+            preface, connector_type, parameter_dependencies, metadata, gui_metadata,
+            tearing_state, substitutions, complete, index_cache, parent, isscheduled)
+    end
+end
+
+"""
+    $(TYPEDSIGNATURES)
+Constructs a DiscreteSystem.
+"""
+function ImplicitDiscreteSystem(eqs::AbstractVector{<:Equation}, iv, dvs, ps;
+        observed = Num[],
+        systems = DiscreteSystem[],
+        tspan = nothing,
+        name = nothing,
+        description = "",
+        default_u0 = Dict(),
+        default_p = Dict(),
+        guesses = Dict(),
+        initializesystem = nothing,
+        initialization_eqs = Equation[],
+        defaults = _merge(Dict(default_u0), Dict(default_p)),
+        preface = nothing,
+        connector_type = nothing,
+        parameter_dependencies = Equation[],
+        metadata = nothing,
+        gui_metadata = nothing,
+        kwargs...)
+    name === nothing &&
+        throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
+    iv′ = value(iv)
+    dvs′ = value.(dvs)
+    ps′ = value.(ps)
+    if any(hasderiv, eqs) || any(hashold, eqs) || any(hassample, eqs) || any(hasdiff, eqs)
+        error("Equations in a `DiscreteSystem` can only have `Shift` operators.")
+    end
+    if !(isempty(default_u0) && isempty(default_p))
+        Base.depwarn(
+            "`default_u0` and `default_p` are deprecated. Use `defaults` instead.",
+            :DiscreteSystem, force = true)
+    end
+
+    defaults = Dict{Any, Any}(todict(defaults))
+    guesses = Dict{Any, Any}(todict(guesses))
+    var_to_name = Dict()
+    process_variables!(var_to_name, defaults, guesses, dvs′)
+    process_variables!(var_to_name, defaults, guesses, ps′)
+    process_variables!(
+        var_to_name, defaults, guesses, [eq.lhs for eq in parameter_dependencies])
+    process_variables!(
+        var_to_name, defaults, guesses, [eq.rhs for eq in parameter_dependencies])
+    defaults = Dict{Any, Any}(value(k) => value(v)
+    for (k, v) in pairs(defaults) if v !== nothing)
+    guesses = Dict{Any, Any}(value(k) => value(v)
+    for (k, v) in pairs(guesses) if v !== nothing)
+
+    isempty(observed) || collect_var_to_name!(var_to_name, (eq.lhs for eq in observed))
+
+    sysnames = nameof.(systems)
+    if length(unique(sysnames)) != length(sysnames)
+        throw(ArgumentError("System names must be unique."))
+    end
+    DiscreteSystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
+        eqs, iv′, dvs′, ps′, tspan, var_to_name, observed, name, description, systems,
+        defaults, guesses, initializesystem, initialization_eqs, preface, connector_type,
+        parameter_dependencies, metadata, gui_metadata, kwargs...)
+end
+
+function ImplicitDiscreteSystem(eqs, iv; kwargs...)
+    eqs = collect(eqs)
+    diffvars = OrderedSet()
+    allunknowns = OrderedSet()
+    ps = OrderedSet()
+    iv = value(iv)
+    for eq in eqs
+        collect_vars!(allunknowns, ps, eq, iv; op = Shift)
+        if iscall(eq.lhs) && operation(eq.lhs) isa Shift
+            isequal(iv, operation(eq.lhs).t) ||
+                throw(ArgumentError("A DiscreteSystem can only have one independent variable."))
+            eq.lhs in diffvars &&
+                throw(ArgumentError("The shift variable $(eq.lhs) is not unique in the system of equations."))
+            push!(diffvars, eq.lhs)
+        end
+    end
+    for eq in get(kwargs, :parameter_dependencies, Equation[])
+        if eq isa Pair
+            collect_vars!(allunknowns, ps, eq, iv)
+        else
+            collect_vars!(allunknowns, ps, eq, iv)
+        end
+    end
+    new_ps = OrderedSet()
+    for p in ps
+        if iscall(p) && operation(p) === getindex
+            par = arguments(p)[begin]
+            if Symbolics.shape(Symbolics.unwrap(par)) !== Symbolics.Unknown() &&
+               all(par[i] in ps for i in eachindex(par))
+                push!(new_ps, par)
+            else
+                push!(new_ps, p)
+            end
+        else
+            push!(new_ps, p)
+        end
+    end
+    return DiscreteSystem(eqs, iv,
+        collect(allunknowns), collect(new_ps); kwargs...)
+end
+
+function flatten(sys::DiscreteSystem, noeqs = false)
+    systems = get_systems(sys)
+    if isempty(systems)
+        return sys
+    else
+        return DiscreteSystem(noeqs ? Equation[] : equations(sys),
+            get_iv(sys),
+            unknowns(sys),
+            parameters(sys),
+            observed = observed(sys),
+            defaults = defaults(sys),
+            guesses = guesses(sys),
+            initialization_eqs = initialization_equations(sys),
+            name = nameof(sys),
+            description = description(sys),
+            metadata = get_metadata(sys),
+            checks = false)
+    end
+end
+
+function generate_function(
+        sys::DiscreteSystem, dvs = unknowns(sys), ps = parameters(sys); wrap_code = identity, kwargs...)
+    exprs = [eq.rhs for eq in equations(sys)]
+    wrap_code = wrap_code .∘ wrap_array_vars(sys, exprs) .∘
+                wrap_parameter_dependencies(sys, false)
+    generate_custom_function(sys, exprs, dvs, ps; wrap_code, kwargs...)
+end
+
+function shift_u0map_forward(sys::DiscreteSystem, u0map, defs)
+    iv = get_iv(sys)
+    updated = AnyDict()
+    for k in collect(keys(u0map))
+        v = u0map[k]
+        if !((op = operation(k)) isa Shift)
+            error("Initial conditions must be for the past state of the unknowns. Instead of providing the condition for $k, provide the condition for $(Shift(iv, -1)(k)).")
+        end
+        updated[Shift(iv, op.steps + 1)(arguments(k)[1])] = v
+    end
+    for var in unknowns(sys)
+        op = operation(var)
+        op isa Shift || continue
+        haskey(updated, var) && continue
+        root = first(arguments(var))
+        haskey(defs, root) || error("Initial condition for $var not provided.")
+        updated[var] = defs[root]
+    end
+    return updated
+end
+
+"""
+    $(TYPEDSIGNATURES)
+Generates an ImplicitDiscreteProblem from an ImplicitDiscreteSystem.
+"""
+function SciMLBase.ImplicitDiscreteProblem(
+        sys::DiscreteSystem, u0map = [], tspan = get_tspan(sys),
+        parammap = SciMLBase.NullParameters();
+        eval_module = @__MODULE__,
+        eval_expression = false,
+        use_union = false,
+        kwargs...
+)
+    if !iscomplete(sys)
+        error("A completed `ImplicitDiscreteSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `ImplicitDiscreteProblem`")
+    end
+    dvs = unknowns(sys)
+    ps = parameters(sys)
+    eqs = equations(sys)
+    iv = get_iv(sys)
+
+    u0map = to_varmap(u0map, dvs)
+    u0map = shift_u0map_forward(sys, u0map, defaults(sys))
+    f, u0, p = process_SciMLProblem(
+        DiscreteFunction, sys, u0map, parammap; eval_expression, eval_module)
+    u0 = f(u0, p, tspan[1])
+    ImplicitDiscreteProblem(f, u0, tspan, p; kwargs...)
+end
+
+function SciMLBase.ImplicitDiscreteFunction(sys::ImplicitDiscreteSystem, args...; kwargs...)
+    ImplicitDiscreteFunction{true}(sys, args...; kwargs...)
+end
+
+function SciMLBase.ImplicitDiscreteFunction{true}(sys::ImplicitDiscreteSystem, args...; kwargs...)
+    ImplicitDiscreteFunction{true, SciMLBase.AutoSpecialize}(sys, args...; kwargs...)
+end
+
+function SciMLBase.ImplicitDiscreteFunction{false}(sys::ImplicitDiscreteSystem, args...; kwargs...)
+    ImplicitDiscreteFunction{false, SciMLBase.FullSpecialize}(sys, args...; kwargs...)
+end
+function SciMLBase.ImplicitDiscreteFunction{iip, specialize}(
+        sys::ImplicitDiscreteSystem,
+        dvs = unknowns(sys),
+        ps = parameters(sys),
+        u0 = nothing;
+        version = nothing,
+        p = nothing,
+        t = nothing,
+        eval_expression = false,
+        eval_module = @__MODULE__,
+        analytic = nothing,
+        kwargs...) where {iip, specialize}
+    if !iscomplete(sys)
+        error("A completed `ImplicitDiscreteSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `ImplicitDiscreteProblem`")
+    end
+    f_gen = generate_function(sys, dvs, ps; expression = Val{true},
+        expression_module = eval_module, kwargs...)
+    f_oop, f_iip = eval_or_rgf.(f_gen; eval_expression, eval_module)
+    f(u, p, t) = f_oop(u, p, t)
+    f(du, u, p, t) = f_iip(du, u, p, t)
+
+    if specialize === SciMLBase.FunctionWrapperSpecialize && iip
+        if u0 === nothing || p === nothing || t === nothing
+            error("u0, p, and t must be specified for FunctionWrapperSpecialize on DiscreteFunction.")
+        end
+        f = SciMLBase.wrapfun_iip(f, (u0, u0, p, t))
+    end
+
+    observedfun = ObservedFunctionCache(
+        sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))
+
+    ImplicitDiscreteFunction{iip, specialize}(f;
+        sys = sys,
+        observed = observedfun,
+        analytic = analytic)
+end
+
+"""
+```julia
+ImplicitDiscreteFunctionExpr{iip}(sys::ImplicitDiscreteSystem, dvs = states(sys),
+                                  ps = parameters(sys);
+                                  version = nothing,
+                                  kwargs...) where {iip}
+```
+
+Create a Julia expression for an `ImplicitDiscreteFunction` from the [`ImplicitDiscreteSystem`](@ref).
+The arguments `dvs` and `ps` are used to set the order of the dependent
+variable and parameter vectors, respectively.
+"""
+struct ImplicitDiscreteFunctionExpr{iip} end
+struct ImplicitDiscreteFunctionClosure{O, I} <: Function
+    f_oop::O
+    f_iip::I
+end
+(f::ImplicitDiscreteFunctionClosure)(u, p, t) = f.f_oop(u, p, t)
+(f::ImplicitDiscreteFunctionClosure)(du, u, p, t) = f.f_iip(du, u, p, t)
+
+function ImplicitDiscreteFunctionExpr{iip}(sys::ImplicitDiscreteSystem, dvs = unknowns(sys),
+        ps = parameters(sys), u0 = nothing;
+        version = nothing, p = nothing,
+        linenumbers = false,
+        simplify = false,
+        kwargs...) where {iip}
+    f_oop, f_iip = generate_function(sys, dvs, ps; expression = Val{true}, kwargs...)
+
+    fsym = gensym(:f)
+    _f = :($fsym = $ImplicitDiscreteFunctionClosure($f_oop, $f_iip))
+
+    ex = quote
+        $_f
+        DiscreteFunction{$iip}($fsym)
+    end
+    !linenumbers ? Base.remove_linenums!(ex) : ex
+end
+
+function ImplicitDiscreteFunctionExpr(sys::ImplicitDiscreteSystem, args...; kwargs...)
+    ImplicitDiscreteFunctionExpr{true}(sys, args...; kwargs...)
+end
+

From 9530a2e6bce78f9fdd49ed1841b9ed6120e04ad2 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Thu, 23 Jan 2025 16:31:25 -0500
Subject: [PATCH 027/111] init

---
 src/systems/parameter_buffer.jl |  6 +++--
 src/systems/problem_utils.jl    | 42 +++++++++++++++++++++++++++++++--
 test/problem_validation.jl      | 24 +++++++++++++++++++
 3 files changed, 68 insertions(+), 4 deletions(-)
 create mode 100644 test/problem_validation.jl

diff --git a/src/systems/parameter_buffer.jl b/src/systems/parameter_buffer.jl
index bc4e62a773..7a88489b48 100644
--- a/src/systems/parameter_buffer.jl
+++ b/src/systems/parameter_buffer.jl
@@ -33,17 +33,19 @@ function MTKParameters(
     else
         error("Cannot create MTKParameters if system does not have index_cache")
     end
+
     all_ps = Set(unwrap.(parameters(sys)))
     union!(all_ps, default_toterm.(unwrap.(parameters(sys))))
     if p isa Vector && !(eltype(p) <: Pair) && !isempty(p)
         ps = parameters(sys)
-        length(p) == length(ps) || error("Invalid parameters")
+        length(p) == length(ps) || error("The number of parameter values is not equal to the number of parameters.")
         p = ps .=> p
     end
     if p isa SciMLBase.NullParameters || isempty(p)
         p = Dict()
     end
     p = todict(p)
+
     defs = Dict(default_toterm(unwrap(k)) => v for (k, v) in defaults(sys))
     if eltype(u0) <: Pair
         u0 = todict(u0)
@@ -761,7 +763,7 @@ end
 
 function Base.showerror(io::IO, e::MissingParametersError)
     println(io, MISSING_PARAMETERS_MESSAGE)
-    println(io, e.vars)
+    println(io, join(e.vars, ", "))
 end
 
 function InvalidParameterSizeException(param, val)
diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl
index bf3d72e1e5..e2af471a15 100644
--- a/src/systems/problem_utils.jl
+++ b/src/systems/problem_utils.jl
@@ -684,12 +684,17 @@ function process_SciMLProblem(
 
     u0Type = typeof(u0map)
     pType = typeof(pmap)
-    _u0map = u0map
+
     u0map = to_varmap(u0map, dvs)
     symbols_to_symbolics!(sys, u0map)
-    _pmap = pmap
+    check_keys(sys, u0map)
+
     pmap = to_varmap(pmap, ps)
     symbols_to_symbolics!(sys, pmap)
+    check_keys(sys, pmap)
+    badkeys = filter(k -> symbolic_type(k) === NotSymbolic(), keys(pmap))
+    isempty(badkeys) || throw(BadKeyError(collect(badkeys)))
+
     defs = add_toterms(recursive_unwrap(defaults(sys)))
     cmap, cs = get_cmap(sys)
     kwargs = NamedTuple(kwargs)
@@ -778,6 +783,39 @@ function process_SciMLProblem(
     implicit_dae ? (f, du0, u0, p) : (f, u0, p)
 end
 
+# Check that the keys of a u0map or pmap are valid
+# (i.e. are symbolic keys, and are defined for the system.)
+function check_keys(sys, map) 
+    badkeys = Any[]
+    for k in keys(map)
+        if symbolic_type(k) === NotSymbolic()
+            push!(badkeys, k)
+        elseif k isa Symbol
+            !hasproperty(sys, k) && push!(badkeys, k)
+        elseif k ∉ Set(parameters(sys)) && k ∉ Set(unknowns(sys)) 
+            push!(badkeys, k)
+        end
+    end
+
+    isempty(badkeys) || throw(BadKeyError(collect(badkeys)))
+end
+
+const BAD_KEY_MESSAGE = """
+                        Undefined keys found in the parameter or initial condition maps. 
+                        The following keys are either invalid or not parameters/states of the system:
+                        """
+
+struct BadKeyError <: Exception
+    vars::Any
+end
+
+function Base.showerror(io::IO, e::BadKeyError) 
+    println(io, BAD_KEY_MESSAGE)
+    println(io, join(e.vars, ", "))
+end
+
+
+
 ##############
 # Legacy functions for backward compatibility
 ##############
diff --git a/test/problem_validation.jl b/test/problem_validation.jl
new file mode 100644
index 0000000000..fb724c55bd
--- /dev/null
+++ b/test/problem_validation.jl
@@ -0,0 +1,24 @@
+using ModelingToolkit
+using ModelingToolkit: t_nounits as t, D_nounits as D
+
+@testset "Input map validation" begin
+    @variables X(t)
+    @parameters p d
+    eqs = [D(X) ~ p - d*X]
+    @mtkbuild osys = ODESystem(eqs, t)
+    
+    p = "I accidentally renamed p"
+    u0 = [X => 1.0]
+    ps = [p => 1.0, d => 0.5]
+    @test_throws ModelingToolkit.BadKeyError oprob = ODEProblem(osys, u0, (0.0, 1.0), ps)
+    
+    ps = [p => 1.0, d => 0.5, "Random stuff" => 3.0]
+    @test_throws ModelingToolkit.BadKeyError oprob = ODEProblem(osys, u0, (0.0, 1.0), ps)
+
+    u0 = [:X => 1.0, "random" => 3.0]
+    @test_throws ModelingToolkit.BadKeyError oprob = ODEProblem(osys, u0, (0.0, 1.0), ps)
+
+    @parameters k
+    ps = [p => 1., d => 0.5, k => 3.]
+    @test_throws ModelingToolkit.BadKeyError oprob = ODEProblem(osys, u0, (0.0, 1.0), ps)
+end

From a9dc112905ced2e1b1b16e1af0e179b0604f2563 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Thu, 23 Jan 2025 16:34:38 -0500
Subject: [PATCH 028/111] up

---
 test/problem_validation.jl | 1 +
 1 file changed, 1 insertion(+)

diff --git a/test/problem_validation.jl b/test/problem_validation.jl
index fb724c55bd..f871327ae8 100644
--- a/test/problem_validation.jl
+++ b/test/problem_validation.jl
@@ -12,6 +12,7 @@ using ModelingToolkit: t_nounits as t, D_nounits as D
     ps = [p => 1.0, d => 0.5]
     @test_throws ModelingToolkit.BadKeyError oprob = ODEProblem(osys, u0, (0.0, 1.0), ps)
     
+    @parameters p d
     ps = [p => 1.0, d => 0.5, "Random stuff" => 3.0]
     @test_throws ModelingToolkit.BadKeyError oprob = ODEProblem(osys, u0, (0.0, 1.0), ps)
 

From 3c629ac227950886235d85307f3a82c7b3183ac7 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Thu, 23 Jan 2025 16:36:30 -0500
Subject: [PATCH 029/111] up

---
 src/systems/problem_utils.jl | 2 --
 1 file changed, 2 deletions(-)

diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl
index e2af471a15..8c6082e436 100644
--- a/src/systems/problem_utils.jl
+++ b/src/systems/problem_utils.jl
@@ -692,8 +692,6 @@ function process_SciMLProblem(
     pmap = to_varmap(pmap, ps)
     symbols_to_symbolics!(sys, pmap)
     check_keys(sys, pmap)
-    badkeys = filter(k -> symbolic_type(k) === NotSymbolic(), keys(pmap))
-    isempty(badkeys) || throw(BadKeyError(collect(badkeys)))
 
     defs = add_toterms(recursive_unwrap(defaults(sys)))
     cmap, cs = get_cmap(sys)

From 417b386a24a6c8c1ed8c49ee6a5c2581a562afae Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Fri, 24 Jan 2025 08:48:29 -0500
Subject: [PATCH 030/111] just check not-symbolic

---
 src/systems/problem_utils.jl | 4 ----
 1 file changed, 4 deletions(-)

diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl
index 8c6082e436..8541056272 100644
--- a/src/systems/problem_utils.jl
+++ b/src/systems/problem_utils.jl
@@ -788,10 +788,6 @@ function check_keys(sys, map)
     for k in keys(map)
         if symbolic_type(k) === NotSymbolic()
             push!(badkeys, k)
-        elseif k isa Symbol
-            !hasproperty(sys, k) && push!(badkeys, k)
-        elseif k ∉ Set(parameters(sys)) && k ∉ Set(unknowns(sys)) 
-            push!(badkeys, k)
         end
     end
 

From f1d2a0754ae33524a063d028b6d2515a9cabc06c Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Fri, 24 Jan 2025 11:20:11 -0500
Subject: [PATCH 031/111] rename

---
 .../{implicitdiscretesystem.jl => implicit_discrete_system.jl}    | 0
 1 file changed, 0 insertions(+), 0 deletions(-)
 rename src/systems/discrete_system/{implicitdiscretesystem.jl => implicit_discrete_system.jl} (100%)

diff --git a/src/systems/discrete_system/implicitdiscretesystem.jl b/src/systems/discrete_system/implicit_discrete_system.jl
similarity index 100%
rename from src/systems/discrete_system/implicitdiscretesystem.jl
rename to src/systems/discrete_system/implicit_discrete_system.jl

From 2bd65e2702544b1ed561449665115247b65de121 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Fri, 24 Jan 2025 16:06:16 -0500
Subject: [PATCH 032/111] up

---
 src/systems/systems.jl | 14 +++++++-------
 1 file changed, 7 insertions(+), 7 deletions(-)

diff --git a/src/systems/systems.jl b/src/systems/systems.jl
index 04c50bc766..ffbee75b33 100644
--- a/src/systems/systems.jl
+++ b/src/systems/systems.jl
@@ -39,13 +39,13 @@ function structural_simplify(
     else
         newsys = newsys′
     end
-    if newsys isa DiscreteSystem &&
-       any(eq -> symbolic_type(eq.lhs) == NotSymbolic(), equations(newsys))
-        error("""
-            Encountered algebraic equations when simplifying discrete system. This is \
-            not yet supported.
-        """)
-    end
+    # if newsys isa DiscreteSystem &&
+    #    any(eq -> symbolic_type(eq.lhs) == NotSymbolic(), equations(newsys))
+    #     error("""
+    #         Encountered algebraic equations when simplifying discrete system. This is \
+    #         not yet supported.
+    #     """)
+    # end
     for pass in additional_passes
         newsys = pass(newsys)
     end

From ec386fe041595284c9ab44a0f54fe9a6c232155e Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Mon, 27 Jan 2025 23:02:07 -0500
Subject: [PATCH 033/111] Refactor constraints

---
 src/ModelingToolkit.jl                        |   8 +-
 src/systems/abstractsystem.jl                 |   1 +
 src/systems/diffeqs/abstractodesystem.jl      | 121 ++----
 src/systems/diffeqs/bvpsystem.jl              | 217 ++++++++++
 src/systems/diffeqs/odesystem.jl              |  86 +++-
 .../optimization/constraints_system.jl        |   5 +
 test/bvproblem.jl                             | 373 +++++++++---------
 test/odesystem.jl                             |  30 ++
 8 files changed, 568 insertions(+), 273 deletions(-)
 create mode 100644 src/systems/diffeqs/bvpsystem.jl

diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl
index 10aba4d8a9..2d53c61401 100644
--- a/src/ModelingToolkit.jl
+++ b/src/ModelingToolkit.jl
@@ -150,6 +150,10 @@ include("systems/imperative_affect.jl")
 include("systems/callbacks.jl")
 include("systems/problem_utils.jl")
 
+include("systems/optimization/constraints_system.jl")
+include("systems/optimization/optimizationsystem.jl")
+include("systems/optimization/modelingtoolkitize.jl")
+
 include("systems/nonlinear/nonlinearsystem.jl")
 include("systems/nonlinear/homotopy_continuation.jl")
 include("systems/diffeqs/odesystem.jl")
@@ -165,10 +169,6 @@ include("systems/discrete_system/discrete_system.jl")
 
 include("systems/jumps/jumpsystem.jl")
 
-include("systems/optimization/constraints_system.jl")
-include("systems/optimization/optimizationsystem.jl")
-include("systems/optimization/modelingtoolkitize.jl")
-
 include("systems/pde/pdesystem.jl")
 
 include("systems/sparsematrixclil.jl")
diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl
index 168260ae69..7c0fe0285d 100644
--- a/src/systems/abstractsystem.jl
+++ b/src/systems/abstractsystem.jl
@@ -983,6 +983,7 @@ for prop in [:eqs
              :structure
              :op
              :constraints
+             :constraintsystem
              :controls
              :loss
              :bcs
diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl
index cf6e0962fd..d0d6ed7937 100644
--- a/src/systems/diffeqs/abstractodesystem.jl
+++ b/src/systems/diffeqs/abstractodesystem.jl
@@ -827,6 +827,12 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
     if !iscomplete(sys)
         error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `ODEProblem`")
     end
+
+    if !isnothing(get_constraintsystem(sys))
+        error("An ODESystem with constraints cannot be used to construct a regular ODEProblem. 
+              Consider a BVProblem instead.")
+    end
+
     f, u0, p = process_SciMLProblem(ODEFunction{iip, specialize}, sys, u0map, parammap;
         t = tspan !== nothing ? tspan[1] : tspan,
         check_length, warn_initialize_determined, eval_expression, eval_module, kwargs...)
@@ -866,18 +872,23 @@ Create a boundary value problem from the [`ODESystem`](@ref).
 must have either an initial guess supplied using `guesses` or a fixed initial 
 value specified using `u0map`.
 
-`constraints` are used to specify boundary conditions to the ODESystem in the
-form of equations. These values should specify values that state variables should
+Boundary value conditions are supplied to ODESystems
+in the form of a ConstraintsSystem. These equations 
+should specify values that state variables should
 take at specific points, as in `x(0.5) ~ 1`). More general constraints that 
 should hold over the entire solution, such as `x(t)^2 + y(t)^2`, should be 
-specified as one of the equations used to build the `ODESystem`. Below is an example.
+specified as one of the equations used to build the `ODESystem`.
+
+If an ODESystem without `constraints` is specified, it will be treated as an initial value problem. 
 
 ```julia
-    @parameters g
+    @parameters g t_c = 0.5
     @variables x(..) y(t) [state_priority = 10] λ(t)
     eqs = [D(D(x(t))) ~ λ * x(t)
            D(D(y)) ~ λ * y - g
            x(t)^2 + y^2 ~ 1]
+    cstr = [x(0.5) ~ 1]
+    @named cstrs = ConstraintsSystem(cstr, t)
     @mtkbuild pend = ODESystem(eqs, t)
 
     tspan = (0.0, 1.5)
@@ -889,9 +900,7 @@ specified as one of the equations used to build the `ODESystem`. Below is an exa
     bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; constraints, guesses, check_length = false)
 ```
 
-If no `constraints` are specified, the problem will be treated as an initial value problem.
-
-If the `ODESystem` has algebraic equations like `x(t)^2 + y(t)^2`, the resulting 
+If the `ODESystem` has algebraic equations, like `x(t)^2 + y(t)^2`, the resulting 
 `BVProblem` must be solved using BVDAE solvers, such as Ascher.
 """
 function SciMLBase.BVProblem(sys::AbstractODESystem, args...; kwargs...)
@@ -916,7 +925,7 @@ end
 function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = [],
         tspan = get_tspan(sys),
         parammap = DiffEqBase.NullParameters();
-        constraints = nothing, guesses = Dict(),
+        guesses = Dict(),
         version = nothing, tgrad = false,
         callback = nothing,
         check_length = true,
@@ -930,21 +939,14 @@ function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = []
     end
     !isnothing(callback) && error("BVP solvers do not support callbacks.")
 
-    has_alg_eqs(sys) && error("The BVProblem currently does not support ODESystems with algebraic equations.") # Remove this when the BVDAE solvers get updated, the codegen should work when it does.
+    has_alg_eqs(sys) && error("The BVProblem constructor currently does not support ODESystems with algebraic equations.") # Remove this when the BVDAE solvers get updated, the codegen should work when it does.
 
-    constraintsts = nothing
-    constraintps = nothing
     sts = unknowns(sys)
     ps = parameters(sys)
 
-    # Constraint validation
     if !isnothing(constraints)
-        constraints isa Equation || 
-            constraints isa Vector{Equation} || 
-            error("Constraints must be specified as an equation or a vector of equations.")
-
         (length(constraints) + length(u0map) > length(sts)) && 
-        error("The BVProblem is overdetermined. The total number of conditions (# constraints + # fixed initial values given by u0map) cannot exceed the total number of states.")
+        @warn "The BVProblem is overdetermined. The total number of conditions (# constraints + # fixed initial values given by u0map) exceeds the total number of states. The BVP solvers will default to doing a nonlinear least-squares optimization."
     end
 
     # ODESystems without algebraic equations should use both fixed values + guesses
@@ -957,48 +959,25 @@ function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = []
     stidxmap = Dict([v => i for (i, v) in enumerate(sts)])
     u0_idxs = has_alg_eqs(sys) ? collect(1:length(sts)) : [stidxmap[k] for (k,v) in u0map]
 
-    bc = process_constraints(sys, constraints, u0, u0_idxs, tspan, iip)
-
+    bc = generate_function_bc(sys, u0, u0_idxs, tspan, iip)
     return BVProblem{iip}(f, bc, u0, tspan, p; kwargs...)
 end
 
 get_callback(prob::BVProblem) = error("BVP solvers do not support callbacks.")
 
-# Validate that all the variables in the BVP constraints are well-formed states or parameters.
-function validate_constraint_syms(eq, constraintsts, constraintps, sts, ps, iv) 
-    for var in constraintsts
-        if length(arguments(var)) > 1
-            error("Too many arguments for variable $var.")
-        elseif isequal(arguments(var)[1], iv)
-            var ∈ sts || error("Constraint equation $eq contains a variable $var that is not a variable of the ODESystem.")
-            error("Constraint equation $eq contains a variable $var that does not have a specified argument. Such equations should be specified as algebraic equations to the ODESystem rather than a boundary constraints.")
-        else
-            operation(var)(iv) ∈ sts || error("Constraint equation $eq contains a variable $(operation(var)) that is not a variable of the ODESystem.")
-        end
-    end
-
-    for var in constraintps
-        if !iscall(var)
-            var ∈ ps || error("Constraint equation $eq contains a parameter $var that is not a parameter of the ODESystem.")
-        else
-            length(arguments(var)) > 1 && error("Too many arguments for parameter $var.")
-            operation(var) ∈ ps || error("Constraint equations contain a parameter $var that is not a parameter of the ODESystem.")
-        end
-    end
-end
-
 """
-    process_constraints(sys, constraints, u0, tspan, iip)
+    generate_function_bc(sys::ODESystem, u0, u0_idxs, tspan, iip)
 
-    Given an ODESystem with some constraints, generate the boundary condition function.
+    Given an ODESystem with constraints, generate the boundary condition function to pass to boundary value problem solvers.
 """
-function process_constraints(sys::ODESystem, constraints, u0, u0_idxs, tspan, iip)
-
+function generate_function_bc(sys::ODESystem, u0, u0_idxs, tspan, iip)
     iv = get_iv(sys)
     sts = get_unknowns(sys)
     ps = get_ps(sys)
     np = length(ps)
     ns = length(sts)
+    conssys = get_constraintsystem(sys)
+    cons = constraints(conssys)
 
     stidxmap = Dict([v => i for (i, v) in enumerate(sts)])
     pidxmap = Dict([v => i for (i, v) in enumerate(ps)])
@@ -1006,48 +985,34 @@ function process_constraints(sys::ODESystem, constraints, u0, u0_idxs, tspan, ii
     @variables sol(..)[1:ns] p[1:np]
     exprs = Any[]
 
-    constraintsts = OrderedSet()
-    constraintps = OrderedSet()
+    for st in get_unknowns(cons)
+        x = operation(st)
+        t = first(arguments(st))
+        idx = stidxmap[x(iv)]
 
-    !isnothing(constraints) && for cons in constraints
-        collect_vars!(constraintsts, constraintps, cons, iv)
-        validate_constraint_syms(cons, constraintsts, constraintps, Set(sts), Set(ps), iv)
-        expr = cons.rhs - cons.lhs
-
-        for st in constraintsts
-            x = operation(st)
-            t = arguments(st)[1]
-            idx = stidxmap[x(iv)]
-
-            expr = Symbolics.substitute(expr, Dict(x(t) => sol(t)[idx]))
-        end
+        cons = Symbolics.substitute(cons, Dict(x(t) => sol(t)[idx]))
+    end
 
-        for var in constraintps
-            if iscall(var)
-                x = operation(var)
-                t = arguments(var)[1]
-                idx = pidxmap[x]
+    for var in get_parameters(cons) 
+        if iscall(var)
+            x = operation(var)
+            t = arguments(var)[1]
+            idx = pidxmap[x]
 
-                expr = Symbolics.substitute(expr, Dict(x(t) => p[idx]))
-            else
-                idx = pidxmap[var]
-                expr = Symbolics.substitute(expr, Dict(var => p[idx]))
-            end
+            cons = Symbolics.substitute(cons, Dict(x(t) => p[idx]))
+        else
+            idx = pidxmap[var]
+            cons = Symbolics.substitute(cons, Dict(var => p[idx]))
         end
-
-        empty!(constraintsts)
-        empty!(constraintps)
-        push!(exprs, expr)
     end
 
-    init_cond_exprs = Any[]
-
+    init_conds = Any[]
     for i in u0_idxs
         expr = sol(tspan[1])[i] - u0[i]
-        push!(init_cond_exprs, expr)
+        push!(init_conds, expr)
     end
 
-    exprs = vcat(init_cond_exprs, exprs)
+    exprs = vcat(init_conds, cons)
     bcs = Symbolics.build_function(exprs, sol, p, expression = Val{false})
     if iip
         return (resid, u, p, t) -> bcs[2](resid, u, p)
diff --git a/src/systems/diffeqs/bvpsystem.jl b/src/systems/diffeqs/bvpsystem.jl
new file mode 100644
index 0000000000..ae3029fa14
--- /dev/null
+++ b/src/systems/diffeqs/bvpsystem.jl
@@ -0,0 +1,217 @@
+"""
+$(TYPEDEF)
+
+A system of ordinary differential equations.
+
+# Fields
+$(FIELDS)
+
+# Example
+
+```julia
+using ModelingToolkit
+using ModelingToolkit: t_nounits as t, D_nounits as D
+
+@parameters σ ρ β
+@variables x(t) y(t) z(t)
+
+eqs = [D(x) ~ σ*(y-x),
+       D(y) ~ x*(ρ-z)-y,
+       D(z) ~ x*y - β*z]
+
+@named de = ODESystem(eqs,t,[x,y,z],[σ,ρ,β],tspan=(0, 1000.0))
+```
+"""
+struct ODESystem <: AbstractODESystem
+    """
+    A tag for the system. If two systems have the same tag, then they are
+    structurally identical.
+    """
+    tag::UInt
+    """The ODEs defining the system."""
+    eqs::Vector{Equation}
+    """Independent variable."""
+    iv::BasicSymbolic{Real}
+    """
+    Dependent (unknown) variables. Must not contain the independent variable.
+
+    N.B.: If `torn_matching !== nothing`, this includes all variables. Actual
+    ODE unknowns are determined by the `SelectedState()` entries in `torn_matching`.
+    """
+    unknowns::Vector
+    """Parameter variables. Must not contain the independent variable."""
+    ps::Vector
+    """Time span."""
+    tspan::Union{NTuple{2, Any}, Nothing}
+    """Array variables."""
+    var_to_name::Any
+    """Control parameters (some subset of `ps`)."""
+    ctrls::Vector
+    """Observed variables."""
+    observed::Vector{Equation}
+    """
+    Time-derivative matrix. Note: this field will not be defined until
+    [`calculate_tgrad`](@ref) is called on the system.
+    """
+    tgrad::RefValue{Vector{Num}}
+    """
+    Jacobian matrix. Note: this field will not be defined until
+    [`calculate_jacobian`](@ref) is called on the system.
+    """
+    jac::RefValue{Any}
+    """
+    Control Jacobian matrix. Note: this field will not be defined until
+    [`calculate_control_jacobian`](@ref) is called on the system.
+    """
+    ctrl_jac::RefValue{Any}
+    """
+    Note: this field will not be defined until
+    [`generate_factorized_W`](@ref) is called on the system.
+    """
+    Wfact::RefValue{Matrix{Num}}
+    """
+    Note: this field will not be defined until
+    [`generate_factorized_W`](@ref) is called on the system.
+    """
+    Wfact_t::RefValue{Matrix{Num}}
+    """
+    The name of the system.
+    """
+    name::Symbol
+    """
+    A description of the system.
+    """
+    description::String
+    """
+    The internal systems. These are required to have unique names.
+    """
+    systems::Vector{ODESystem}
+    """
+    The default values to use when initial conditions and/or
+    parameters are not supplied in `ODEProblem`.
+    """
+    defaults::Dict
+    """
+    The guesses to use as the initial conditions for the
+    initialization system.
+    """
+    guesses::Dict
+    """
+    Tearing result specifying how to solve the system.
+    """
+    torn_matching::Union{Matching, Nothing}
+    """
+    The system for performing the initialization.
+    """
+    initializesystem::Union{Nothing, NonlinearSystem}
+    """
+    Extra equations to be enforced during the initialization sequence.
+    """
+    initialization_eqs::Vector{Equation}
+    """
+    The schedule for the code generation process.
+    """
+    schedule::Any
+    """
+    Type of the system.
+    """
+    connector_type::Any
+    """
+    Inject assignment statements before the evaluation of the RHS function.
+    """
+    preface::Any
+    """
+    A `Vector{SymbolicContinuousCallback}` that model events.
+    The integrator will use root finding to guarantee that it steps at each zero crossing.
+    """
+    continuous_events::Vector{SymbolicContinuousCallback}
+    """
+    A `Vector{SymbolicDiscreteCallback}` that models events. Symbolic
+    analog to `SciMLBase.DiscreteCallback` that executes an affect when a given condition is
+    true at the end of an integration step.
+    """
+    discrete_events::Vector{SymbolicDiscreteCallback}
+    """
+    Topologically sorted parameter dependency equations, where all symbols are parameters and
+    the LHS is a single parameter.
+    """
+    parameter_dependencies::Vector{Equation}
+    """
+    Metadata for the system, to be used by downstream packages.
+    """
+    metadata::Any
+    """
+    Metadata for MTK GUI.
+    """
+    gui_metadata::Union{Nothing, GUIMetadata}
+    """
+    A boolean indicating if the given `ODESystem` represents a system of DDEs.
+    """
+    is_dde::Bool
+    """
+    A list of points to provide to the solver as tstops. Uses the same syntax as discrete
+    events.
+    """
+    tstops::Vector{Any}
+    """
+    Cache for intermediate tearing state.
+    """
+    tearing_state::Any
+    """
+    Substitutions generated by tearing.
+    """
+    substitutions::Any
+    """
+    If a model `sys` is complete, then `sys.x` no longer performs namespacing.
+    """
+    complete::Bool
+    """
+    Cached data for fast symbolic indexing.
+    """
+    index_cache::Union{Nothing, IndexCache}
+    """
+    A list of discrete subsystems.
+    """
+    discrete_subsystems::Any
+    """
+    A list of actual unknowns needed to be solved by solvers.
+    """
+    solved_unknowns::Union{Nothing, Vector{Any}}
+    """
+    A vector of vectors of indices for the split parameters.
+    """
+    split_idxs::Union{Nothing, Vector{Vector{Int}}}
+    """
+    The hierarchical parent system before simplification.
+    """
+    parent::Any
+
+    function ODESystem(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad,
+            jac, ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses,
+            torn_matching, initializesystem, initialization_eqs, schedule,
+            connector_type, preface, cevents,
+            devents, parameter_dependencies,
+            metadata = nothing, gui_metadata = nothing, is_dde = false,
+            tstops = [], tearing_state = nothing,
+            substitutions = nothing, complete = false, index_cache = nothing,
+            discrete_subsystems = nothing, solved_unknowns = nothing,
+            split_idxs = nothing, parent = nothing; checks::Union{Bool, Int} = true)
+        if checks == true || (checks & CheckComponents) > 0
+            check_independent_variables([iv])
+            check_variables(dvs, iv)
+            check_parameters(ps, iv)
+            check_equations(deqs, iv)
+            check_equations(equations(cevents), iv)
+        end
+        if checks == true || (checks & CheckUnits) > 0
+            u = __get_unit_type(dvs, ps, iv)
+            check_units(u, deqs)
+        end
+        new(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac,
+            ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses, torn_matching,
+            initializesystem, initialization_eqs, schedule, connector_type, preface,
+            cevents, devents, parameter_dependencies, metadata,
+            gui_metadata, is_dde, tstops, tearing_state, substitutions, complete, index_cache,
+            discrete_subsystems, solved_unknowns, split_idxs, parent)
+    end
+end
diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl
index 61f16fd926..6ca6cdf838 100644
--- a/src/systems/diffeqs/odesystem.jl
+++ b/src/systems/diffeqs/odesystem.jl
@@ -49,6 +49,8 @@ struct ODESystem <: AbstractODESystem
     ctrls::Vector
     """Observed variables."""
     observed::Vector{Equation}
+    """System of constraints that must be satisfied by the solution to the system."""
+    constraintsystem::Union{Nothing, ConstraintsSystem}
     """
     Time-derivative matrix. Note: this field will not be defined until
     [`calculate_tgrad`](@ref) is called on the system.
@@ -186,7 +188,7 @@ struct ODESystem <: AbstractODESystem
     """
     parent::Any
 
-    function ODESystem(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad,
+    function ODESystem(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, constraints, tgrad,
             jac, ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses,
             torn_matching, initializesystem, initialization_eqs, schedule,
             connector_type, preface, cevents,
@@ -207,7 +209,7 @@ struct ODESystem <: AbstractODESystem
             u = __get_unit_type(dvs, ps, iv)
             check_units(u, deqs)
         end
-        new(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac,
+        new(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, constraints, tgrad, jac,
             ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses, torn_matching,
             initializesystem, initialization_eqs, schedule, connector_type, preface,
             cevents, devents, parameter_dependencies, metadata,
@@ -219,6 +221,7 @@ end
 function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
         controls = Num[],
         observed = Equation[],
+        constraints = Equation[],
         systems = ODESystem[],
         tspan = nothing,
         name = nothing,
@@ -283,11 +286,26 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
     cont_callbacks = SymbolicContinuousCallbacks(continuous_events)
     disc_callbacks = SymbolicDiscreteCallbacks(discrete_events)
 
+    constraintsys = nothing
+    if !isempty(constraints)
+        constraintsys = process_constraint_system(constraints, dvs′, ps′, iv, systems)
+        dvset = Set(dvs′)
+        pset = Set(ps′)
+        for st in get_unknowns(constraintsys)
+            iscall(st) ? 
+                !in(operation(st)(iv), dvset) && push!(dvs′, st) :
+                !in(st, dvset) && push!(dvs′, st)
+        end
+        for p in parameters(constraintsys)
+            !in(p, pset) && push!(ps′, p)
+        end
+    end
+
     if is_dde === nothing
         is_dde = _check_if_dde(deqs, iv′, systems)
     end
     ODESystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
-        deqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, tgrad, jac,
+        deqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, constraintsys, tgrad, jac,
         ctrl_jac, Wfact, Wfact_t, name, description, systems,
         defaults, guesses, nothing, initializesystem,
         initialization_eqs, schedule, connector_type, preface, cont_callbacks,
@@ -295,7 +313,7 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
         metadata, gui_metadata, is_dde, tstops, checks = checks)
 end
 
-function ODESystem(eqs, iv; kwargs...)
+function ODESystem(eqs, iv; constraints = Equation[], kwargs...)
     eqs = collect(eqs)
     # NOTE: this assumes that the order of algebraic equations doesn't matter
     diffvars = OrderedSet()
@@ -358,9 +376,10 @@ function ODESystem(eqs, iv; kwargs...)
         end
     end
     algevars = setdiff(allunknowns, diffvars)
+
     # the orders here are very important!
     return ODESystem(Equation[diffeq; algeeq; compressed_eqs], iv,
-        collect(Iterators.flatten((diffvars, algevars))), collect(new_ps); kwargs...)
+        collect(Iterators.flatten((diffvars, algevars))), collect(new_ps); constraints, kwargs...)
 end
 
 # NOTE: equality does not check cached Jacobian
@@ -770,3 +789,60 @@ function Base.show(io::IO, mime::MIME"text/plain", sys::ODESystem; hint = true,
 
     return nothing
 end
+
+# Validate that all the variables in the BVP constraints are well-formed states or parameters.
+#  - Any callable with multiple arguments will error.
+#  - Callable/delay variables (e.g. of the form x(0.6) should be unknowns of the system (and have one arg, etc.)
+#  - Callable/delay parameters should be parameters of the system (and have one arg, etc.)
+function validate_constraint_syms(constraintsts, constraintps, sts, ps, iv) 
+    for var in constraintsts
+        if !iscall(var)
+            occursin(iv, var) && var ∈ sts || throw(ArgumentError("Time-dependent variable $var is not an unknown of the system."))
+        elseif length(arguments(var)) > 1
+            throw(ArgumentError("Too many arguments for variable $var."))
+        elseif length(arguments(var)) == 1
+            arg = first(arguments(var))
+            operation(var)(iv) ∈ sts || throw(ArgumentError("Variable $var is not a variable of the ODESystem. Called variables must be variables of the ODESystem."))
+
+            isequal(arg, iv) || 
+                isparameter(arg) || 
+                arg isa Integer || 
+                    arg isa AbstractFloat || 
+                        throw(ArgumentError("Invalid argument specified for variable $var. The argument of the variable should be either $iv, a parameter, or a value specifying the time that the constraint holds."))
+        else
+            var ∈ sts && @warn "Variable $var has no argument. It will be interpreted as $var($iv), and the constraint will apply to the entire interval."
+        end
+    end
+
+    for var in constraintps
+        !iscall(var) && continue
+
+        if length(arguments(var)) > 1
+            throw(ArgumentError("Too many arguments for parameter $var in equation $eq."))
+        elseif length(arguments(var)) == 1
+            arg = first(arguments(var))
+            operation(var) ∈ ps || throw(ArgumentError("Parameter $var is not a parameter of the ODESystem. Called parameters must be parameters of the ODESystem."))
+
+            isequal(arg, iv) || 
+                isparameter(arg) ||
+                arg isa Integer || 
+                    arg isa AbstractFloat || 
+                        throw(ArgumentError("Invalid argument specified for callable parameter $var. The argument of the parameter should be either $iv, a parameter, or a value specifying the time that the constraint holds."))
+        end
+    end
+end
+
+function process_constraint_system(constraints::Vector{Equation}, sts, ps, iv, subsys::Vector{ODESystem}; name = :cons)
+    isempty(constraints) && return nothing
+
+    constraintsts = OrderedSet()
+    constraintps = OrderedSet()
+
+    for cons in constraints
+        syms = collect_vars!(constraintsts, constraintps, cons, iv)
+    end
+    validate_constraint_syms(constraintsts, constraintps, Set(sts), Set(ps), iv)
+
+    constraint_subsys = get_constraintsystem.(subsys)
+    ConstraintsSystem(constraints, collect(constraintsts), collect(constraintps); systems = constraint_subsys, name)
+end
diff --git a/src/systems/optimization/constraints_system.jl b/src/systems/optimization/constraints_system.jl
index 03225fc900..61e190e672 100644
--- a/src/systems/optimization/constraints_system.jl
+++ b/src/systems/optimization/constraints_system.jl
@@ -89,6 +89,10 @@ struct ConstraintsSystem <: AbstractTimeIndependentSystem
             tearing_state = nothing, substitutions = nothing,
             complete = false, index_cache = nothing;
             checks::Union{Bool, Int} = true)
+
+        ##if checks == true || (checks & CheckComponents) > 0 
+        ##    check_variables(unknowns, constraints)
+        ##end
         if checks == true || (checks & CheckUnits) > 0
             u = __get_unit_type(unknowns, ps)
             check_units(u, constraints)
@@ -253,3 +257,4 @@ function get_cmap(sys::ConstraintsSystem, exprs = nothing)
     cmap = map(x -> x ~ getdefault(x), cs)
     return cmap, cs
 end
+
diff --git a/test/bvproblem.jl b/test/bvproblem.jl
index 2f278e6135..16c12d1be6 100644
--- a/test/bvproblem.jl
+++ b/test/bvproblem.jl
@@ -8,163 +8,163 @@ using ModelingToolkit: t_nounits as t, D_nounits as D
 import ModelingToolkit: process_constraints
 
 ### Test Collocation solvers on simple problems 
-solvers = [MIRK4, RadauIIa5, LobattoIIIa3]
+solvers = [MIRK4]
 daesolvers = [Ascher2, Ascher4, Ascher6]
 
-# let
-#      @parameters α=7.5 β=4.0 γ=8.0 δ=5.0
-#      @variables x(t)=1.0 y(t)=2.0
-#      
-#      eqs = [D(x) ~ α * x - β * x * y,
-#          D(y) ~ -γ * y + δ * x * y]
-#      
-#      u0map = [x => 1.0, y => 2.0]
-#      parammap = [α => 7.5, β => 4, γ => 8.0, δ => 5.0]
-#      tspan = (0.0, 10.0)
-#      
-#      @mtkbuild lotkavolterra = ODESystem(eqs, t)
-#      op = ODEProblem(lotkavolterra, u0map, tspan, parammap)
-#      osol = solve(op, Vern9())
-#      
-#      bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(
-#          lotkavolterra, u0map, tspan, parammap; eval_expression = true)
-#      
-#      for solver in solvers
-#          sol = solve(bvp, solver(), dt = 0.01)
-#          @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
-#          @test sol.u[1] == [1.0, 2.0]
-#      end
-#      
-#      # Test out of place
-#      bvp2 = SciMLBase.BVProblem{false, SciMLBase.AutoSpecialize}(
-#          lotkavolterra, u0map, tspan, parammap; eval_expression = true)
-#      
-#      for solver in solvers
-#          sol = solve(bvp2, solver(), dt = 0.01)
-#          @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
-#          @test sol.u[1] == [1.0, 2.0]
-#      end
-# end
-# 
-# ### Testing on pendulum
-# let
-#      @parameters g=9.81 L=1.0
-#      @variables θ(t) = π / 2 θ_t(t)
-#      
-#      eqs = [D(θ) ~ θ_t
-#             D(θ_t) ~ -(g / L) * sin(θ)]
-#      
-#      @mtkbuild pend = ODESystem(eqs, t)
-#      
-#      u0map = [θ => π / 2, θ_t => π / 2]
-#      parammap = [:L => 1.0, :g => 9.81]
-#      tspan = (0.0, 6.0)
-#      
-#      op = ODEProblem(pend, u0map, tspan, parammap)
-#      osol = solve(op, Vern9())
-#      
-#      bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap)
-#      for solver in solvers
-#          sol = solve(bvp, solver(), dt = 0.01)
-#          @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
-#          @test sol.u[1] == [π / 2, π / 2]
-#      end
-#      
-#      # Test out-of-place
-#      bvp2 = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(pend, u0map, tspan, parammap)
-#      
-#      for solver in solvers
-#          sol = solve(bvp2, solver(), dt = 0.01)
-#          @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
-#          @test sol.u[1] == [π / 2, π / 2]
-#      end
-# end
-# 
-# ##################################################################
-# ### ODESystem with constraint equations, DAEs with constraints ###
-# ##################################################################
-# 
-# # Test generation of boundary condition function using `process_constraints`. Compare solutions to manually written boundary conditions
-# let
-#     @parameters α=1.5 β=1.0 γ=3.0 δ=1.0
-#     @variables x(..) y(..)
-#     eqs = [D(x(t)) ~ α * x(t) - β * x(t) * y(t),
-#            D(y(t)) ~ -γ * y(t) + δ * x(t) * y(t)]
-#     
-#     tspan = (0., 1.)
-#     @mtkbuild lksys = ODESystem(eqs, t)
-# 
-#     function lotkavolterra!(du, u, p, t) 
-#         du[1] = p[1]*u[1] - p[2]*u[1]*u[2]
-#         du[2] = -p[4]*u[2] + p[3]*u[1]*u[2]
-#     end
-# 
-#     function lotkavolterra(u, p, t) 
-#         [p[1]*u[1] - p[2]*u[1]*u[2], -p[4]*u[2] + p[3]*u[1]*u[2]]
-#     end
-#     # Compare the built bc function to the actual constructed one.
-#     function bc!(resid, u, p, t) 
-#         resid[1] = u[1][1] - 1.
-#         resid[2] = u[1][2] - 2.
-#         nothing
-#     end
-#     function bc(u, p, t)
-#         [u[1][1] - 1., u[1][2] - 2.]
-#     end
-# 
-#     u0 = [1., 2.]; p = [1.5, 1., 1., 3.]
-#     genbc_iip = ModelingToolkit.process_constraints(lksys, nothing, u0, [1, 2], tspan, true)
-#     genbc_oop = ModelingToolkit.process_constraints(lksys, nothing, u0, [1, 2], tspan, false)
-# 
-#     bvpi1 = SciMLBase.BVProblem(lotkavolterra!, bc!, [1.,2.], tspan, p)
-#     bvpi2 = SciMLBase.BVProblem(lotkavolterra!, genbc_iip, [1.,2.], tspan, p)
-# 
-#     sol1 = solve(bvpi1, MIRK4(), dt = 0.05)
-#     sol2 = solve(bvpi2, MIRK4(), dt = 0.05)
-#     @test sol1 ≈ sol2
-# 
-#     bvpo1 = BVProblem(lotkavolterra, bc, [1,2], tspan, p)
-#     bvpo2 = BVProblem(lotkavolterra, genbc_oop, [1,2], tspan, p)
-# 
-#     sol1 = solve(bvpo1, MIRK4(), dt = 0.05)
-#     sol2 = solve(bvpo2, MIRK4(), dt = 0.05)
-#     @test sol1 ≈ sol2
-# 
-#     # Test with a constraint.
-#     constraints = [y(0.5) ~ 2.]
-# 
-#     function bc!(resid, u, p, t) 
-#         resid[1] = u(0.0)[1] - 1.
-#         resid[2] = u(0.5)[2] - 2.
-#     end
-#     function bc(u, p, t)
-#         [u(0.0)[1] - 1., u(0.5)[2] - 2.]
-#     end
-# 
-#     u0 = [1, 1.]
-#     genbc_iip = ModelingToolkit.process_constraints(lksys, constraints, u0, [1], tspan, true)
-#     genbc_oop = ModelingToolkit.process_constraints(lksys, constraints, u0, [1], tspan, false)
-# 
-#     bvpi1 = SciMLBase.BVProblem(lotkavolterra!, bc!, u0, tspan, p)
-#     bvpi2 = SciMLBase.BVProblem(lotkavolterra!, genbc_iip, u0, tspan, p)
-#     bvpi3 = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lksys, [x(t) => 1.], tspan; guesses = [y(t) => 1.], constraints)
-#     bvpi4 = SciMLBase.BVProblem{true, SciMLBase.FullSpecialize}(lksys, [x(t) => 1.], tspan; guesses = [y(t) => 1.], constraints)
-#     
-#     sol1 = @btime solve($bvpi1, MIRK4(), dt = 0.01)
-#     sol2 = @btime solve($bvpi2, MIRK4(), dt = 0.01)
-#     sol3 = @btime solve($bvpi3, MIRK4(), dt = 0.01)
-#     sol4 = @btime solve($bvpi4, MIRK4(), dt = 0.01)
-#     @test sol1 ≈ sol2 ≈ sol3 ≈ sol4 # don't get true equality here, not sure why
-# 
-#     bvpo1 = BVProblem(lotkavolterra, bc, u0, tspan, p)
-#     bvpo2 = BVProblem(lotkavolterra, genbc_oop, u0, tspan, p)
-#     bvpo3 = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(lksys, [x(t) => 1.], tspan; guesses = [y(t) => 1.], constraints)
-# 
-#     sol1 = @btime solve($bvpo1, MIRK4(), dt = 0.05)
-#     sol2 = @btime solve($bvpo2, MIRK4(), dt = 0.05)
-#     sol3 = @btime solve($bvpo3, MIRK4(), dt = 0.05)
-#     @test sol1 ≈ sol2 ≈ sol3
-# end
+let
+     @parameters α=7.5 β=4.0 γ=8.0 δ=5.0
+     @variables x(t)=1.0 y(t)=2.0
+     
+     eqs = [D(x) ~ α * x - β * x * y,
+         D(y) ~ -γ * y + δ * x * y]
+     
+     u0map = [x => 1.0, y => 2.0]
+     parammap = [α => 7.5, β => 4, γ => 8.0, δ => 5.0]
+     tspan = (0.0, 10.0)
+     
+     @mtkbuild lotkavolterra = ODESystem(eqs, t)
+     op = ODEProblem(lotkavolterra, u0map, tspan, parammap)
+     osol = solve(op, Vern9())
+     
+     bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(
+         lotkavolterra, u0map, tspan, parammap; eval_expression = true)
+     
+     for solver in solvers
+         sol = solve(bvp, solver(), dt = 0.01)
+         @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
+         @test sol.u[1] == [1.0, 2.0]
+     end
+     
+     # Test out of place
+     bvp2 = SciMLBase.BVProblem{false, SciMLBase.AutoSpecialize}(
+         lotkavolterra, u0map, tspan, parammap; eval_expression = true)
+     
+     for solver in solvers
+         sol = solve(bvp2, solver(), dt = 0.01)
+         @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
+         @test sol.u[1] == [1.0, 2.0]
+     end
+end
+
+### Testing on pendulum
+let
+     @parameters g=9.81 L=1.0
+     @variables θ(t) = π / 2 θ_t(t)
+     
+     eqs = [D(θ) ~ θ_t
+            D(θ_t) ~ -(g / L) * sin(θ)]
+     
+     @mtkbuild pend = ODESystem(eqs, t)
+     
+     u0map = [θ => π / 2, θ_t => π / 2]
+     parammap = [:L => 1.0, :g => 9.81]
+     tspan = (0.0, 6.0)
+     
+     op = ODEProblem(pend, u0map, tspan, parammap)
+     osol = solve(op, Vern9())
+     
+     bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap)
+     for solver in solvers
+         sol = solve(bvp, solver(), dt = 0.01)
+         @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
+         @test sol.u[1] == [π / 2, π / 2]
+     end
+     
+     # Test out-of-place
+     bvp2 = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(pend, u0map, tspan, parammap)
+     
+     for solver in solvers
+         sol = solve(bvp2, solver(), dt = 0.01)
+         @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
+         @test sol.u[1] == [π / 2, π / 2]
+     end
+end
+
+##################################################################
+### ODESystem with constraint equations, DAEs with constraints ###
+##################################################################
+
+# Test generation of boundary condition function using `generate_function_bc`. Compare solutions to manually written boundary conditions
+let
+    @parameters α=1.5 β=1.0 γ=3.0 δ=1.0
+    @variables x(..) y(..)
+    eqs = [D(x(t)) ~ α * x(t) - β * x(t) * y(t),
+           D(y(t)) ~ -γ * y(t) + δ * x(t) * y(t)]
+    
+    tspan = (0., 1.)
+    @mtkbuild lksys = ODESystem(eqs, t)
+
+    function lotkavolterra!(du, u, p, t) 
+        du[1] = p[1]*u[1] - p[2]*u[1]*u[2]
+        du[2] = -p[4]*u[2] + p[3]*u[1]*u[2]
+    end
+
+    function lotkavolterra(u, p, t) 
+        [p[1]*u[1] - p[2]*u[1]*u[2], -p[4]*u[2] + p[3]*u[1]*u[2]]
+    end
+    # Compare the built bc function to the actual constructed one.
+    function bc!(resid, u, p, t) 
+        resid[1] = u[1][1] - 1.
+        resid[2] = u[1][2] - 2.
+        nothing
+    end
+    function bc(u, p, t)
+        [u[1][1] - 1., u[1][2] - 2.]
+    end
+
+    u0 = [1., 2.]; p = [1.5, 1., 1., 3.]
+    genbc_iip = ModelingToolkit.generate_function_bc(lksys, nothing, u0, [1, 2], tspan, true)
+    genbc_oop = ModelingToolkit.generate_function_bc(lksys, nothing, u0, [1, 2], tspan, false)
+
+    bvpi1 = SciMLBase.BVProblem(lotkavolterra!, bc!, [1.,2.], tspan, p)
+    bvpi2 = SciMLBase.BVProblem(lotkavolterra!, genbc_iip, [1.,2.], tspan, p)
+
+    sol1 = solve(bvpi1, MIRK4(), dt = 0.05)
+    sol2 = solve(bvpi2, MIRK4(), dt = 0.05)
+    @test sol1 ≈ sol2
+
+    bvpo1 = BVProblem(lotkavolterra, bc, [1,2], tspan, p)
+    bvpo2 = BVProblem(lotkavolterra, genbc_oop, [1,2], tspan, p)
+
+    sol1 = solve(bvpo1, MIRK4(), dt = 0.05)
+    sol2 = solve(bvpo2, MIRK4(), dt = 0.05)
+    @test sol1 ≈ sol2
+
+    # Test with a constraint.
+    constraints = [y(0.5) ~ 2.]
+
+    function bc!(resid, u, p, t) 
+        resid[1] = u(0.0)[1] - 1.
+        resid[2] = u(0.5)[2] - 2.
+    end
+    function bc(u, p, t)
+        [u(0.0)[1] - 1., u(0.5)[2] - 2.]
+    end
+
+    u0 = [1, 1.]
+    genbc_iip = ModelingToolkit.generate_function_bc(lksys, constraints, u0, [1], tspan, true)
+    genbc_oop = ModelingToolkit.generate_function_bc(lksys, constraints, u0, [1], tspan, false)
+
+    bvpi1 = SciMLBase.BVProblem(lotkavolterra!, bc!, u0, tspan, p)
+    bvpi2 = SciMLBase.BVProblem(lotkavolterra!, genbc_iip, u0, tspan, p)
+    bvpi3 = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lksys, [x(t) => 1.], tspan; guesses = [y(t) => 1.], constraints)
+    bvpi4 = SciMLBase.BVProblem{true, SciMLBase.FullSpecialize}(lksys, [x(t) => 1.], tspan; guesses = [y(t) => 1.], constraints)
+    
+    sol1 = @btime solve($bvpi1, MIRK4(), dt = 0.01)
+    sol2 = @btime solve($bvpi2, MIRK4(), dt = 0.01)
+    sol3 = @btime solve($bvpi3, MIRK4(), dt = 0.01)
+    sol4 = @btime solve($bvpi4, MIRK4(), dt = 0.01)
+    @test sol1 ≈ sol2 ≈ sol3 ≈ sol4 # don't get true equality here, not sure why
+
+    bvpo1 = BVProblem(lotkavolterra, bc, u0, tspan, p)
+    bvpo2 = BVProblem(lotkavolterra, genbc_oop, u0, tspan, p)
+    bvpo3 = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(lksys, [x(t) => 1.], tspan; guesses = [y(t) => 1.], constraints)
+
+    sol1 = @btime solve($bvpo1, MIRK4(), dt = 0.05)
+    sol2 = @btime solve($bvpo2, MIRK4(), dt = 0.05)
+    sol3 = @btime solve($bvpo3, MIRK4(), dt = 0.05)
+    @test sol1 ≈ sol2 ≈ sol3
+end
 
 function test_solvers(solvers, prob, u0map, constraints, equations = []; dt = 0.05, atol = 1e-3)
     for solver in solvers
@@ -214,33 +214,32 @@ let
     u0map = []
     tspan = (0.0, 1.0)
     guesses = [x(t) => 4.0, y(t) => 2.]
+    constr = [x(.6) ~ 3.5, x(.3) ~ 7.]
+    @mtkbuild lksys = ODESystem(eqs, t; constraints = constr)
 
-    @mtkbuild lksys = ODESystem(eqs, t)
-
-    constraints = [x(.6) ~ 3.5, x(.3) ~ 7.]
-    bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lksys, u0map, tspan; guesses, constraints)
-    test_solvers(solvers, bvp, u0map, constraints; dt = 0.05)
+    bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lksys, u0map, tspan; guesses)
+    test_solvers(solvers, bvp, u0map, constr; dt = 0.05)
 
-    # Testing that more complicated constraints give correct solutions.
-    constraints = [y(.2) + x(.8) ~ 3., y(.3) ~ 2.]
-    bvp = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(lksys, u0map, tspan; guesses, constraints)
-    test_solvers(solvers, bvp, u0map, constraints; dt = 0.05)
+    # Testing that more complicated constr give correct solutions.
+    constr = [y(.2) + x(.8) ~ 3., y(.3) ~ 2.]
+    bvp = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(lksys, u0map, tspan; guesses)
+    test_solvers(solvers, bvp, u0map, constr; dt = 0.05)
 
-    constraints = [α * β - x(.6) ~ 0.0, y(.2) ~ 3.]
-    bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lksys, u0map, tspan; guesses, constraints)
-    test_solvers(solvers, bvp, u0map, constraints)
+    constr = [α * β - x(.6) ~ 0.0, y(.2) ~ 3.]
+    bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lksys, u0map, tspan; guesses)
+    test_solvers(solvers, bvp, u0map, constr)
 
-    # Testing that errors are properly thrown when malformed constraints are given.
+    # Testing that errors are properly thrown when malformed constr are given.
     @variables bad(..)
-    constraints = [x(1.) + bad(3.) ~ 10]
-    @test_throws ErrorException bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lksys, u0map, tspan; guesses, constraints)
+    constr = [x(1.) + bad(3.) ~ 10]
+    @test_throws ErrorException lksys = ODESystem(eqs, t; constraints = constr)
 
-    constraints = [x(t) + y(t) ~ 3]
-    @test_throws ErrorException bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lksys, u0map, tspan; guesses, constraints)
+    constr = [x(t) + y(t) ~ 3]
+    @test_throws ErrorException lksys = ODESystem(eqs, t; constraints = constr)
 
     @parameters bad2
-    constraints = [bad2 + x(0.) ~ 3]
-    @test_throws ErrorException bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lksys, u0map, tspan; guesses, constraints)
+    constr = [bad2 + x(0.) ~ 3]
+    @test_throws ErrorException lksys = ODESystem(eqs, t; constraints = constr)
 end
 
 # Cartesian pendulum from the docs.
@@ -288,31 +287,33 @@ end
 #     eqs = [D(D(x(t))) ~ λ * x(t)
 #            D(D(y)) ~ λ * y - g
 #            x(t)^2 + y^2 ~ 1]
-#     @mtkbuild pend = ODESystem(eqs, t)
+#     constr = [x(0.5) ~ 1]
+#     @mtkbuild pend = ODESystem(eqs, t; constr)
 # 
 #     tspan = (0.0, 1.5)
 #     u0map = [x(t) => 0.6, y => 0.8]
 #     parammap = [g => 1]
 #     guesses = [λ => 1]
 # 
-#     constraints = [x(0.5) ~ 1]
-#     bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; constraints, guesses, check_length = false)
-#     test_solvers(daesolvers, bvp, u0map, constraints)
+#     bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; guesses, check_length = false)
+#     test_solvers(daesolvers, bvp, u0map, constr)
 # 
 #     bvp2 = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(pend, u0map, tspan, parammap)
-#     test_solvers(daesolvers, bvp2, u0map, constraints, get_alg_eqs(pend))
+#     test_solvers(daesolvers, bvp2, u0map, constr, get_alg_eqs(pend))
 # 
-#     # More complicated constraints.
+#     # More complicated constr.
 #     u0map = [x(t) => 0.6]
 #     guesses = [λ => 1, y(t) => 0.8]
 # 
-#     constraints = [x(0.5) ~ 1, 
+#     constr = [x(0.5) ~ 1, 
 #                    x(0.3)^3 + y(0.6)^2 ~ 0.5]
-#     bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; constraints, guesses, check_length = false)
-#     test_solvers(daesolvers, bvp, u0map, constraints, get_alg_eqs(pend))
+#     @mtkbuild pend = ODESystem(eqs, t; constr)
+#     bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; guesses, check_length = false)
+#     test_solvers(daesolvers, bvp, u0map, constr, get_alg_eqs(pend))
 # 
-#     constraints = [x(0.4) * g ~ y(0.2),
+#     constr = [x(0.4) * g ~ y(0.2),
 #                    y(0.7) ~ 0.3]
-#     bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; constraints, guesses, check_length = false)
-#     test_solvers(daesolvers, bvp, u0map, constraints, get_alg_eqs(pend))
+#     @mtkbuild pend = ODESystem(eqs, t; constr)
+#     bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; guesses, check_length = false)
+#     test_solvers(daesolvers, bvp, u0map, constr, get_alg_eqs(pend))
 # end
diff --git a/test/odesystem.jl b/test/odesystem.jl
index 85d135b338..516fa7d415 100644
--- a/test/odesystem.jl
+++ b/test/odesystem.jl
@@ -1552,3 +1552,33 @@ end
     expected_tstops = unique!(sort!(vcat(0.0:0.075:10.0, 0.1, 0.2, 0.65, 0.35, 0.45)))
     @test all(x -> any(isapprox(x, atol = 1e-6), sol2.t), expected_tstops)
 end
+
+@testset "Constraint system construction" begin
+    @variables x(..) y(..) z(..)
+    @parameters a b c d e 
+    eqs = [D(x(t)) ~ 3*a*y(t), D(y(t)) ~ x(t) - z(t), D(z(t)) ~ e*x(t)^2]
+    cons = [x(0.3) ~ c*d, y(0.7) ~ 3]
+
+    # Test variables + parameters infer correctly.
+    @mtkbuild sys = ODESystem(eqs, t; constraints = cons)
+    @test issetequal(parameters(sys), [a, c, d, e])
+    @test issetequal(unknowns(sys), [x(t), y(t)])
+
+    @parameters t_c
+    cons = [x(t_c) ~ 3]
+    @mtkbuild sys = ODESystem(eqs, t; constraints = cons)
+    @test_broken issetequal(parameters(sys), [a, e, t_c]) # TODO: unbreak this.
+
+    # Test that bad constraints throw errors.
+    cons = [x(3, 4) ~ 3]
+    @test_throws ArgumentError @mtkbuild sys = ODESystem(eqs, t; constraints = cons)
+
+    cons = [x(y(t)) ~ 2]
+    @test_throws ArgumentError @mtkbuild sys = ODESystem(eqs, t; constraints = cons)
+
+    @variables u(t) v
+    cons = [x(t) * u ~ 3]
+    @test_throws ArgumentError @mtkbuild sys = ODESystem(eqs, t; constraints = cons)
+    cons = [x(t) * v ~ 3]
+    @test_nowarn @mtkbuild sys = ODESystem(eqs, t; constraints = cons)
+end

From 90ce80d658eabd599b99c9e5c9bbc5efc9d41470 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Tue, 28 Jan 2025 11:57:37 -0500
Subject: [PATCH 034/111] refactor tests

---
 src/systems/diffeqs/abstractodesystem.jl | 45 +++++++++++++-----------
 src/systems/diffeqs/odesystem.jl         | 11 +++---
 test/bvproblem.jl                        | 38 ++++++--------------
 3 files changed, 39 insertions(+), 55 deletions(-)

diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl
index d0d6ed7937..47e5c3b88c 100644
--- a/src/systems/diffeqs/abstractodesystem.jl
+++ b/src/systems/diffeqs/abstractodesystem.jl
@@ -943,9 +943,10 @@ function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = []
 
     sts = unknowns(sys)
     ps = parameters(sys)
+    constraintsys = get_constraintsystem(sys)
 
-    if !isnothing(constraints)
-        (length(constraints) + length(u0map) > length(sts)) && 
+    if !isnothing(constraintsys)
+        (length(constraints(constraintsys)) + length(u0map) > length(sts)) && 
         @warn "The BVProblem is overdetermined. The total number of conditions (# constraints + # fixed initial values given by u0map) exceeds the total number of states. The BVP solvers will default to doing a nonlinear least-squares optimization."
     end
 
@@ -976,33 +977,35 @@ function generate_function_bc(sys::ODESystem, u0, u0_idxs, tspan, iip)
     ps = get_ps(sys)
     np = length(ps)
     ns = length(sts)
-    conssys = get_constraintsystem(sys)
-    cons = constraints(conssys)
-
     stidxmap = Dict([v => i for (i, v) in enumerate(sts)])
     pidxmap = Dict([v => i for (i, v) in enumerate(ps)])
 
     @variables sol(..)[1:ns] p[1:np]
-    exprs = Any[]
 
-    for st in get_unknowns(cons)
-        x = operation(st)
-        t = first(arguments(st))
-        idx = stidxmap[x(iv)]
+    conssys = get_constraintsystem(sys)
+    cons = Any[]
+    if !isnothing(conssys)
+        cons = [con.lhs - con.rhs for con in constraints(conssys)]
 
-        cons = Symbolics.substitute(cons, Dict(x(t) => sol(t)[idx]))
-    end
+        for st in get_unknowns(conssys)
+            x = operation(st)
+            t = only(arguments(st))
+            idx = stidxmap[x(iv)]
 
-    for var in get_parameters(cons) 
-        if iscall(var)
-            x = operation(var)
-            t = arguments(var)[1]
-            idx = pidxmap[x]
+            cons = map(c -> Symbolics.substitute(c, Dict(x(t) => sol(t)[idx])), cons)
+        end
 
-            cons = Symbolics.substitute(cons, Dict(x(t) => p[idx]))
-        else
-            idx = pidxmap[var]
-            cons = Symbolics.substitute(cons, Dict(var => p[idx]))
+        for var in parameters(conssys) 
+            if iscall(var)
+                x = operation(var)
+                t = only(arguments(var))
+                idx = pidxmap[x]
+
+                cons = map(c -> Symbolics.substitute(c, Dict(x(t) => p[idx])), cons)
+            else
+                idx = pidxmap[var]
+                cons = map(c -> Symbolics.substitute(c, Dict(var => p[idx])), cons)
+            end
         end
     end
 
diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl
index 6ca6cdf838..4ef9f89c6a 100644
--- a/src/systems/diffeqs/odesystem.jl
+++ b/src/systems/diffeqs/odesystem.jl
@@ -802,13 +802,11 @@ function validate_constraint_syms(constraintsts, constraintps, sts, ps, iv)
             throw(ArgumentError("Too many arguments for variable $var."))
         elseif length(arguments(var)) == 1
             arg = first(arguments(var))
-            operation(var)(iv) ∈ sts || throw(ArgumentError("Variable $var is not a variable of the ODESystem. Called variables must be variables of the ODESystem."))
+            operation(var)(iv) ∈ sts || 
+                throw(ArgumentError("Variable $var is not a variable of the ODESystem. Called variables must be variables of the ODESystem."))
 
-            isequal(arg, iv) || 
-                isparameter(arg) || 
-                arg isa Integer || 
-                    arg isa AbstractFloat || 
-                        throw(ArgumentError("Invalid argument specified for variable $var. The argument of the variable should be either $iv, a parameter, or a value specifying the time that the constraint holds."))
+            isequal(arg, iv) || isparameter(arg) || arg isa Integer || arg isa AbstractFloat || 
+                throw(ArgumentError("Invalid argument specified for variable $var. The argument of the variable should be either $iv, a parameter, or a value specifying the time that the constraint holds."))
         else
             var ∈ sts && @warn "Variable $var has no argument. It will be interpreted as $var($iv), and the constraint will apply to the entire interval."
         end
@@ -824,7 +822,6 @@ function validate_constraint_syms(constraintsts, constraintps, sts, ps, iv)
             operation(var) ∈ ps || throw(ArgumentError("Parameter $var is not a parameter of the ODESystem. Called parameters must be parameters of the ODESystem."))
 
             isequal(arg, iv) || 
-                isparameter(arg) ||
                 arg isa Integer || 
                     arg isa AbstractFloat || 
                         throw(ArgumentError("Invalid argument specified for callable parameter $var. The argument of the parameter should be either $iv, a parameter, or a value specifying the time that the constraint holds."))
diff --git a/test/bvproblem.jl b/test/bvproblem.jl
index 16c12d1be6..c6c124d09e 100644
--- a/test/bvproblem.jl
+++ b/test/bvproblem.jl
@@ -5,7 +5,6 @@ using BenchmarkTools
 using ModelingToolkit
 using SciMLBase
 using ModelingToolkit: t_nounits as t, D_nounits as D
-import ModelingToolkit: process_constraints
 
 ### Test Collocation solvers on simple problems 
 solvers = [MIRK4]
@@ -113,8 +112,8 @@ let
     end
 
     u0 = [1., 2.]; p = [1.5, 1., 1., 3.]
-    genbc_iip = ModelingToolkit.generate_function_bc(lksys, nothing, u0, [1, 2], tspan, true)
-    genbc_oop = ModelingToolkit.generate_function_bc(lksys, nothing, u0, [1, 2], tspan, false)
+    genbc_iip = ModelingToolkit.generate_function_bc(lksys, u0, [1, 2], tspan, true)
+    genbc_oop = ModelingToolkit.generate_function_bc(lksys, u0, [1, 2], tspan, false)
 
     bvpi1 = SciMLBase.BVProblem(lotkavolterra!, bc!, [1.,2.], tspan, p)
     bvpi2 = SciMLBase.BVProblem(lotkavolterra!, genbc_iip, [1.,2.], tspan, p)
@@ -131,7 +130,8 @@ let
     @test sol1 ≈ sol2
 
     # Test with a constraint.
-    constraints = [y(0.5) ~ 2.]
+    constr = [y(0.5) ~ 2.]
+    @mtkbuild lksys = ODESystem(eqs, t; constraints = constr)
 
     function bc!(resid, u, p, t) 
         resid[1] = u(0.0)[1] - 1.
@@ -142,13 +142,13 @@ let
     end
 
     u0 = [1, 1.]
-    genbc_iip = ModelingToolkit.generate_function_bc(lksys, constraints, u0, [1], tspan, true)
-    genbc_oop = ModelingToolkit.generate_function_bc(lksys, constraints, u0, [1], tspan, false)
+    genbc_iip = ModelingToolkit.generate_function_bc(lksys, u0, [1], tspan, true)
+    genbc_oop = ModelingToolkit.generate_function_bc(lksys, u0, [1], tspan, false)
 
     bvpi1 = SciMLBase.BVProblem(lotkavolterra!, bc!, u0, tspan, p)
     bvpi2 = SciMLBase.BVProblem(lotkavolterra!, genbc_iip, u0, tspan, p)
-    bvpi3 = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lksys, [x(t) => 1.], tspan; guesses = [y(t) => 1.], constraints)
-    bvpi4 = SciMLBase.BVProblem{true, SciMLBase.FullSpecialize}(lksys, [x(t) => 1.], tspan; guesses = [y(t) => 1.], constraints)
+    bvpi3 = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lksys, [x(t) => 1.], tspan; guesses = [y(t) => 1.])
+    bvpi4 = SciMLBase.BVProblem{true, SciMLBase.FullSpecialize}(lksys, [x(t) => 1.], tspan; guesses = [y(t) => 1.])
     
     sol1 = @btime solve($bvpi1, MIRK4(), dt = 0.01)
     sol2 = @btime solve($bvpi2, MIRK4(), dt = 0.01)
@@ -158,7 +158,7 @@ let
 
     bvpo1 = BVProblem(lotkavolterra, bc, u0, tspan, p)
     bvpo2 = BVProblem(lotkavolterra, genbc_oop, u0, tspan, p)
-    bvpo3 = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(lksys, [x(t) => 1.], tspan; guesses = [y(t) => 1.], constraints)
+    bvpo3 = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(lksys, [x(t) => 1.], tspan; guesses = [y(t) => 1.])
 
     sol1 = @btime solve($bvpo1, MIRK4(), dt = 0.05)
     sol2 = @btime solve($bvpo2, MIRK4(), dt = 0.05)
@@ -197,12 +197,6 @@ function test_solvers(solvers, prob, u0map, constraints, equations = []; dt = 0.
     end
 end
 
-solvers = [RadauIIa3, RadauIIa5, RadauIIa7,
-           LobattoIIIa2, LobattoIIIa4, LobattoIIIa5,
-           LobattoIIIb2, LobattoIIIb3, LobattoIIIb4, LobattoIIIb5,
-           LobattoIIIc2, LobattoIIIc3, LobattoIIIc4, LobattoIIIc5]
-weird = [MIRK2, MIRK5, RadauIIa2]
-daesolvers = []
 # Simple ODESystem with BVP constraints.
 let
     @parameters α=1.5 β=1.0 γ=3.0 δ=1.0
@@ -222,24 +216,14 @@ let
 
     # Testing that more complicated constr give correct solutions.
     constr = [y(.2) + x(.8) ~ 3., y(.3) ~ 2.]
+    @mtkbuild lksys = ODESystem(eqs, t; constraints = constr)
     bvp = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(lksys, u0map, tspan; guesses)
     test_solvers(solvers, bvp, u0map, constr; dt = 0.05)
 
     constr = [α * β - x(.6) ~ 0.0, y(.2) ~ 3.]
+    @mtkbuild lksys = ODESystem(eqs, t; constraints = constr)
     bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lksys, u0map, tspan; guesses)
     test_solvers(solvers, bvp, u0map, constr)
-
-    # Testing that errors are properly thrown when malformed constr are given.
-    @variables bad(..)
-    constr = [x(1.) + bad(3.) ~ 10]
-    @test_throws ErrorException lksys = ODESystem(eqs, t; constraints = constr)
-
-    constr = [x(t) + y(t) ~ 3]
-    @test_throws ErrorException lksys = ODESystem(eqs, t; constraints = constr)
-
-    @parameters bad2
-    constr = [bad2 + x(0.) ~ 3]
-    @test_throws ErrorException lksys = ODESystem(eqs, t; constraints = constr)
 end
 
 # Cartesian pendulum from the docs.

From a15c67024d4630193afe0f662d7c4c4183d07037 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Tue, 28 Jan 2025 14:23:08 -0500
Subject: [PATCH 035/111] fix sym validation

---
 src/systems/diffeqs/odesystem.jl | 91 ++++++++++++++------------------
 test/odesystem.jl                | 15 ++++--
 2 files changed, 49 insertions(+), 57 deletions(-)

diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl
index 4ef9f89c6a..0c8304e428 100644
--- a/src/systems/diffeqs/odesystem.jl
+++ b/src/systems/diffeqs/odesystem.jl
@@ -221,7 +221,7 @@ end
 function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
         controls = Num[],
         observed = Equation[],
-        constraints = Equation[],
+        constraintsystem = nothing,
         systems = ODESystem[],
         tspan = nothing,
         name = nothing,
@@ -286,26 +286,17 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
     cont_callbacks = SymbolicContinuousCallbacks(continuous_events)
     disc_callbacks = SymbolicDiscreteCallbacks(discrete_events)
 
-    constraintsys = nothing
-    if !isempty(constraints)
-        constraintsys = process_constraint_system(constraints, dvs′, ps′, iv, systems)
-        dvset = Set(dvs′)
-        pset = Set(ps′)
-        for st in get_unknowns(constraintsys)
-            iscall(st) ? 
-                !in(operation(st)(iv), dvset) && push!(dvs′, st) :
-                !in(st, dvset) && push!(dvs′, st)
-        end
-        for p in parameters(constraintsys)
-            !in(p, pset) && push!(ps′, p)
-        end
-    end
-
     if is_dde === nothing
         is_dde = _check_if_dde(deqs, iv′, systems)
     end
+
+    if !isempty(systems)
+        cons = get_constraintsystems.(systems)
+        @set! constraintsystem.systems = cons
+    end
+
     ODESystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
-        deqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, constraintsys, tgrad, jac,
+        deqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, constraintsystem, tgrad, jac,
         ctrl_jac, Wfact, Wfact_t, name, description, systems,
         defaults, guesses, nothing, initializesystem,
         initialization_eqs, schedule, connector_type, preface, cont_callbacks,
@@ -377,9 +368,22 @@ function ODESystem(eqs, iv; constraints = Equation[], kwargs...)
     end
     algevars = setdiff(allunknowns, diffvars)
 
+    if !isempty(constraints)
+        consvars = OrderedSet()
+        constraintsystem = process_constraint_system(constraints, allunknowns, new_ps, iv)
+        for st in get_unknowns(constraintsystem)
+            iscall(st) ? 
+                !in(operation(st)(iv), allunknowns) && push!(consvars, st) :
+                !in(st, allunknowns) && push!(consvars, st)
+        end
+        for p in parameters(constraintsystem)
+            !in(p, new_ps) && push!(new_ps, p)
+        end
+    end
+
     # the orders here are very important!
     return ODESystem(Equation[diffeq; algeeq; compressed_eqs], iv,
-        collect(Iterators.flatten((diffvars, algevars))), collect(new_ps); constraints, kwargs...)
+        collect(Iterators.flatten((diffvars, algevars, consvars))), collect(new_ps); constraintsystem, kwargs...)
 end
 
 # NOTE: equality does not check cached Jacobian
@@ -791,55 +795,38 @@ function Base.show(io::IO, mime::MIME"text/plain", sys::ODESystem; hint = true,
 end
 
 # Validate that all the variables in the BVP constraints are well-formed states or parameters.
-#  - Any callable with multiple arguments will error.
 #  - Callable/delay variables (e.g. of the form x(0.6) should be unknowns of the system (and have one arg, etc.)
 #  - Callable/delay parameters should be parameters of the system (and have one arg, etc.)
-function validate_constraint_syms(constraintsts, constraintps, sts, ps, iv) 
+function process_constraint_system(constraints::Vector{Equation}, sts, ps, iv; consname = :cons)
+    isempty(constraints) && return nothing
+
+    constraintsts = OrderedSet()
+    constraintps = OrderedSet()
+
+    # Hack? to extract parameters from callable variables in constraints.
+    for cons in constraints
+        collect_vars!(constraintsts, constraintps, cons, iv)
+    end
+
+    # Validate the states.
     for var in constraintsts
         if !iscall(var)
-            occursin(iv, var) && var ∈ sts || throw(ArgumentError("Time-dependent variable $var is not an unknown of the system."))
+            occursin(iv, var) && (var ∈ sts || throw(ArgumentError("Time-dependent variable $var is not an unknown of the system.")))
         elseif length(arguments(var)) > 1
             throw(ArgumentError("Too many arguments for variable $var."))
         elseif length(arguments(var)) == 1
-            arg = first(arguments(var))
+            arg = only(arguments(var))
             operation(var)(iv) ∈ sts || 
                 throw(ArgumentError("Variable $var is not a variable of the ODESystem. Called variables must be variables of the ODESystem."))
 
             isequal(arg, iv) || isparameter(arg) || arg isa Integer || arg isa AbstractFloat || 
                 throw(ArgumentError("Invalid argument specified for variable $var. The argument of the variable should be either $iv, a parameter, or a value specifying the time that the constraint holds."))
+
+            isparameter(arg) && push!(constraintps, arg)
         else
             var ∈ sts && @warn "Variable $var has no argument. It will be interpreted as $var($iv), and the constraint will apply to the entire interval."
         end
     end
 
-    for var in constraintps
-        !iscall(var) && continue
-
-        if length(arguments(var)) > 1
-            throw(ArgumentError("Too many arguments for parameter $var in equation $eq."))
-        elseif length(arguments(var)) == 1
-            arg = first(arguments(var))
-            operation(var) ∈ ps || throw(ArgumentError("Parameter $var is not a parameter of the ODESystem. Called parameters must be parameters of the ODESystem."))
-
-            isequal(arg, iv) || 
-                arg isa Integer || 
-                    arg isa AbstractFloat || 
-                        throw(ArgumentError("Invalid argument specified for callable parameter $var. The argument of the parameter should be either $iv, a parameter, or a value specifying the time that the constraint holds."))
-        end
-    end
-end
-
-function process_constraint_system(constraints::Vector{Equation}, sts, ps, iv, subsys::Vector{ODESystem}; name = :cons)
-    isempty(constraints) && return nothing
-
-    constraintsts = OrderedSet()
-    constraintps = OrderedSet()
-
-    for cons in constraints
-        syms = collect_vars!(constraintsts, constraintps, cons, iv)
-    end
-    validate_constraint_syms(constraintsts, constraintps, Set(sts), Set(ps), iv)
-
-    constraint_subsys = get_constraintsystem.(subsys)
-    ConstraintsSystem(constraints, collect(constraintsts), collect(constraintps); systems = constraint_subsys, name)
+    ConstraintsSystem(constraints, collect(constraintsts), collect(constraintps); name = consname)
 end
diff --git a/test/odesystem.jl b/test/odesystem.jl
index 516fa7d415..0276bb4b4f 100644
--- a/test/odesystem.jl
+++ b/test/odesystem.jl
@@ -1562,23 +1562,28 @@ end
     # Test variables + parameters infer correctly.
     @mtkbuild sys = ODESystem(eqs, t; constraints = cons)
     @test issetequal(parameters(sys), [a, c, d, e])
-    @test issetequal(unknowns(sys), [x(t), y(t)])
+    @test issetequal(unknowns(sys), [x(t), y(t), z(t)])
 
     @parameters t_c
     cons = [x(t_c) ~ 3]
     @mtkbuild sys = ODESystem(eqs, t; constraints = cons)
-    @test_broken issetequal(parameters(sys), [a, e, t_c]) # TODO: unbreak this.
+    @test issetequal(parameters(sys), [a, e, t_c]) 
+
+    @parameters g(..) h i
+    cons = [g(h, i) * x(3) ~ c]
+    @mtkbuild sys = ODESystem(eqs, t; constraints = cons)
+    @test issetequal(parameters(sys), [g, h, i, a, e, c]) 
 
     # Test that bad constraints throw errors.
-    cons = [x(3, 4) ~ 3]
+    cons = [x(3, 4) ~ 3] # unknowns cannot have multiple args.
     @test_throws ArgumentError @mtkbuild sys = ODESystem(eqs, t; constraints = cons)
 
-    cons = [x(y(t)) ~ 2]
+    cons = [x(y(t)) ~ 2] # unknown arg must be parameter, value, or t
     @test_throws ArgumentError @mtkbuild sys = ODESystem(eqs, t; constraints = cons)
 
     @variables u(t) v
     cons = [x(t) * u ~ 3]
     @test_throws ArgumentError @mtkbuild sys = ODESystem(eqs, t; constraints = cons)
     cons = [x(t) * v ~ 3]
-    @test_nowarn @mtkbuild sys = ODESystem(eqs, t; constraints = cons)
+    @test_throws ArgumentError @mtkbuild sys = ODESystem(eqs, t; constraints = cons) # Need time argument.
 end

From c6ef04adac961b386f0f37c4750eee28273cff3f Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Tue, 28 Jan 2025 14:24:38 -0500
Subject: [PATCH 036/111] remove file

---
 src/systems/diffeqs/bvpsystem.jl | 217 -------------------------------
 1 file changed, 217 deletions(-)
 delete mode 100644 src/systems/diffeqs/bvpsystem.jl

diff --git a/src/systems/diffeqs/bvpsystem.jl b/src/systems/diffeqs/bvpsystem.jl
deleted file mode 100644
index ae3029fa14..0000000000
--- a/src/systems/diffeqs/bvpsystem.jl
+++ /dev/null
@@ -1,217 +0,0 @@
-"""
-$(TYPEDEF)
-
-A system of ordinary differential equations.
-
-# Fields
-$(FIELDS)
-
-# Example
-
-```julia
-using ModelingToolkit
-using ModelingToolkit: t_nounits as t, D_nounits as D
-
-@parameters σ ρ β
-@variables x(t) y(t) z(t)
-
-eqs = [D(x) ~ σ*(y-x),
-       D(y) ~ x*(ρ-z)-y,
-       D(z) ~ x*y - β*z]
-
-@named de = ODESystem(eqs,t,[x,y,z],[σ,ρ,β],tspan=(0, 1000.0))
-```
-"""
-struct ODESystem <: AbstractODESystem
-    """
-    A tag for the system. If two systems have the same tag, then they are
-    structurally identical.
-    """
-    tag::UInt
-    """The ODEs defining the system."""
-    eqs::Vector{Equation}
-    """Independent variable."""
-    iv::BasicSymbolic{Real}
-    """
-    Dependent (unknown) variables. Must not contain the independent variable.
-
-    N.B.: If `torn_matching !== nothing`, this includes all variables. Actual
-    ODE unknowns are determined by the `SelectedState()` entries in `torn_matching`.
-    """
-    unknowns::Vector
-    """Parameter variables. Must not contain the independent variable."""
-    ps::Vector
-    """Time span."""
-    tspan::Union{NTuple{2, Any}, Nothing}
-    """Array variables."""
-    var_to_name::Any
-    """Control parameters (some subset of `ps`)."""
-    ctrls::Vector
-    """Observed variables."""
-    observed::Vector{Equation}
-    """
-    Time-derivative matrix. Note: this field will not be defined until
-    [`calculate_tgrad`](@ref) is called on the system.
-    """
-    tgrad::RefValue{Vector{Num}}
-    """
-    Jacobian matrix. Note: this field will not be defined until
-    [`calculate_jacobian`](@ref) is called on the system.
-    """
-    jac::RefValue{Any}
-    """
-    Control Jacobian matrix. Note: this field will not be defined until
-    [`calculate_control_jacobian`](@ref) is called on the system.
-    """
-    ctrl_jac::RefValue{Any}
-    """
-    Note: this field will not be defined until
-    [`generate_factorized_W`](@ref) is called on the system.
-    """
-    Wfact::RefValue{Matrix{Num}}
-    """
-    Note: this field will not be defined until
-    [`generate_factorized_W`](@ref) is called on the system.
-    """
-    Wfact_t::RefValue{Matrix{Num}}
-    """
-    The name of the system.
-    """
-    name::Symbol
-    """
-    A description of the system.
-    """
-    description::String
-    """
-    The internal systems. These are required to have unique names.
-    """
-    systems::Vector{ODESystem}
-    """
-    The default values to use when initial conditions and/or
-    parameters are not supplied in `ODEProblem`.
-    """
-    defaults::Dict
-    """
-    The guesses to use as the initial conditions for the
-    initialization system.
-    """
-    guesses::Dict
-    """
-    Tearing result specifying how to solve the system.
-    """
-    torn_matching::Union{Matching, Nothing}
-    """
-    The system for performing the initialization.
-    """
-    initializesystem::Union{Nothing, NonlinearSystem}
-    """
-    Extra equations to be enforced during the initialization sequence.
-    """
-    initialization_eqs::Vector{Equation}
-    """
-    The schedule for the code generation process.
-    """
-    schedule::Any
-    """
-    Type of the system.
-    """
-    connector_type::Any
-    """
-    Inject assignment statements before the evaluation of the RHS function.
-    """
-    preface::Any
-    """
-    A `Vector{SymbolicContinuousCallback}` that model events.
-    The integrator will use root finding to guarantee that it steps at each zero crossing.
-    """
-    continuous_events::Vector{SymbolicContinuousCallback}
-    """
-    A `Vector{SymbolicDiscreteCallback}` that models events. Symbolic
-    analog to `SciMLBase.DiscreteCallback` that executes an affect when a given condition is
-    true at the end of an integration step.
-    """
-    discrete_events::Vector{SymbolicDiscreteCallback}
-    """
-    Topologically sorted parameter dependency equations, where all symbols are parameters and
-    the LHS is a single parameter.
-    """
-    parameter_dependencies::Vector{Equation}
-    """
-    Metadata for the system, to be used by downstream packages.
-    """
-    metadata::Any
-    """
-    Metadata for MTK GUI.
-    """
-    gui_metadata::Union{Nothing, GUIMetadata}
-    """
-    A boolean indicating if the given `ODESystem` represents a system of DDEs.
-    """
-    is_dde::Bool
-    """
-    A list of points to provide to the solver as tstops. Uses the same syntax as discrete
-    events.
-    """
-    tstops::Vector{Any}
-    """
-    Cache for intermediate tearing state.
-    """
-    tearing_state::Any
-    """
-    Substitutions generated by tearing.
-    """
-    substitutions::Any
-    """
-    If a model `sys` is complete, then `sys.x` no longer performs namespacing.
-    """
-    complete::Bool
-    """
-    Cached data for fast symbolic indexing.
-    """
-    index_cache::Union{Nothing, IndexCache}
-    """
-    A list of discrete subsystems.
-    """
-    discrete_subsystems::Any
-    """
-    A list of actual unknowns needed to be solved by solvers.
-    """
-    solved_unknowns::Union{Nothing, Vector{Any}}
-    """
-    A vector of vectors of indices for the split parameters.
-    """
-    split_idxs::Union{Nothing, Vector{Vector{Int}}}
-    """
-    The hierarchical parent system before simplification.
-    """
-    parent::Any
-
-    function ODESystem(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad,
-            jac, ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses,
-            torn_matching, initializesystem, initialization_eqs, schedule,
-            connector_type, preface, cevents,
-            devents, parameter_dependencies,
-            metadata = nothing, gui_metadata = nothing, is_dde = false,
-            tstops = [], tearing_state = nothing,
-            substitutions = nothing, complete = false, index_cache = nothing,
-            discrete_subsystems = nothing, solved_unknowns = nothing,
-            split_idxs = nothing, parent = nothing; checks::Union{Bool, Int} = true)
-        if checks == true || (checks & CheckComponents) > 0
-            check_independent_variables([iv])
-            check_variables(dvs, iv)
-            check_parameters(ps, iv)
-            check_equations(deqs, iv)
-            check_equations(equations(cevents), iv)
-        end
-        if checks == true || (checks & CheckUnits) > 0
-            u = __get_unit_type(dvs, ps, iv)
-            check_units(u, deqs)
-        end
-        new(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac,
-            ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses, torn_matching,
-            initializesystem, initialization_eqs, schedule, connector_type, preface,
-            cevents, devents, parameter_dependencies, metadata,
-            gui_metadata, is_dde, tstops, tearing_state, substitutions, complete, index_cache,
-            discrete_subsystems, solved_unknowns, split_idxs, parent)
-    end
-end

From 78782256b5f3c3f82ae1869b2d6737ea7a4b7acc Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Tue, 28 Jan 2025 14:26:12 -0500
Subject: [PATCH 037/111] up

---
 src/systems/optimization/constraints_system.jl | 3 ---
 1 file changed, 3 deletions(-)

diff --git a/src/systems/optimization/constraints_system.jl b/src/systems/optimization/constraints_system.jl
index 61e190e672..3c2aed192d 100644
--- a/src/systems/optimization/constraints_system.jl
+++ b/src/systems/optimization/constraints_system.jl
@@ -90,9 +90,6 @@ struct ConstraintsSystem <: AbstractTimeIndependentSystem
             complete = false, index_cache = nothing;
             checks::Union{Bool, Int} = true)
 
-        ##if checks == true || (checks & CheckComponents) > 0 
-        ##    check_variables(unknowns, constraints)
-        ##end
         if checks == true || (checks & CheckUnits) > 0
             u = __get_unit_type(unknowns, ps)
             check_units(u, constraints)

From 0493b5d2b49e72872976caf2e9c57381e272557a Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Tue, 28 Jan 2025 14:28:46 -0500
Subject: [PATCH 038/111] remove lines

---
 src/systems/optimization/constraints_system.jl | 2 --
 1 file changed, 2 deletions(-)

diff --git a/src/systems/optimization/constraints_system.jl b/src/systems/optimization/constraints_system.jl
index 3c2aed192d..03225fc900 100644
--- a/src/systems/optimization/constraints_system.jl
+++ b/src/systems/optimization/constraints_system.jl
@@ -89,7 +89,6 @@ struct ConstraintsSystem <: AbstractTimeIndependentSystem
             tearing_state = nothing, substitutions = nothing,
             complete = false, index_cache = nothing;
             checks::Union{Bool, Int} = true)
-
         if checks == true || (checks & CheckUnits) > 0
             u = __get_unit_type(unknowns, ps)
             check_units(u, constraints)
@@ -254,4 +253,3 @@ function get_cmap(sys::ConstraintsSystem, exprs = nothing)
     cmap = map(x -> x ~ getdefault(x), cs)
     return cmap, cs
 end
-

From 1d32b6ea8438bda25dc91e095458c5f377c801c0 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Tue, 28 Jan 2025 14:50:42 -0500
Subject: [PATCH 039/111] up

---
 src/systems/diffeqs/odesystem.jl | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl
index 8d0e09fce5..1154bcd4c8 100644
--- a/src/systems/diffeqs/odesystem.jl
+++ b/src/systems/diffeqs/odesystem.jl
@@ -368,8 +368,8 @@ function ODESystem(eqs, iv; constraints = Equation[], kwargs...)
     end
     algevars = setdiff(allunknowns, diffvars)
 
+    consvars = OrderedSet()
     if !isempty(constraints)
-        consvars = OrderedSet()
         constraintsystem = process_constraint_system(constraints, allunknowns, new_ps, iv)
         for st in get_unknowns(constraintsystem)
             iscall(st) ? 

From 2b3ca96d23f0010609867f6586449aa7dd9c6ddb Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Tue, 28 Jan 2025 15:05:02 -0500
Subject: [PATCH 040/111] up

---
 src/systems/diffeqs/odesystem.jl | 1 +
 1 file changed, 1 insertion(+)

diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl
index 1154bcd4c8..1754079745 100644
--- a/src/systems/diffeqs/odesystem.jl
+++ b/src/systems/diffeqs/odesystem.jl
@@ -369,6 +369,7 @@ function ODESystem(eqs, iv; constraints = Equation[], kwargs...)
     algevars = setdiff(allunknowns, diffvars)
 
     consvars = OrderedSet()
+    constraintsystem = nothing
     if !isempty(constraints)
         constraintsystem = process_constraint_system(constraints, allunknowns, new_ps, iv)
         for st in get_unknowns(constraintsystem)

From 0324522a419219c8f59f1b584c1ccd65454dc794 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Tue, 28 Jan 2025 15:05:57 -0500
Subject: [PATCH 041/111] fix typo

---
 src/systems/diffeqs/odesystem.jl | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl
index 1754079745..2ed7155f6c 100644
--- a/src/systems/diffeqs/odesystem.jl
+++ b/src/systems/diffeqs/odesystem.jl
@@ -291,7 +291,7 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
     end
 
     if !isempty(systems)
-        cons = get_constraintsystems.(systems)
+        cons = get_constraintsystem.(systems)
         @set! constraintsystem.systems = cons
     end
 

From 2a079becdd97c10431ad8a85f707adbb39519e59 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Tue, 28 Jan 2025 15:46:23 -0500
Subject: [PATCH 042/111] Fix setter

---
 src/systems/diffeqs/odesystem.jl | 11 ++++++++---
 1 file changed, 8 insertions(+), 3 deletions(-)

diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl
index 2ed7155f6c..dac0878d26 100644
--- a/src/systems/diffeqs/odesystem.jl
+++ b/src/systems/diffeqs/odesystem.jl
@@ -290,9 +290,14 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
         is_dde = _check_if_dde(deqs, iv′, systems)
     end
 
-    if !isempty(systems)
-        cons = get_constraintsystem.(systems)
-        @set! constraintsystem.systems = cons
+    if !isempty(systems) && !isnothing(constraintsystem)
+        conssystems = ConstraintsSystem[]
+        for sys in systems
+            cons = get_constraintsystem(sys)
+            cons !== nothing && push!(conssystems, cons) 
+        end
+        @show conssystems
+        @set! constraintsystem.systems = conssystems
     end
 
     ODESystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),

From d70a470d546e9c79417fb4c0869956cd1800de00 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Tue, 28 Jan 2025 16:21:54 -0500
Subject: [PATCH 043/111] fix

---
 test/odesystem.jl | 1 +
 1 file changed, 1 insertion(+)

diff --git a/test/odesystem.jl b/test/odesystem.jl
index 963847731c..53f75bf377 100644
--- a/test/odesystem.jl
+++ b/test/odesystem.jl
@@ -1660,3 +1660,4 @@ end
     @test_throws ArgumentError @mtkbuild sys = ODESystem(eqs, t; constraints = cons)
     cons = [x(t) * v ~ 3]
     @test_throws ArgumentError @mtkbuild sys = ODESystem(eqs, t; constraints = cons) # Need time argument.
+end

From 37092f12c1fb6958b4d1ab6a917be99221371e61 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Wed, 29 Jan 2025 14:32:17 -0500
Subject: [PATCH 044/111] lower tol

---
 test/bvproblem.jl | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/test/bvproblem.jl b/test/bvproblem.jl
index c6c124d09e..30fde44531 100644
--- a/test/bvproblem.jl
+++ b/test/bvproblem.jl
@@ -1,6 +1,6 @@
 ### TODO: update when BoundaryValueDiffEqAscher is updated to use the normal boundary condition conventions 
 
-using BoundaryValueDiffEq, OrdinaryDiffEq, BoundaryValueDiffEqAscher
+using BoundaryValueDiffEq, OrdinaryDiffEqDefault, BoundaryValueDiffEqAscher
 using BenchmarkTools
 using ModelingToolkit
 using SciMLBase
@@ -166,7 +166,7 @@ let
     @test sol1 ≈ sol2 ≈ sol3
 end
 
-function test_solvers(solvers, prob, u0map, constraints, equations = []; dt = 0.05, atol = 1e-3)
+function test_solvers(solvers, prob, u0map, constraints, equations = []; dt = 0.05, atol = 1e-2)
     for solver in solvers
         println("Solver: $solver")
         sol = @btime solve($prob, $solver(), dt = $dt, abstol = $atol)
@@ -214,7 +214,7 @@ let
     bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lksys, u0map, tspan; guesses)
     test_solvers(solvers, bvp, u0map, constr; dt = 0.05)
 
-    # Testing that more complicated constr give correct solutions.
+    # Testing that more complicated constraints give correct solutions.
     constr = [y(.2) + x(.8) ~ 3., y(.3) ~ 2.]
     @mtkbuild lksys = ODESystem(eqs, t; constraints = constr)
     bvp = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(lksys, u0map, tspan; guesses)

From 4142923e7e7e38e90c65952265c02e93fc081472 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Wed, 29 Jan 2025 15:04:39 -0500
Subject: [PATCH 045/111] up

---
 .../implicit_discrete_system.jl               | 36 +++++++++----------
 1 file changed, 18 insertions(+), 18 deletions(-)

diff --git a/src/systems/discrete_system/implicit_discrete_system.jl b/src/systems/discrete_system/implicit_discrete_system.jl
index 404fa7ff49..56e7545c6a 100644
--- a/src/systems/discrete_system/implicit_discrete_system.jl
+++ b/src/systems/discrete_system/implicit_discrete_system.jl
@@ -23,7 +23,7 @@ struct ImplicitDiscreteSystem <: AbstractTimeDependentSystem
     structurally identical.
     """
     tag::UInt
-    """The differential equations defining the discrete system."""
+    """The difference equations defining the discrete system."""
     eqs::Vector{Equation}
     """Independent variable."""
     iv::BasicSymbolic{Real}
@@ -48,10 +48,10 @@ struct ImplicitDiscreteSystem <: AbstractTimeDependentSystem
     """
     The internal systems. These are required to have unique names.
     """
-    systems::Vector{DiscreteSystem}
+    systems::Vector{ImplicitDiscreteSystem}
     """
     The default values to use when initial conditions and/or
-    parameters are not supplied in `DiscreteProblem`.
+    parameters are not supplied in `ImplicitDiscreteProblem`.
     """
     defaults::Dict
     """
@@ -136,11 +136,11 @@ end
 
 """
     $(TYPEDSIGNATURES)
-Constructs a DiscreteSystem.
+Constructs a ImplicitDiscreteSystem.
 """
 function ImplicitDiscreteSystem(eqs::AbstractVector{<:Equation}, iv, dvs, ps;
         observed = Num[],
-        systems = DiscreteSystem[],
+        systems = ImplicitDiscreteSystem[],
         tspan = nothing,
         name = nothing,
         description = "",
@@ -162,12 +162,12 @@ function ImplicitDiscreteSystem(eqs::AbstractVector{<:Equation}, iv, dvs, ps;
     dvs′ = value.(dvs)
     ps′ = value.(ps)
     if any(hasderiv, eqs) || any(hashold, eqs) || any(hassample, eqs) || any(hasdiff, eqs)
-        error("Equations in a `DiscreteSystem` can only have `Shift` operators.")
+        error("Equations in a `ImplicitDiscreteSystem` can only have `Shift` operators.")
     end
     if !(isempty(default_u0) && isempty(default_p))
         Base.depwarn(
             "`default_u0` and `default_p` are deprecated. Use `defaults` instead.",
-            :DiscreteSystem, force = true)
+            :ImplicitDiscreteSystem, force = true)
     end
 
     defaults = Dict{Any, Any}(todict(defaults))
@@ -190,7 +190,7 @@ function ImplicitDiscreteSystem(eqs::AbstractVector{<:Equation}, iv, dvs, ps;
     if length(unique(sysnames)) != length(sysnames)
         throw(ArgumentError("System names must be unique."))
     end
-    DiscreteSystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
+    ImplicitDiscreteSystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
         eqs, iv′, dvs′, ps′, tspan, var_to_name, observed, name, description, systems,
         defaults, guesses, initializesystem, initialization_eqs, preface, connector_type,
         parameter_dependencies, metadata, gui_metadata, kwargs...)
@@ -206,7 +206,7 @@ function ImplicitDiscreteSystem(eqs, iv; kwargs...)
         collect_vars!(allunknowns, ps, eq, iv; op = Shift)
         if iscall(eq.lhs) && operation(eq.lhs) isa Shift
             isequal(iv, operation(eq.lhs).t) ||
-                throw(ArgumentError("A DiscreteSystem can only have one independent variable."))
+                throw(ArgumentError("An ImplicitDiscreteSystem can only have one independent variable."))
             eq.lhs in diffvars &&
                 throw(ArgumentError("The shift variable $(eq.lhs) is not unique in the system of equations."))
             push!(diffvars, eq.lhs)
@@ -233,16 +233,16 @@ function ImplicitDiscreteSystem(eqs, iv; kwargs...)
             push!(new_ps, p)
         end
     end
-    return DiscreteSystem(eqs, iv,
+    return ImplicitDiscreteSystem(eqs, iv,
         collect(allunknowns), collect(new_ps); kwargs...)
 end
 
-function flatten(sys::DiscreteSystem, noeqs = false)
+function flatten(sys::ImplicitDiscreteSystem, noeqs = false)
     systems = get_systems(sys)
     if isempty(systems)
         return sys
     else
-        return DiscreteSystem(noeqs ? Equation[] : equations(sys),
+        return ImplicitDiscreteSystem(noeqs ? Equation[] : equations(sys),
             get_iv(sys),
             unknowns(sys),
             parameters(sys),
@@ -258,14 +258,14 @@ function flatten(sys::DiscreteSystem, noeqs = false)
 end
 
 function generate_function(
-        sys::DiscreteSystem, dvs = unknowns(sys), ps = parameters(sys); wrap_code = identity, kwargs...)
+        sys::ImplicitDiscreteSystem, dvs = unknowns(sys), ps = parameters(sys); wrap_code = identity, kwargs...)
     exprs = [eq.rhs for eq in equations(sys)]
     wrap_code = wrap_code .∘ wrap_array_vars(sys, exprs) .∘
                 wrap_parameter_dependencies(sys, false)
     generate_custom_function(sys, exprs, dvs, ps; wrap_code, kwargs...)
 end
 
-function shift_u0map_forward(sys::DiscreteSystem, u0map, defs)
+function shift_u0map_forward(sys::ImplicitDiscreteSystem, u0map, defs)
     iv = get_iv(sys)
     updated = AnyDict()
     for k in collect(keys(u0map))
@@ -291,7 +291,7 @@ end
 Generates an ImplicitDiscreteProblem from an ImplicitDiscreteSystem.
 """
 function SciMLBase.ImplicitDiscreteProblem(
-        sys::DiscreteSystem, u0map = [], tspan = get_tspan(sys),
+        sys::ImplicitDiscreteSystem, u0map = [], tspan = get_tspan(sys),
         parammap = SciMLBase.NullParameters();
         eval_module = @__MODULE__,
         eval_expression = false,
@@ -309,7 +309,7 @@ function SciMLBase.ImplicitDiscreteProblem(
     u0map = to_varmap(u0map, dvs)
     u0map = shift_u0map_forward(sys, u0map, defaults(sys))
     f, u0, p = process_SciMLProblem(
-        DiscreteFunction, sys, u0map, parammap; eval_expression, eval_module)
+        ImplicitDiscreteFunction, sys, u0map, parammap; eval_expression, eval_module)
     u0 = f(u0, p, tspan[1])
     ImplicitDiscreteProblem(f, u0, tspan, p; kwargs...)
 end
@@ -348,7 +348,7 @@ function SciMLBase.ImplicitDiscreteFunction{iip, specialize}(
 
     if specialize === SciMLBase.FunctionWrapperSpecialize && iip
         if u0 === nothing || p === nothing || t === nothing
-            error("u0, p, and t must be specified for FunctionWrapperSpecialize on DiscreteFunction.")
+            error("u0, p, and t must be specified for FunctionWrapperSpecialize on ImplicitDiscreteFunction.")
         end
         f = SciMLBase.wrapfun_iip(f, (u0, u0, p, t))
     end
@@ -395,7 +395,7 @@ function ImplicitDiscreteFunctionExpr{iip}(sys::ImplicitDiscreteSystem, dvs = un
 
     ex = quote
         $_f
-        DiscreteFunction{$iip}($fsym)
+        ImplicitDiscreteFunction{$iip}($fsym)
     end
     !linenumbers ? Base.remove_linenums!(ex) : ex
 end

From e5eb8bd35a89ab8b0d5346c218ea0bdf6cb15f37 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Wed, 29 Jan 2025 15:05:35 -0500
Subject: [PATCH 046/111] fix Project.toml

---
 Project.toml | 1 -
 1 file changed, 1 deletion(-)

diff --git a/Project.toml b/Project.toml
index 69ad24c8df..865b88658b 100644
--- a/Project.toml
+++ b/Project.toml
@@ -143,7 +143,6 @@ SparseArrays = "1"
 SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
 StaticArrays = "0.10, 0.11, 0.12, 1.0"
 StochasticDelayDiffEq = "1.8.1"
-StochasticDiffEq = "6.72.1"
 SymbolicIndexingInterface = "0.3.36"
 SymbolicUtils = "3.10"
 Symbolics = "6.22.1"

From afc468996d65bff484d66327149f2f8ccabfa83e Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Thu, 30 Jan 2025 11:03:18 -0500
Subject: [PATCH 047/111] add test file

---
 src/ModelingToolkit.jl                        |  2 +
 .../discrete_system/discrete_system.jl        |  1 +
 .../implicit_discrete_system.jl               | 48 +++++++++++++------
 src/systems/systemstructure.jl                |  4 +-
 test/implicit_discrete_system.jl              |  2 +
 test/runtests.jl                              |  3 +-
 6 files changed, 44 insertions(+), 16 deletions(-)
 create mode 100644 test/implicit_discrete_system.jl

diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl
index 2710d7d1e4..4083f2609d 100644
--- a/src/ModelingToolkit.jl
+++ b/src/ModelingToolkit.jl
@@ -164,6 +164,7 @@ include("systems/diffeqs/modelingtoolkitize.jl")
 include("systems/diffeqs/basic_transformations.jl")
 
 include("systems/discrete_system/discrete_system.jl")
+include("systems/discrete_system/implicit_discrete_system.jl")
 
 include("systems/jumps/jumpsystem.jl")
 
@@ -229,6 +230,7 @@ export DAEFunctionExpr, DAEProblemExpr
 export SDESystem, SDEFunction, SDEFunctionExpr, SDEProblemExpr
 export SystemStructure
 export DiscreteSystem, DiscreteProblem, DiscreteFunction, DiscreteFunctionExpr
+export ImplicitDiscreteSystem, ImplicitDiscreteProblem, ImplicitDiscreteFunction, ImplicitDiscreteFunctionExpr
 export JumpSystem
 export ODEProblem, SDEProblem
 export NonlinearFunction, NonlinearFunctionExpr
diff --git a/src/systems/discrete_system/discrete_system.jl b/src/systems/discrete_system/discrete_system.jl
index bd5c72eec7..0ad3317ea1 100644
--- a/src/systems/discrete_system/discrete_system.jl
+++ b/src/systems/discrete_system/discrete_system.jl
@@ -233,6 +233,7 @@ function DiscreteSystem(eqs, iv; kwargs...)
             push!(new_ps, p)
         end
     end
+    @show allunknowns
     return DiscreteSystem(eqs, iv,
         collect(allunknowns), collect(new_ps); kwargs...)
 end
diff --git a/src/systems/discrete_system/implicit_discrete_system.jl b/src/systems/discrete_system/implicit_discrete_system.jl
index 56e7545c6a..1e1b5d5f6f 100644
--- a/src/systems/discrete_system/implicit_discrete_system.jl
+++ b/src/systems/discrete_system/implicit_discrete_system.jl
@@ -10,11 +10,10 @@ using ModelingToolkit: t_nounits as t
 @parameters σ=28.0 ρ=10.0 β=8/3 δt=0.1
 @variables x(t)=1.0 y(t)=0.0 z(t)=0.0
 k = ShiftIndex(t)
-eqs = [x(k+1) ~ σ*(y-x),
-       y(k+1) ~ x*(ρ-z)-y,
-       z(k+1) ~ x*y - β*z]
-@named de = ImplicitDiscreteSystem(eqs,t,[x,y,z],[σ,ρ,β]; tspan = (0, 1000.0)) # or
-@named de = ImplicitDiscreteSystem(eqs)
+eqs = [x ~ σ*(y(k-1)-x),
+       y ~ x*(ρ-z(k-1))-y,
+       z ~ x(k-1)*y - β*z]
+@named ide = ImplicitDiscreteSystem(eqs,t,[x,y,z],[σ,ρ,β]; tspan = (0, 1000.0))
 ```
 """
 struct ImplicitDiscreteSystem <: AbstractTimeDependentSystem
@@ -136,6 +135,7 @@ end
 
 """
     $(TYPEDSIGNATURES)
+
 Constructs a ImplicitDiscreteSystem.
 """
 function ImplicitDiscreteSystem(eqs::AbstractVector{<:Equation}, iv, dvs, ps;
@@ -170,6 +170,8 @@ function ImplicitDiscreteSystem(eqs::AbstractVector{<:Equation}, iv, dvs, ps;
             :ImplicitDiscreteSystem, force = true)
     end
 
+    # Copy equations to canonical form, but do not touch array expressions
+    eqs = [wrap(eq.lhs) isa Symbolics.Arr ? eq : 0 ~ eq.rhs - eq.lhs for eq in eqs]
     defaults = Dict{Any, Any}(todict(defaults))
     guesses = Dict{Any, Any}(todict(guesses))
     var_to_name = Dict()
@@ -236,6 +238,8 @@ function ImplicitDiscreteSystem(eqs, iv; kwargs...)
     return ImplicitDiscreteSystem(eqs, iv,
         collect(allunknowns), collect(new_ps); kwargs...)
 end
+# basically at every timestep it should build a nonlinear solve
+# Previous timesteps should be treated as parameters? is this right? 
 
 function flatten(sys::ImplicitDiscreteSystem, noeqs = false)
     systems = get_systems(sys)
@@ -259,10 +263,25 @@ end
 
 function generate_function(
         sys::ImplicitDiscreteSystem, dvs = unknowns(sys), ps = parameters(sys); wrap_code = identity, kwargs...)
-    exprs = [eq.rhs for eq in equations(sys)]
-    wrap_code = wrap_code .∘ wrap_array_vars(sys, exprs) .∘
-                wrap_parameter_dependencies(sys, false)
-    generate_custom_function(sys, exprs, dvs, ps; wrap_code, kwargs...)
+    if !iscomplete(sys)
+        error("A completed system is required. Call `complete` or `structural_simplify` on the system.")
+    end
+    p = (reorder_parameters(sys, unwrap.(ps))..., cachesyms...)
+    isscalar = !(exprs isa AbstractArray)
+    pre, sol_states = get_substitutions_and_solved_unknowns(sys, isscalar ? [exprs] : exprs)
+    if postprocess_fbody === nothing
+        postprocess_fbody = pre
+    end
+    if states === nothing
+        states = sol_states
+    end
+    exprs = [eq.lhs - eq.rhs for eq in equations(sys)]
+    u = map(Shift(iv, -1), dvs)
+    u_next = dvs
+
+    wrap_code = wrap_code .∘ wrap_array_vars(sys, exprs) .∘ wrap_parameter_dependencies(sys, false)
+
+    build_function(exprs, u_next, u, p..., get_iv(sys))
 end
 
 function shift_u0map_forward(sys::ImplicitDiscreteSystem, u0map, defs)
@@ -311,7 +330,7 @@ function SciMLBase.ImplicitDiscreteProblem(
     f, u0, p = process_SciMLProblem(
         ImplicitDiscreteFunction, sys, u0map, parammap; eval_expression, eval_module)
     u0 = f(u0, p, tspan[1])
-    ImplicitDiscreteProblem(f, u0, tspan, p; kwargs...)
+    NonlinearProblem(f, u0, tspan, p; kwargs...)
 end
 
 function SciMLBase.ImplicitDiscreteFunction(sys::ImplicitDiscreteSystem, args...; kwargs...)
@@ -337,14 +356,15 @@ function SciMLBase.ImplicitDiscreteFunction{iip, specialize}(
         eval_module = @__MODULE__,
         analytic = nothing,
         kwargs...) where {iip, specialize}
+
     if !iscomplete(sys)
         error("A completed `ImplicitDiscreteSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `ImplicitDiscreteProblem`")
     end
     f_gen = generate_function(sys, dvs, ps; expression = Val{true},
         expression_module = eval_module, kwargs...)
     f_oop, f_iip = eval_or_rgf.(f_gen; eval_expression, eval_module)
-    f(u, p, t) = f_oop(u, p, t)
-    f(du, u, p, t) = f_iip(du, u, p, t)
+    f(u_next, u, p, t) = f_oop(u_next, u, p, t)
+    f(resid, u_next, u, p, t) = f_iip(resid, u_next, u, p, t)
 
     if specialize === SciMLBase.FunctionWrapperSpecialize && iip
         if u0 === nothing || p === nothing || t === nothing
@@ -379,8 +399,8 @@ struct ImplicitDiscreteFunctionClosure{O, I} <: Function
     f_oop::O
     f_iip::I
 end
-(f::ImplicitDiscreteFunctionClosure)(u, p, t) = f.f_oop(u, p, t)
-(f::ImplicitDiscreteFunctionClosure)(du, u, p, t) = f.f_iip(du, u, p, t)
+(f::ImplicitDiscreteFunctionClosure)(u_next, u, p, t) = f.f_oop(u_next, u, p, t)
+(f::ImplicitDiscreteFunctionClosure)(resid, u_next, u, p, t) = f.f_iip(resid, u_next, u, p, t)
 
 function ImplicitDiscreteFunctionExpr{iip}(sys::ImplicitDiscreteSystem, dvs = unknowns(sys),
         ps = parameters(sys), u0 = nothing;
diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl
index 1bdc11f06a..215d843c2e 100644
--- a/src/systems/systemstructure.jl
+++ b/src/systems/systemstructure.jl
@@ -432,7 +432,7 @@ function TearingState(sys; quick_cancel = false, check = true)
         SystemStructure(complete(var_to_diff), complete(eq_to_diff),
             complete(graph), nothing, var_types, sys isa DiscreteSystem),
         Any[])
-    if sys isa DiscreteSystem
+    if sys isa DiscreteSystem || sys isa ImplicitDiscreteSystem
         ts = shift_discrete_system(ts)
     end
     return ts
@@ -456,6 +456,8 @@ function lower_order_var(dervar, t)
     diffvar
 end
 
+"""
+"""
 function shift_discrete_system(ts::TearingState)
     @unpack fullvars, sys = ts
     discvars = OrderedSet()
diff --git a/test/implicit_discrete_system.jl b/test/implicit_discrete_system.jl
new file mode 100644
index 0000000000..adfb9d2fc1
--- /dev/null
+++ b/test/implicit_discrete_system.jl
@@ -0,0 +1,2 @@
+
+#init
diff --git a/test/runtests.jl b/test/runtests.jl
index 52875fdae5..bd5eecacd0 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -79,7 +79,8 @@ end
             @safetestset "Variable Utils Test" include("variable_utils.jl")
             @safetestset "Variable Metadata Test" include("test_variable_metadata.jl")
             @safetestset "OptimizationSystem Test" include("optimizationsystem.jl")
-            @safetestset "Discrete System" include("discrete_system.jl")
+            @safetestset "DiscreteSystem Test" include("discrete_system.jl")
+            @safetestset "ImplicitDiscreteSystem Test" include("implicit_discrete_system.jl")
             @safetestset "SteadyStateSystem Test" include("steadystatesystems.jl")
             @safetestset "SDESystem Test" include("sdesystem.jl")
             @safetestset "DDESystem Test" include("dde.jl")

From 397298fea4c78b4ccfbfc6f8d1c61fde5aabe16d Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Thu, 30 Jan 2025 11:34:02 -0500
Subject: [PATCH 048/111] working on structural simplification

---
 src/ModelingToolkit.jl                                  | 1 +
 src/systems/discrete_system/discrete_system.jl          | 2 +-
 src/systems/discrete_system/implicit_discrete_system.jl | 8 ++++----
 src/systems/systemstructure.jl                          | 4 ++--
 4 files changed, 8 insertions(+), 7 deletions(-)

diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl
index 4083f2609d..2dbef8f272 100644
--- a/src/ModelingToolkit.jl
+++ b/src/ModelingToolkit.jl
@@ -123,6 +123,7 @@ abstract type AbstractTimeIndependentSystem <: AbstractSystem end
 abstract type AbstractODESystem <: AbstractTimeDependentSystem end
 abstract type AbstractMultivariateSystem <: AbstractSystem end
 abstract type AbstractOptimizationSystem <: AbstractTimeIndependentSystem end
+abstract type AbstractDiscreteSystem <: AbstractTimeDependentSystem end
 
 function independent_variable end
 
diff --git a/src/systems/discrete_system/discrete_system.jl b/src/systems/discrete_system/discrete_system.jl
index 0ad3317ea1..f2d271dfe4 100644
--- a/src/systems/discrete_system/discrete_system.jl
+++ b/src/systems/discrete_system/discrete_system.jl
@@ -17,7 +17,7 @@ eqs = [x(k+1) ~ σ*(y-x),
 @named de = DiscreteSystem(eqs)
 ```
 """
-struct DiscreteSystem <: AbstractTimeDependentSystem
+struct DiscreteSystem <: AbstractDiscreteSystem
     """
     A tag for the system. If two systems have the same tag, then they are
     structurally identical.
diff --git a/src/systems/discrete_system/implicit_discrete_system.jl b/src/systems/discrete_system/implicit_discrete_system.jl
index 1e1b5d5f6f..3a137eb8c1 100644
--- a/src/systems/discrete_system/implicit_discrete_system.jl
+++ b/src/systems/discrete_system/implicit_discrete_system.jl
@@ -10,13 +10,13 @@ using ModelingToolkit: t_nounits as t
 @parameters σ=28.0 ρ=10.0 β=8/3 δt=0.1
 @variables x(t)=1.0 y(t)=0.0 z(t)=0.0
 k = ShiftIndex(t)
-eqs = [x ~ σ*(y(k-1)-x),
-       y ~ x*(ρ-z(k-1))-y,
-       z ~ x(k-1)*y - β*z]
+eqs = [x ~ σ*(y-x),
+       y ~ x*(ρ-z)-y,
+       z ~ x*y - β*z]
 @named ide = ImplicitDiscreteSystem(eqs,t,[x,y,z],[σ,ρ,β]; tspan = (0, 1000.0))
 ```
 """
-struct ImplicitDiscreteSystem <: AbstractTimeDependentSystem
+struct ImplicitDiscreteSystem <: AbstractDiscreteSystem
     """
     A tag for the system. If two systems have the same tag, then they are
     structurally identical.
diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl
index 215d843c2e..86d1702041 100644
--- a/src/systems/systemstructure.jl
+++ b/src/systems/systemstructure.jl
@@ -430,9 +430,9 @@ function TearingState(sys; quick_cancel = false, check = true)
 
     ts = TearingState(sys, fullvars,
         SystemStructure(complete(var_to_diff), complete(eq_to_diff),
-            complete(graph), nothing, var_types, sys isa DiscreteSystem),
+                        complete(graph), nothing, var_types, sys isa AbstractDiscreteSystem),
         Any[])
-    if sys isa DiscreteSystem || sys isa ImplicitDiscreteSystem
+    if sys isa AbstractDiscreteSystem 
         ts = shift_discrete_system(ts)
     end
     return ts

From 2ae79ae19941d5502a7a83d13770284cad7b2cc3 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Thu, 30 Jan 2025 11:34:36 -0500
Subject: [PATCH 049/111] revert to OrdinaryDiffEq

---
 test/bvproblem.jl | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/test/bvproblem.jl b/test/bvproblem.jl
index 30fde44531..069372aef7 100644
--- a/test/bvproblem.jl
+++ b/test/bvproblem.jl
@@ -1,6 +1,6 @@
 ### TODO: update when BoundaryValueDiffEqAscher is updated to use the normal boundary condition conventions 
 
-using BoundaryValueDiffEq, OrdinaryDiffEqDefault, BoundaryValueDiffEqAscher
+using BoundaryValueDiffEq, OrdinaryDiffEq, BoundaryValueDiffEqAscher
 using BenchmarkTools
 using ModelingToolkit
 using SciMLBase

From be1b270298f11b23306cf4b5a9ea677a0ba26894 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Thu, 30 Jan 2025 14:26:04 -0500
Subject: [PATCH 050/111] change variable renaming

---
 src/structural_transformation/utils.jl | 11 +++++++++--
 1 file changed, 9 insertions(+), 2 deletions(-)

diff --git a/src/structural_transformation/utils.jl b/src/structural_transformation/utils.jl
index bd24a1d017..29c0f66756 100644
--- a/src/structural_transformation/utils.jl
+++ b/src/structural_transformation/utils.jl
@@ -452,8 +452,15 @@ end
 function lower_varname_withshift(var, iv, order)
     order == 0 && return var
     if ModelingToolkit.isoperator(var, ModelingToolkit.Shift)
-        op = operation(var)
-        return Shift(op.t, order)(var)
+        O = only(arguments(var))
+        oldop = operation(O)
+        ds = "$iv-$order"
+        d_separator = 'ˍ'
+        newname = Symbol(string(nameof(oldop)), d_separator, ds)
+
+        newvar = maketerm(typeof(O), Symbolics.rename(oldop, newname), Symbolics.children(O), Symbolics.metadata(O))
+        setmetadata(newvar, Symbolics.VariableSource, (:variables, newname))
+        return ModelingToolkit._with_unit(identity, newvar, iv)
     end
     return lower_varname_with_unit(var, iv, order)
 end

From cc23c888984a3e811d8be32f7f229b214dd7c7b0 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Thu, 30 Jan 2025 14:29:14 -0500
Subject: [PATCH 051/111] up

---
 src/systems/discrete_system/discrete_system.jl |  3 +--
 src/systems/systems.jl                         | 14 +++++++-------
 src/systems/systemstructure.jl                 |  6 ++----
 3 files changed, 10 insertions(+), 13 deletions(-)

diff --git a/src/systems/discrete_system/discrete_system.jl b/src/systems/discrete_system/discrete_system.jl
index f2d271dfe4..bd5c72eec7 100644
--- a/src/systems/discrete_system/discrete_system.jl
+++ b/src/systems/discrete_system/discrete_system.jl
@@ -17,7 +17,7 @@ eqs = [x(k+1) ~ σ*(y-x),
 @named de = DiscreteSystem(eqs)
 ```
 """
-struct DiscreteSystem <: AbstractDiscreteSystem
+struct DiscreteSystem <: AbstractTimeDependentSystem
     """
     A tag for the system. If two systems have the same tag, then they are
     structurally identical.
@@ -233,7 +233,6 @@ function DiscreteSystem(eqs, iv; kwargs...)
             push!(new_ps, p)
         end
     end
-    @show allunknowns
     return DiscreteSystem(eqs, iv,
         collect(allunknowns), collect(new_ps); kwargs...)
 end
diff --git a/src/systems/systems.jl b/src/systems/systems.jl
index 5ff28eba26..9c8c272c5c 100644
--- a/src/systems/systems.jl
+++ b/src/systems/systems.jl
@@ -39,13 +39,13 @@ function structural_simplify(
     else
         newsys = newsys′
     end
-    # if newsys isa DiscreteSystem &&
-    #    any(eq -> symbolic_type(eq.lhs) == NotSymbolic(), equations(newsys))
-    #     error("""
-    #         Encountered algebraic equations when simplifying discrete system. This is \
-    #         not yet supported.
-    #     """)
-    # end
+    if newsys isa DiscreteSystem &&
+       any(eq -> symbolic_type(eq.lhs) == NotSymbolic(), equations(newsys))
+        error("""
+            Encountered algebraic equations when simplifying discrete system. This is \
+            not yet supported.
+        """)
+    end
     for pass in additional_passes
         newsys = pass(newsys)
     end
diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl
index 86d1702041..1bdc11f06a 100644
--- a/src/systems/systemstructure.jl
+++ b/src/systems/systemstructure.jl
@@ -430,9 +430,9 @@ function TearingState(sys; quick_cancel = false, check = true)
 
     ts = TearingState(sys, fullvars,
         SystemStructure(complete(var_to_diff), complete(eq_to_diff),
-                        complete(graph), nothing, var_types, sys isa AbstractDiscreteSystem),
+            complete(graph), nothing, var_types, sys isa DiscreteSystem),
         Any[])
-    if sys isa AbstractDiscreteSystem 
+    if sys isa DiscreteSystem
         ts = shift_discrete_system(ts)
     end
     return ts
@@ -456,8 +456,6 @@ function lower_order_var(dervar, t)
     diffvar
 end
 
-"""
-"""
 function shift_discrete_system(ts::TearingState)
     @unpack fullvars, sys = ts
     discvars = OrderedSet()

From f5e5aff51476b66e8fc152ed9b8c8fb4307e8f4c Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Thu, 30 Jan 2025 14:30:08 -0500
Subject: [PATCH 052/111] revert all

---
 src/ModelingToolkit.jl                        |   3 -
 .../implicit_discrete_system.jl               | 426 ------------------
 test/implicit_discrete_system.jl              |   2 -
 test/runtests.jl                              |   3 +-
 4 files changed, 1 insertion(+), 433 deletions(-)
 delete mode 100644 src/systems/discrete_system/implicit_discrete_system.jl
 delete mode 100644 test/implicit_discrete_system.jl

diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl
index 2dbef8f272..2710d7d1e4 100644
--- a/src/ModelingToolkit.jl
+++ b/src/ModelingToolkit.jl
@@ -123,7 +123,6 @@ abstract type AbstractTimeIndependentSystem <: AbstractSystem end
 abstract type AbstractODESystem <: AbstractTimeDependentSystem end
 abstract type AbstractMultivariateSystem <: AbstractSystem end
 abstract type AbstractOptimizationSystem <: AbstractTimeIndependentSystem end
-abstract type AbstractDiscreteSystem <: AbstractTimeDependentSystem end
 
 function independent_variable end
 
@@ -165,7 +164,6 @@ include("systems/diffeqs/modelingtoolkitize.jl")
 include("systems/diffeqs/basic_transformations.jl")
 
 include("systems/discrete_system/discrete_system.jl")
-include("systems/discrete_system/implicit_discrete_system.jl")
 
 include("systems/jumps/jumpsystem.jl")
 
@@ -231,7 +229,6 @@ export DAEFunctionExpr, DAEProblemExpr
 export SDESystem, SDEFunction, SDEFunctionExpr, SDEProblemExpr
 export SystemStructure
 export DiscreteSystem, DiscreteProblem, DiscreteFunction, DiscreteFunctionExpr
-export ImplicitDiscreteSystem, ImplicitDiscreteProblem, ImplicitDiscreteFunction, ImplicitDiscreteFunctionExpr
 export JumpSystem
 export ODEProblem, SDEProblem
 export NonlinearFunction, NonlinearFunctionExpr
diff --git a/src/systems/discrete_system/implicit_discrete_system.jl b/src/systems/discrete_system/implicit_discrete_system.jl
deleted file mode 100644
index 3a137eb8c1..0000000000
--- a/src/systems/discrete_system/implicit_discrete_system.jl
+++ /dev/null
@@ -1,426 +0,0 @@
-"""
-$(TYPEDEF)
-An implicit system of difference equations.
-# Fields
-$(FIELDS)
-# Example
-```
-using ModelingToolkit
-using ModelingToolkit: t_nounits as t
-@parameters σ=28.0 ρ=10.0 β=8/3 δt=0.1
-@variables x(t)=1.0 y(t)=0.0 z(t)=0.0
-k = ShiftIndex(t)
-eqs = [x ~ σ*(y-x),
-       y ~ x*(ρ-z)-y,
-       z ~ x*y - β*z]
-@named ide = ImplicitDiscreteSystem(eqs,t,[x,y,z],[σ,ρ,β]; tspan = (0, 1000.0))
-```
-"""
-struct ImplicitDiscreteSystem <: AbstractDiscreteSystem
-    """
-    A tag for the system. If two systems have the same tag, then they are
-    structurally identical.
-    """
-    tag::UInt
-    """The difference equations defining the discrete system."""
-    eqs::Vector{Equation}
-    """Independent variable."""
-    iv::BasicSymbolic{Real}
-    """Dependent (state) variables. Must not contain the independent variable."""
-    unknowns::Vector
-    """Parameter variables. Must not contain the independent variable."""
-    ps::Vector
-    """Time span."""
-    tspan::Union{NTuple{2, Any}, Nothing}
-    """Array variables."""
-    var_to_name::Any
-    """Observed states."""
-    observed::Vector{Equation}
-    """
-    The name of the system
-    """
-    name::Symbol
-    """
-    A description of the system.
-    """
-    description::String
-    """
-    The internal systems. These are required to have unique names.
-    """
-    systems::Vector{ImplicitDiscreteSystem}
-    """
-    The default values to use when initial conditions and/or
-    parameters are not supplied in `ImplicitDiscreteProblem`.
-    """
-    defaults::Dict
-    """
-    The guesses to use as the initial conditions for the
-    initialization system.
-    """
-    guesses::Dict
-    """
-    The system for performing the initialization.
-    """
-    initializesystem::Union{Nothing, NonlinearSystem}
-    """
-    Extra equations to be enforced during the initialization sequence.
-    """
-    initialization_eqs::Vector{Equation}
-    """
-    Inject assignment statements before the evaluation of the RHS function.
-    """
-    preface::Any
-    """
-    Type of the system.
-    """
-    connector_type::Any
-    """
-    Topologically sorted parameter dependency equations, where all symbols are parameters and
-    the LHS is a single parameter.
-    """
-    parameter_dependencies::Vector{Equation}
-    """
-    Metadata for the system, to be used by downstream packages.
-    """
-    metadata::Any
-    """
-    Metadata for MTK GUI.
-    """
-    gui_metadata::Union{Nothing, GUIMetadata}
-    """
-    Cache for intermediate tearing state.
-    """
-    tearing_state::Any
-    """
-    Substitutions generated by tearing.
-    """
-    substitutions::Any
-    """
-    If a model `sys` is complete, then `sys.x` no longer performs namespacing.
-    """
-    complete::Bool
-    """
-    Cached data for fast symbolic indexing.
-    """
-    index_cache::Union{Nothing, IndexCache}
-    """
-    The hierarchical parent system before simplification.
-    """
-    parent::Any
-    isscheduled::Bool
-
-    function ImplicitDiscreteSystem(tag, discreteEqs, iv, dvs, ps, tspan, var_to_name,
-            observed, name, description, systems, defaults, guesses, initializesystem,
-            initialization_eqs, preface, connector_type, parameter_dependencies = Equation[],
-            metadata = nothing, gui_metadata = nothing,
-            tearing_state = nothing, substitutions = nothing,
-            complete = false, index_cache = nothing, parent = nothing,
-            isscheduled = false;
-            checks::Union{Bool, Int} = true)
-        if checks == true || (checks & CheckComponents) > 0
-            check_independent_variables([iv])
-            check_variables(dvs, iv)
-            check_parameters(ps, iv)
-        end
-        if checks == true || (checks & CheckUnits) > 0
-            u = __get_unit_type(dvs, ps, iv)
-            check_units(u, discreteEqs)
-        end
-        new(tag, discreteEqs, iv, dvs, ps, tspan, var_to_name, observed, name, description,
-            systems, defaults, guesses, initializesystem, initialization_eqs,
-            preface, connector_type, parameter_dependencies, metadata, gui_metadata,
-            tearing_state, substitutions, complete, index_cache, parent, isscheduled)
-    end
-end
-
-"""
-    $(TYPEDSIGNATURES)
-
-Constructs a ImplicitDiscreteSystem.
-"""
-function ImplicitDiscreteSystem(eqs::AbstractVector{<:Equation}, iv, dvs, ps;
-        observed = Num[],
-        systems = ImplicitDiscreteSystem[],
-        tspan = nothing,
-        name = nothing,
-        description = "",
-        default_u0 = Dict(),
-        default_p = Dict(),
-        guesses = Dict(),
-        initializesystem = nothing,
-        initialization_eqs = Equation[],
-        defaults = _merge(Dict(default_u0), Dict(default_p)),
-        preface = nothing,
-        connector_type = nothing,
-        parameter_dependencies = Equation[],
-        metadata = nothing,
-        gui_metadata = nothing,
-        kwargs...)
-    name === nothing &&
-        throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
-    iv′ = value(iv)
-    dvs′ = value.(dvs)
-    ps′ = value.(ps)
-    if any(hasderiv, eqs) || any(hashold, eqs) || any(hassample, eqs) || any(hasdiff, eqs)
-        error("Equations in a `ImplicitDiscreteSystem` can only have `Shift` operators.")
-    end
-    if !(isempty(default_u0) && isempty(default_p))
-        Base.depwarn(
-            "`default_u0` and `default_p` are deprecated. Use `defaults` instead.",
-            :ImplicitDiscreteSystem, force = true)
-    end
-
-    # Copy equations to canonical form, but do not touch array expressions
-    eqs = [wrap(eq.lhs) isa Symbolics.Arr ? eq : 0 ~ eq.rhs - eq.lhs for eq in eqs]
-    defaults = Dict{Any, Any}(todict(defaults))
-    guesses = Dict{Any, Any}(todict(guesses))
-    var_to_name = Dict()
-    process_variables!(var_to_name, defaults, guesses, dvs′)
-    process_variables!(var_to_name, defaults, guesses, ps′)
-    process_variables!(
-        var_to_name, defaults, guesses, [eq.lhs for eq in parameter_dependencies])
-    process_variables!(
-        var_to_name, defaults, guesses, [eq.rhs for eq in parameter_dependencies])
-    defaults = Dict{Any, Any}(value(k) => value(v)
-    for (k, v) in pairs(defaults) if v !== nothing)
-    guesses = Dict{Any, Any}(value(k) => value(v)
-    for (k, v) in pairs(guesses) if v !== nothing)
-
-    isempty(observed) || collect_var_to_name!(var_to_name, (eq.lhs for eq in observed))
-
-    sysnames = nameof.(systems)
-    if length(unique(sysnames)) != length(sysnames)
-        throw(ArgumentError("System names must be unique."))
-    end
-    ImplicitDiscreteSystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
-        eqs, iv′, dvs′, ps′, tspan, var_to_name, observed, name, description, systems,
-        defaults, guesses, initializesystem, initialization_eqs, preface, connector_type,
-        parameter_dependencies, metadata, gui_metadata, kwargs...)
-end
-
-function ImplicitDiscreteSystem(eqs, iv; kwargs...)
-    eqs = collect(eqs)
-    diffvars = OrderedSet()
-    allunknowns = OrderedSet()
-    ps = OrderedSet()
-    iv = value(iv)
-    for eq in eqs
-        collect_vars!(allunknowns, ps, eq, iv; op = Shift)
-        if iscall(eq.lhs) && operation(eq.lhs) isa Shift
-            isequal(iv, operation(eq.lhs).t) ||
-                throw(ArgumentError("An ImplicitDiscreteSystem can only have one independent variable."))
-            eq.lhs in diffvars &&
-                throw(ArgumentError("The shift variable $(eq.lhs) is not unique in the system of equations."))
-            push!(diffvars, eq.lhs)
-        end
-    end
-    for eq in get(kwargs, :parameter_dependencies, Equation[])
-        if eq isa Pair
-            collect_vars!(allunknowns, ps, eq, iv)
-        else
-            collect_vars!(allunknowns, ps, eq, iv)
-        end
-    end
-    new_ps = OrderedSet()
-    for p in ps
-        if iscall(p) && operation(p) === getindex
-            par = arguments(p)[begin]
-            if Symbolics.shape(Symbolics.unwrap(par)) !== Symbolics.Unknown() &&
-               all(par[i] in ps for i in eachindex(par))
-                push!(new_ps, par)
-            else
-                push!(new_ps, p)
-            end
-        else
-            push!(new_ps, p)
-        end
-    end
-    return ImplicitDiscreteSystem(eqs, iv,
-        collect(allunknowns), collect(new_ps); kwargs...)
-end
-# basically at every timestep it should build a nonlinear solve
-# Previous timesteps should be treated as parameters? is this right? 
-
-function flatten(sys::ImplicitDiscreteSystem, noeqs = false)
-    systems = get_systems(sys)
-    if isempty(systems)
-        return sys
-    else
-        return ImplicitDiscreteSystem(noeqs ? Equation[] : equations(sys),
-            get_iv(sys),
-            unknowns(sys),
-            parameters(sys),
-            observed = observed(sys),
-            defaults = defaults(sys),
-            guesses = guesses(sys),
-            initialization_eqs = initialization_equations(sys),
-            name = nameof(sys),
-            description = description(sys),
-            metadata = get_metadata(sys),
-            checks = false)
-    end
-end
-
-function generate_function(
-        sys::ImplicitDiscreteSystem, dvs = unknowns(sys), ps = parameters(sys); wrap_code = identity, kwargs...)
-    if !iscomplete(sys)
-        error("A completed system is required. Call `complete` or `structural_simplify` on the system.")
-    end
-    p = (reorder_parameters(sys, unwrap.(ps))..., cachesyms...)
-    isscalar = !(exprs isa AbstractArray)
-    pre, sol_states = get_substitutions_and_solved_unknowns(sys, isscalar ? [exprs] : exprs)
-    if postprocess_fbody === nothing
-        postprocess_fbody = pre
-    end
-    if states === nothing
-        states = sol_states
-    end
-    exprs = [eq.lhs - eq.rhs for eq in equations(sys)]
-    u = map(Shift(iv, -1), dvs)
-    u_next = dvs
-
-    wrap_code = wrap_code .∘ wrap_array_vars(sys, exprs) .∘ wrap_parameter_dependencies(sys, false)
-
-    build_function(exprs, u_next, u, p..., get_iv(sys))
-end
-
-function shift_u0map_forward(sys::ImplicitDiscreteSystem, u0map, defs)
-    iv = get_iv(sys)
-    updated = AnyDict()
-    for k in collect(keys(u0map))
-        v = u0map[k]
-        if !((op = operation(k)) isa Shift)
-            error("Initial conditions must be for the past state of the unknowns. Instead of providing the condition for $k, provide the condition for $(Shift(iv, -1)(k)).")
-        end
-        updated[Shift(iv, op.steps + 1)(arguments(k)[1])] = v
-    end
-    for var in unknowns(sys)
-        op = operation(var)
-        op isa Shift || continue
-        haskey(updated, var) && continue
-        root = first(arguments(var))
-        haskey(defs, root) || error("Initial condition for $var not provided.")
-        updated[var] = defs[root]
-    end
-    return updated
-end
-
-"""
-    $(TYPEDSIGNATURES)
-Generates an ImplicitDiscreteProblem from an ImplicitDiscreteSystem.
-"""
-function SciMLBase.ImplicitDiscreteProblem(
-        sys::ImplicitDiscreteSystem, u0map = [], tspan = get_tspan(sys),
-        parammap = SciMLBase.NullParameters();
-        eval_module = @__MODULE__,
-        eval_expression = false,
-        use_union = false,
-        kwargs...
-)
-    if !iscomplete(sys)
-        error("A completed `ImplicitDiscreteSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `ImplicitDiscreteProblem`")
-    end
-    dvs = unknowns(sys)
-    ps = parameters(sys)
-    eqs = equations(sys)
-    iv = get_iv(sys)
-
-    u0map = to_varmap(u0map, dvs)
-    u0map = shift_u0map_forward(sys, u0map, defaults(sys))
-    f, u0, p = process_SciMLProblem(
-        ImplicitDiscreteFunction, sys, u0map, parammap; eval_expression, eval_module)
-    u0 = f(u0, p, tspan[1])
-    NonlinearProblem(f, u0, tspan, p; kwargs...)
-end
-
-function SciMLBase.ImplicitDiscreteFunction(sys::ImplicitDiscreteSystem, args...; kwargs...)
-    ImplicitDiscreteFunction{true}(sys, args...; kwargs...)
-end
-
-function SciMLBase.ImplicitDiscreteFunction{true}(sys::ImplicitDiscreteSystem, args...; kwargs...)
-    ImplicitDiscreteFunction{true, SciMLBase.AutoSpecialize}(sys, args...; kwargs...)
-end
-
-function SciMLBase.ImplicitDiscreteFunction{false}(sys::ImplicitDiscreteSystem, args...; kwargs...)
-    ImplicitDiscreteFunction{false, SciMLBase.FullSpecialize}(sys, args...; kwargs...)
-end
-function SciMLBase.ImplicitDiscreteFunction{iip, specialize}(
-        sys::ImplicitDiscreteSystem,
-        dvs = unknowns(sys),
-        ps = parameters(sys),
-        u0 = nothing;
-        version = nothing,
-        p = nothing,
-        t = nothing,
-        eval_expression = false,
-        eval_module = @__MODULE__,
-        analytic = nothing,
-        kwargs...) where {iip, specialize}
-
-    if !iscomplete(sys)
-        error("A completed `ImplicitDiscreteSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `ImplicitDiscreteProblem`")
-    end
-    f_gen = generate_function(sys, dvs, ps; expression = Val{true},
-        expression_module = eval_module, kwargs...)
-    f_oop, f_iip = eval_or_rgf.(f_gen; eval_expression, eval_module)
-    f(u_next, u, p, t) = f_oop(u_next, u, p, t)
-    f(resid, u_next, u, p, t) = f_iip(resid, u_next, u, p, t)
-
-    if specialize === SciMLBase.FunctionWrapperSpecialize && iip
-        if u0 === nothing || p === nothing || t === nothing
-            error("u0, p, and t must be specified for FunctionWrapperSpecialize on ImplicitDiscreteFunction.")
-        end
-        f = SciMLBase.wrapfun_iip(f, (u0, u0, p, t))
-    end
-
-    observedfun = ObservedFunctionCache(
-        sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))
-
-    ImplicitDiscreteFunction{iip, specialize}(f;
-        sys = sys,
-        observed = observedfun,
-        analytic = analytic)
-end
-
-"""
-```julia
-ImplicitDiscreteFunctionExpr{iip}(sys::ImplicitDiscreteSystem, dvs = states(sys),
-                                  ps = parameters(sys);
-                                  version = nothing,
-                                  kwargs...) where {iip}
-```
-
-Create a Julia expression for an `ImplicitDiscreteFunction` from the [`ImplicitDiscreteSystem`](@ref).
-The arguments `dvs` and `ps` are used to set the order of the dependent
-variable and parameter vectors, respectively.
-"""
-struct ImplicitDiscreteFunctionExpr{iip} end
-struct ImplicitDiscreteFunctionClosure{O, I} <: Function
-    f_oop::O
-    f_iip::I
-end
-(f::ImplicitDiscreteFunctionClosure)(u_next, u, p, t) = f.f_oop(u_next, u, p, t)
-(f::ImplicitDiscreteFunctionClosure)(resid, u_next, u, p, t) = f.f_iip(resid, u_next, u, p, t)
-
-function ImplicitDiscreteFunctionExpr{iip}(sys::ImplicitDiscreteSystem, dvs = unknowns(sys),
-        ps = parameters(sys), u0 = nothing;
-        version = nothing, p = nothing,
-        linenumbers = false,
-        simplify = false,
-        kwargs...) where {iip}
-    f_oop, f_iip = generate_function(sys, dvs, ps; expression = Val{true}, kwargs...)
-
-    fsym = gensym(:f)
-    _f = :($fsym = $ImplicitDiscreteFunctionClosure($f_oop, $f_iip))
-
-    ex = quote
-        $_f
-        ImplicitDiscreteFunction{$iip}($fsym)
-    end
-    !linenumbers ? Base.remove_linenums!(ex) : ex
-end
-
-function ImplicitDiscreteFunctionExpr(sys::ImplicitDiscreteSystem, args...; kwargs...)
-    ImplicitDiscreteFunctionExpr{true}(sys, args...; kwargs...)
-end
-
diff --git a/test/implicit_discrete_system.jl b/test/implicit_discrete_system.jl
deleted file mode 100644
index adfb9d2fc1..0000000000
--- a/test/implicit_discrete_system.jl
+++ /dev/null
@@ -1,2 +0,0 @@
-
-#init
diff --git a/test/runtests.jl b/test/runtests.jl
index bd5eecacd0..52875fdae5 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -79,8 +79,7 @@ end
             @safetestset "Variable Utils Test" include("variable_utils.jl")
             @safetestset "Variable Metadata Test" include("test_variable_metadata.jl")
             @safetestset "OptimizationSystem Test" include("optimizationsystem.jl")
-            @safetestset "DiscreteSystem Test" include("discrete_system.jl")
-            @safetestset "ImplicitDiscreteSystem Test" include("implicit_discrete_system.jl")
+            @safetestset "Discrete System" include("discrete_system.jl")
             @safetestset "SteadyStateSystem Test" include("steadystatesystems.jl")
             @safetestset "SDESystem Test" include("sdesystem.jl")
             @safetestset "DDESystem Test" include("dde.jl")

From b5305950778d256c70462e7858f4cb3b79ff835e Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Mon, 3 Feb 2025 07:58:14 -0500
Subject: [PATCH 053/111] working simplification

---
 .../symbolics_tearing.jl                      | 58 +++++++++++++++++--
 src/structural_transformation/utils.jl        | 18 +++---
 src/systems/systemstructure.jl                | 30 +++++++---
 3 files changed, 88 insertions(+), 18 deletions(-)

diff --git a/src/structural_transformation/symbolics_tearing.jl b/src/structural_transformation/symbolics_tearing.jl
index d95951c41e..8c802d9765 100644
--- a/src/structural_transformation/symbolics_tearing.jl
+++ b/src/structural_transformation/symbolics_tearing.jl
@@ -240,7 +240,7 @@ end
 function tearing_reassemble(state::TearingState, var_eq_matching,
         full_var_eq_matching = nothing; simplify = false, mm = nothing, cse_hack = true, array_hack = true)
     @unpack fullvars, sys, structure = state
-    @unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure
+    @unpack solvable_graph, var_to_diff, eq_to_diff, graph, lowest_shift = structure
     extra_vars = Int[]
     if full_var_eq_matching !== nothing
         for v in 𝑑vertices(state.structure.graph)
@@ -279,6 +279,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
         iv = D = nothing
     end
     diff_to_var = invview(var_to_diff)
+
     dummy_sub = Dict()
     for var in 1:length(fullvars)
         dv = var_to_diff[var]
@@ -310,7 +311,10 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
             diff_to_var[dv] = nothing
         end
     end
+    @show neweqs
 
+    println("Post state selection.")
+    
     # `SelectedState` information is no longer needed past here. State selection
     # is done. All non-differentiated variables are algebraic variables, and all
     # variables that appear differentiated are differential variables.
@@ -331,10 +335,28 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
                 order += 1
                 dv = dv′
             end
+            println("Order")
+            @show fullvars[dv]
+            is_only_discrete(state.structure) && begin
+                var = fullvars[dv]
+                key = operation(var) isa Shift ? only(arguments(var)) : var
+                order = -get(lowest_shift, key, 0) - order
+            end
             order, dv
         end
     end
 
+    lower_name = is_only_discrete(state.structure) ? lower_varname_withshift : lower_varname_with_unit
+    # is_only_discrete(state.structure) && for v in 1:length(fullvars)
+    #     var = fullvars[v]
+    #     op = operation(var)
+    #     if op isa Shift
+    #         x = only(arguments(var))
+    #         lowest_shift_idxs[v]
+    #         op.steps == lowest_shift[x] && (fullvars[v] = lower_varname_withshift(var, iv, -op.steps))
+    #     end
+    # end
+
     #retear = BitSet()
     # There are three cases where we want to generate new variables to convert
     # the system into first order (semi-implicit) ODEs.
@@ -384,9 +406,28 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
     eq_var_matching = invview(var_eq_matching)
     linear_eqs = mm === nothing ? Dict{Int, Int}() :
                  Dict(reverse(en) for en in enumerate(mm.nzrows))
+
     for v in 1:length(var_to_diff)
-        dv = var_to_diff[v]
+        println()
+        @show fullvars
+        @show diff_to_var
+        is_highest_discrete = begin
+            var = fullvars[v]
+            op = operation(var)
+            if (!is_only_discrete(state.structure) || op isa Shift) 
+                false
+            elseif !haskey(lowest_shift, var)
+                false
+            else
+                low = lowest_shift[var]
+                idx = findfirst(x -> isequal(x, Shift(iv, low)(var)), fullvars)
+                true
+            end
+        end
+        dv = is_highest_discrete ? idx : var_to_diff[v]
+        @show (v, fullvars[v], dv)
         dv isa Int || continue
+
         solved = var_eq_matching[dv] isa Int
         solved && continue
         # check if there's `D(x) = x_t` already
@@ -404,17 +445,19 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
                diff_to_var[v_t] === nothing)
                 @assert dv in rvs
                 dummy_eq = eq
+                @show "FOUND DUMMY EQ"
                 @goto FOUND_DUMMY_EQ
             end
         end
         dx = fullvars[dv]
         # add `x_t`
-        order, lv = var_order(dv)
-        x_t = lower_varname_withshift(fullvars[lv], iv, order)
+        @show order, lv = var_order(dv)
+        x_t = lower_name(fullvars[lv], iv, order)
         push!(fullvars, simplify_shifts(x_t))
         v_t = length(fullvars)
         v_t_idx = add_vertex!(var_to_diff)
         add_vertex!(graph, DST)
+        @show x_t, dx
         # TODO: do we care about solvable_graph? We don't use them after
         # `dummy_derivative_graph`.
         add_vertex!(solvable_graph, DST)
@@ -433,10 +476,16 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
         add_edge!(solvable_graph, dummy_eq, dv)
         @assert nsrcs(graph) == nsrcs(solvable_graph) == dummy_eq
         @label FOUND_DUMMY_EQ
+        @show is_highest_discrete
+        @show diff_to_var
+        @show v_t, dv
+        # If var = x with no shift, then 
+        is_highest_discrete && (lowest_shift[x_t] = lowest_shift[fullvars[v]])
         var_to_diff[v_t] = var_to_diff[dv]
         var_eq_matching[dv] = unassigned
         eq_var_matching[dummy_eq] = dv
     end
+    @show neweqs
 
     # Will reorder equations and unknowns to be:
     # [diffeqs; ...]
@@ -537,6 +586,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
 
     deps = Vector{Int}[i == 1 ? Int[] : collect(1:(i - 1))
                        for i in 1:length(solved_equations)]
+
     # Contract the vertices in the structure graph to make the structure match
     # the new reality of the system we've just created.
     graph = contract_variables(graph, var_eq_matching, varsperm, eqsperm,
diff --git a/src/structural_transformation/utils.jl b/src/structural_transformation/utils.jl
index 29c0f66756..dd063c0b07 100644
--- a/src/structural_transformation/utils.jl
+++ b/src/structural_transformation/utils.jl
@@ -451,18 +451,22 @@ end
 
 function lower_varname_withshift(var, iv, order)
     order == 0 && return var
+    ds = "$iv-$order"
+    d_separator = 'ˍ'
+
     if ModelingToolkit.isoperator(var, ModelingToolkit.Shift)
         O = only(arguments(var))
         oldop = operation(O)
-        ds = "$iv-$order"
-        d_separator = 'ˍ'
         newname = Symbol(string(nameof(oldop)), d_separator, ds)
-
-        newvar = maketerm(typeof(O), Symbolics.rename(oldop, newname), Symbolics.children(O), Symbolics.metadata(O))
-        setmetadata(newvar, Symbolics.VariableSource, (:variables, newname))
-        return ModelingToolkit._with_unit(identity, newvar, iv)
+    else
+        O = var
+        oldop = operation(var) 
+        varname = split(string(nameof(oldop)), d_separator)[1]
+        newname = Symbol(varname, d_separator, ds)
     end
-    return lower_varname_with_unit(var, iv, order)
+    newvar = maketerm(typeof(O), Symbolics.rename(oldop, newname), Symbolics.children(O), Symbolics.metadata(O))
+    setmetadata(newvar, Symbolics.VariableSource, (:variables, newname))
+    return ModelingToolkit._with_unit(identity, newvar, iv)
 end
 
 function isdoubleshift(var)
diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl
index 1bdc11f06a..44558ade6c 100644
--- a/src/systems/systemstructure.jl
+++ b/src/systems/systemstructure.jl
@@ -140,17 +140,21 @@ get_fullvars(ts::TransformationState) = ts.fullvars
 has_equations(::TransformationState) = true
 
 Base.@kwdef mutable struct SystemStructure
-    # Maps the (index of) a variable to the (index of) the variable describing
-    # its derivative.
+    """Maps the (index of) a variable to the (index of) the variable describing its derivative."""
     var_to_diff::DiffGraph
+    """Maps the (index of) a """
     eq_to_diff::DiffGraph
     # Can be access as
     # `graph` to automatically look at the bipartite graph
     # or as `torn` to assert that tearing has run.
+    """Incidence graph of the system of equations. An edge from equation x to variable y exists if variable y appears in equation x."""
     graph::BipartiteGraph{Int, Nothing}
+    """."""
     solvable_graph::Union{BipartiteGraph{Int, Nothing}, Nothing}
     var_types::Union{Vector{VariableType}, Nothing}
+    """Whether the system is discrete."""
     only_discrete::Bool
+    lowest_shift::Union{Dict, Nothing}
 end
 
 function Base.copy(structure::SystemStructure)
@@ -346,6 +350,8 @@ function TearingState(sys; quick_cancel = false, check = true)
             eqs[i] = eqs[i].lhs ~ rhs
         end
     end
+
+    ### Handle discrete variables
     lowest_shift = Dict()
     for var in fullvars
         if ModelingToolkit.isoperator(var, ModelingToolkit.Shift)
@@ -430,10 +436,10 @@ function TearingState(sys; quick_cancel = false, check = true)
 
     ts = TearingState(sys, fullvars,
         SystemStructure(complete(var_to_diff), complete(eq_to_diff),
-            complete(graph), nothing, var_types, sys isa DiscreteSystem),
+            complete(graph), nothing, var_types, sys isa DiscreteSystem, lowest_shift),
         Any[])
     if sys isa DiscreteSystem
-        ts = shift_discrete_system(ts)
+        ts = shift_discrete_system(ts, lowest_shift)
     end
     return ts
 end
@@ -456,17 +462,27 @@ function lower_order_var(dervar, t)
     diffvar
 end
 
-function shift_discrete_system(ts::TearingState)
+"""
+    Shift variable x by the largest shift s such that x(k-s) appears in the system of equations.
+    The lowest-shift term will have.
+"""
+function shift_discrete_system(ts::TearingState, lowest_shift)
     @unpack fullvars, sys = ts
+    return ts
     discvars = OrderedSet()
     eqs = equations(sys)
+
     for eq in eqs
         vars!(discvars, eq; op = Union{Sample, Hold})
     end
     iv = get_iv(sys)
-    discmap = Dict(k => StructuralTransformations.simplify_shifts(Shift(iv, 1)(k))
+    discmap = Dict(k => StructuralTransformations.simplify_shifts(Shift(iv, -get(lowest_shift, k, 0))(k))
     for k in discvars
-    if any(isequal(k), fullvars) && !isa(operation(k), Union{Sample, Hold}))
+    if any(isequal(k), fullvars) && !isa(operation(k), Union{Sample, Hold})) 
+
+    discmap = Dict(k => StructuralTransformations.simplify_shifts(Shift(iv, 1)(k))    for k in discvars 
+    if any(isequal(k), fullvars) && !isa(operation(k), Union{Sample, Hold})) 
+
     for i in eachindex(fullvars)
         fullvars[i] = StructuralTransformations.simplify_shifts(fast_substitute(
             fullvars[i], discmap; operator = Union{Sample, Hold}))

From fe372865a7dd79d1e576f19d50677b1977e70fd1 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Mon, 3 Feb 2025 14:21:03 -0500
Subject: [PATCH 054/111] solving equations

---
 .../symbolics_tearing.jl                      | 64 ++++++++++---------
 1 file changed, 34 insertions(+), 30 deletions(-)

diff --git a/src/structural_transformation/symbolics_tearing.jl b/src/structural_transformation/symbolics_tearing.jl
index 8c802d9765..599a75f4f7 100644
--- a/src/structural_transformation/symbolics_tearing.jl
+++ b/src/structural_transformation/symbolics_tearing.jl
@@ -237,6 +237,18 @@ function check_diff_graph(var_to_diff, fullvars)
 end
 =#
 
+function state_selection() 
+    
+end
+
+function create_new_deriv_variables() 
+    
+end
+
+function solve_solvable_equations() 
+    
+end
+
 function tearing_reassemble(state::TearingState, var_eq_matching,
         full_var_eq_matching = nothing; simplify = false, mm = nothing, cse_hack = true, array_hack = true)
     @unpack fullvars, sys, structure = state
@@ -323,6 +335,11 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
     is_solvable = let solvable_graph = solvable_graph
         (eq, iv) -> eq isa Int && iv isa Int && BipartiteEdge(eq, iv) in solvable_graph
     end
+    idx_to_lowest_shift = Dict{Int, Int}(var => 0 for var in 1:length(fullvars))
+    for (i,var) in enumerate(fullvars)
+        key = operation(var) isa Shift ? only(arguments(var)) : var
+        idx_to_lowest_shift[i] = get(lowest_shift, key, 0)
+    end
 
     # if var is like D(x)
     isdervar = let diff_to_var = diff_to_var
@@ -335,27 +352,11 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
                 order += 1
                 dv = dv′
             end
-            println("Order")
-            @show fullvars[dv]
-            is_only_discrete(state.structure) && begin
-                var = fullvars[dv]
-                key = operation(var) isa Shift ? only(arguments(var)) : var
-                order = -get(lowest_shift, key, 0) - order
-            end
+            is_only_discrete(state.structure) && (order = -idx_to_lowest_shift[dv] - order - 1)
             order, dv
         end
     end
-
     lower_name = is_only_discrete(state.structure) ? lower_varname_withshift : lower_varname_with_unit
-    # is_only_discrete(state.structure) && for v in 1:length(fullvars)
-    #     var = fullvars[v]
-    #     op = operation(var)
-    #     if op isa Shift
-    #         x = only(arguments(var))
-    #         lowest_shift_idxs[v]
-    #         op.steps == lowest_shift[x] && (fullvars[v] = lower_varname_withshift(var, iv, -op.steps))
-    #     end
-    # end
 
     #retear = BitSet()
     # There are three cases where we want to generate new variables to convert
@@ -408,9 +409,6 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
                  Dict(reverse(en) for en in enumerate(mm.nzrows))
 
     for v in 1:length(var_to_diff)
-        println()
-        @show fullvars
-        @show diff_to_var
         is_highest_discrete = begin
             var = fullvars[v]
             op = operation(var)
@@ -425,7 +423,6 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
             end
         end
         dv = is_highest_discrete ? idx : var_to_diff[v]
-        @show (v, fullvars[v], dv)
         dv isa Int || continue
 
         solved = var_eq_matching[dv] isa Int
@@ -445,19 +442,17 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
                diff_to_var[v_t] === nothing)
                 @assert dv in rvs
                 dummy_eq = eq
-                @show "FOUND DUMMY EQ"
                 @goto FOUND_DUMMY_EQ
             end
         end
         dx = fullvars[dv]
         # add `x_t`
-        @show order, lv = var_order(dv)
+        order, lv = var_order(dv)
         x_t = lower_name(fullvars[lv], iv, order)
         push!(fullvars, simplify_shifts(x_t))
         v_t = length(fullvars)
         v_t_idx = add_vertex!(var_to_diff)
         add_vertex!(graph, DST)
-        @show x_t, dx
         # TODO: do we care about solvable_graph? We don't use them after
         # `dummy_derivative_graph`.
         add_vertex!(solvable_graph, DST)
@@ -476,16 +471,12 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
         add_edge!(solvable_graph, dummy_eq, dv)
         @assert nsrcs(graph) == nsrcs(solvable_graph) == dummy_eq
         @label FOUND_DUMMY_EQ
-        @show is_highest_discrete
-        @show diff_to_var
-        @show v_t, dv
         # If var = x with no shift, then 
-        is_highest_discrete && (lowest_shift[x_t] = lowest_shift[fullvars[v]])
+        is_only_discrete(state.structure) && (idx_to_lowest_shift[v_t] = idx_to_lowest_shift[dv])
         var_to_diff[v_t] = var_to_diff[dv]
         var_eq_matching[dv] = unassigned
         eq_var_matching[dummy_eq] = dv
     end
-    @show neweqs
 
     # Will reorder equations and unknowns to be:
     # [diffeqs; ...]
@@ -501,9 +492,16 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
     subeqs = Equation[]
     solved_equations = Int[]
     solved_variables = Int[]
+
     # Solve solvable equations
+    println()
+    println("SOLVING SOLVABLE EQUATIONS.")
+    @show eq_var_matching
     toporder = topological_sort(DiCMOBiGraph{false}(graph, var_eq_matching))
     eqs = Iterators.reverse(toporder)
+    @show eqs
+    @show neweqs
+    @show fullvars
     total_sub = Dict()
     idep = iv
     for ieq in eqs
@@ -516,12 +514,18 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
                 isnothing(D) &&
                     error("Differential found in a non-differential system. Likely this is a bug in the construction of an initialization system. Please report this issue with a reproducible example. Offending equation: $(equations(sys)[ieq])")
                 order, lv = var_order(iv)
-                dx = D(simplify_shifts(lower_varname_withshift(
+                @show fullvars[iv]
+                @show (order, lv)
+                dx = D(simplify_shifts(lower_name(
                     fullvars[lv], idep, order - 1)))
+                @show dx
+                @show neweqs[ieq]
                 eq = dx ~ simplify_shifts(Symbolics.fixpoint_sub(
                     Symbolics.symbolic_linear_solve(neweqs[ieq],
                         fullvars[iv]),
                     total_sub; operator = ModelingToolkit.Shift))
+                @show total_sub
+                @show eq
                 for e in 𝑑neighbors(graph, iv)
                     e == ieq && continue
                     for v in 𝑠neighbors(graph, e)

From 13a242c35ed4ec7c19143ccde26f3a8e6edddf3f Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Mon, 3 Feb 2025 15:34:05 -0500
Subject: [PATCH 055/111] update to use updated codegen

---
 src/systems/diffeqs/abstractodesystem.jl | 33 ++++++++----------------
 test/bvproblem.jl                        | 11 ++++----
 2 files changed, 17 insertions(+), 27 deletions(-)

diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl
index f017cb902f..95e966a461 100644
--- a/src/systems/diffeqs/abstractodesystem.jl
+++ b/src/systems/diffeqs/abstractodesystem.jl
@@ -881,18 +881,23 @@ function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = []
     stidxmap = Dict([v => i for (i, v) in enumerate(sts)])
     u0_idxs = has_alg_eqs(sys) ? collect(1:length(sts)) : [stidxmap[k] for (k,v) in u0map]
 
-    bc = generate_function_bc(sys, u0, u0_idxs, tspan, iip)
+    fns = generate_function_bc(sys, u0, u0_idxs, tspan)
+    bc_oop, bc_iip = eval_or_rgf.(fns; eval_expression, eval_module) 
+    # bc(sol, p, t) = bc_oop(sol, p, t)
+    bc(resid, u, p, t) = bc_iip(resid, u, p, t)
+
     return BVProblem{iip}(f, bc, u0, tspan, p; kwargs...)
 end
 
 get_callback(prob::BVProblem) = error("BVP solvers do not support callbacks.")
 
 """
-    generate_function_bc(sys::ODESystem, u0, u0_idxs, tspan, iip)
+    generate_function_bc(sys::ODESystem, u0, u0_idxs, tspan)
 
     Given an ODESystem with constraints, generate the boundary condition function to pass to boundary value problem solvers.
+    Expression uses the constraints and the provided initial conditions.
 """
-function generate_function_bc(sys::ODESystem, u0, u0_idxs, tspan, iip)
+function generate_function_bc(sys::ODESystem, u0, u0_idxs, tspan; kwargs...)
     iv = get_iv(sys)
     sts = get_unknowns(sys)
     ps = get_ps(sys)
@@ -915,19 +920,6 @@ function generate_function_bc(sys::ODESystem, u0, u0_idxs, tspan, iip)
 
             cons = map(c -> Symbolics.substitute(c, Dict(x(t) => sol(t)[idx])), cons)
         end
-
-        for var in parameters(conssys) 
-            if iscall(var)
-                x = operation(var)
-                t = only(arguments(var))
-                idx = pidxmap[x]
-
-                cons = map(c -> Symbolics.substitute(c, Dict(x(t) => p[idx])), cons)
-            else
-                idx = pidxmap[var]
-                cons = map(c -> Symbolics.substitute(c, Dict(var => p[idx])), cons)
-            end
-        end
     end
 
     init_conds = Any[]
@@ -937,12 +929,9 @@ function generate_function_bc(sys::ODESystem, u0, u0_idxs, tspan, iip)
     end
 
     exprs = vcat(init_conds, cons)
-    bcs = Symbolics.build_function(exprs, sol, p, expression = Val{false})
-    if iip
-        return (resid, u, p, t) -> bcs[2](resid, u, p)
-    else
-        return (u, p, t) -> bcs[1](u, p)
-    end
+    _p = reorder_parameters(sys, ps)
+
+    build_function_wrapper(sys, exprs, sol, _p..., t; kwargs...)
 end
 
 """
diff --git a/test/bvproblem.jl b/test/bvproblem.jl
index 069372aef7..cedd0eef8d 100644
--- a/test/bvproblem.jl
+++ b/test/bvproblem.jl
@@ -1,6 +1,7 @@
 ### TODO: update when BoundaryValueDiffEqAscher is updated to use the normal boundary condition conventions 
 
-using BoundaryValueDiffEq, OrdinaryDiffEq, BoundaryValueDiffEqAscher
+using OrdinaryDiffEqVerner
+using BoundaryValueDiffEqMIRK, BoundaryValueDiffEqAscher
 using BenchmarkTools
 using ModelingToolkit
 using SciMLBase
@@ -207,22 +208,22 @@ let
     
     u0map = []
     tspan = (0.0, 1.0)
-    guesses = [x(t) => 4.0, y(t) => 2.]
+    guess = [x(t) => 4.0, y(t) => 2.0]
     constr = [x(.6) ~ 3.5, x(.3) ~ 7.]
     @mtkbuild lksys = ODESystem(eqs, t; constraints = constr)
 
-    bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lksys, u0map, tspan; guesses)
+    bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lksys, u0map, tspan; guesses = guess)
     test_solvers(solvers, bvp, u0map, constr; dt = 0.05)
 
     # Testing that more complicated constraints give correct solutions.
     constr = [y(.2) + x(.8) ~ 3., y(.3) ~ 2.]
     @mtkbuild lksys = ODESystem(eqs, t; constraints = constr)
-    bvp = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(lksys, u0map, tspan; guesses)
+    bvp = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(lksys, u0map, tspan; guesses = guess)
     test_solvers(solvers, bvp, u0map, constr; dt = 0.05)
 
     constr = [α * β - x(.6) ~ 0.0, y(.2) ~ 3.]
     @mtkbuild lksys = ODESystem(eqs, t; constraints = constr)
-    bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lksys, u0map, tspan; guesses)
+    bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lksys, u0map, tspan; guesses = guess)
     test_solvers(solvers, bvp, u0map, constr)
 end
 

From 2fcb9c930c0e4d607d509f1f13a01cef11df6b12 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Mon, 3 Feb 2025 15:34:37 -0500
Subject: [PATCH 056/111] up

---
 src/systems/diffeqs/abstractodesystem.jl | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl
index 95e966a461..056ecd80e4 100644
--- a/src/systems/diffeqs/abstractodesystem.jl
+++ b/src/systems/diffeqs/abstractodesystem.jl
@@ -883,7 +883,7 @@ function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = []
 
     fns = generate_function_bc(sys, u0, u0_idxs, tspan)
     bc_oop, bc_iip = eval_or_rgf.(fns; eval_expression, eval_module) 
-    # bc(sol, p, t) = bc_oop(sol, p, t)
+    bc(sol, p, t) = bc_oop(sol, p, t)
     bc(resid, u, p, t) = bc_iip(resid, u, p, t)
 
     return BVProblem{iip}(f, bc, u0, tspan, p; kwargs...)

From 25b56d77859cb4afd8c33bf2c48771e81591859e Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Mon, 3 Feb 2025 22:03:33 -0500
Subject: [PATCH 057/111] working codegen

---
 src/systems/diffeqs/abstractodesystem.jl       |  8 ++++----
 src/systems/diffeqs/odesystem.jl               |  2 +-
 src/systems/optimization/constraints_system.jl |  2 +-
 test/odesystem.jl                              | 12 ++++++++++++
 4 files changed, 18 insertions(+), 6 deletions(-)

diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl
index 056ecd80e4..8d8dd80d47 100644
--- a/src/systems/diffeqs/abstractodesystem.jl
+++ b/src/systems/diffeqs/abstractodesystem.jl
@@ -899,14 +899,14 @@ get_callback(prob::BVProblem) = error("BVP solvers do not support callbacks.")
 """
 function generate_function_bc(sys::ODESystem, u0, u0_idxs, tspan; kwargs...)
     iv = get_iv(sys)
-    sts = get_unknowns(sys)
-    ps = get_ps(sys)
+    sts = unknowns(sys)
+    ps = parameters(sys)
     np = length(ps)
     ns = length(sts)
     stidxmap = Dict([v => i for (i, v) in enumerate(sts)])
     pidxmap = Dict([v => i for (i, v) in enumerate(ps)])
 
-    @variables sol(..)[1:ns] p[1:np]
+    @variables sol(..)[1:ns]
 
     conssys = get_constraintsystem(sys)
     cons = Any[]
@@ -931,7 +931,7 @@ function generate_function_bc(sys::ODESystem, u0, u0_idxs, tspan; kwargs...)
     exprs = vcat(init_conds, cons)
     _p = reorder_parameters(sys, ps)
 
-    build_function_wrapper(sys, exprs, sol, _p..., t; kwargs...)
+    build_function_wrapper(sys, exprs, sol, _p..., t; output_type = Array, kwargs...)
 end
 
 """
diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl
index c38f20c235..44cc7df46b 100644
--- a/src/systems/diffeqs/odesystem.jl
+++ b/src/systems/diffeqs/odesystem.jl
@@ -687,7 +687,6 @@ function process_constraint_system(constraints::Vector{Equation}, sts, ps, iv; c
     constraintsts = OrderedSet()
     constraintps = OrderedSet()
 
-    # Hack? to extract parameters from callable variables in constraints.
     for cons in constraints
         collect_vars!(constraintsts, constraintps, cons, iv)
     end
@@ -712,5 +711,6 @@ function process_constraint_system(constraints::Vector{Equation}, sts, ps, iv; c
         end
     end
 
+    @show constraints
     ConstraintsSystem(constraints, collect(constraintsts), collect(constraintps); name = consname)
 end
diff --git a/src/systems/optimization/constraints_system.jl b/src/systems/optimization/constraints_system.jl
index 275079297a..ecdbafa044 100644
--- a/src/systems/optimization/constraints_system.jl
+++ b/src/systems/optimization/constraints_system.jl
@@ -123,7 +123,7 @@ function ConstraintsSystem(constraints, unknowns, ps;
     name === nothing &&
         throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
 
-    cstr = value.(Symbolics.canonical_form.(scalarize(constraints)))
+    cstr = value.(Symbolics.canonical_form.(vcat(scalarize(constraints)...)))
     unknowns′ = value.(scalarize(unknowns))
     ps′ = value.(ps)
 
diff --git a/test/odesystem.jl b/test/odesystem.jl
index 98ad6dd87a..97073457b3 100644
--- a/test/odesystem.jl
+++ b/test/odesystem.jl
@@ -1670,4 +1670,16 @@ end
     @test_throws ArgumentError @mtkbuild sys = ODESystem(eqs, t; constraints = cons)
     cons = [x(t) * v ~ 3]
     @test_throws ArgumentError @mtkbuild sys = ODESystem(eqs, t; constraints = cons) # Need time argument.
+
+    # Test array variables
+    @variables x(..)[1:5]
+    mat = [1 2 0 3 2
+           0 0 3 2 0
+           0 1 3 0 4
+           2 0 0 2 1
+           0 0 2 0 5]
+    eqs = D(x(t)) ~ mat * x(t)
+    cons = [x(3) ~ [2,3,3,5,4]]
+    @mtkbuild ode = ODESystem(D(x(t)) ~ mat * x(t), t; constraints = cons)
+    @test length(constraints(ModelingToolkit.get_constraintsystem(ode))) == 5
 end

From dd71d2349f0d8eca81b04fc212209045bfb18018 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Mon, 3 Feb 2025 22:10:22 -0500
Subject: [PATCH 058/111] add docs

---
 docs/src/systems/DiscreteSystem.md | 28 ++++++++++++++++++++++++++++
 1 file changed, 28 insertions(+)
 create mode 100644 docs/src/systems/DiscreteSystem.md

diff --git a/docs/src/systems/DiscreteSystem.md b/docs/src/systems/DiscreteSystem.md
new file mode 100644
index 0000000000..b6a8061e50
--- /dev/null
+++ b/docs/src/systems/DiscreteSystem.md
@@ -0,0 +1,28 @@
+# DiscreteSystem
+
+## System Constructors
+
+```@docs
+DiscreteSystem
+```
+
+## Composition and Accessor Functions
+
+  - `get_eqs(sys)` or `equations(sys)`: The equations that define the discrete system.
+  - `get_unknowns(sys)` or `unknowns(sys)`: The set of unknowns in the discrete system.
+  - `get_ps(sys)` or `parameters(sys)`: The parameters of the discrete system.
+  - `get_iv(sys)`: The independent variable of the discrete system.
+  - `discrete_events(sys)`: The set of discrete events in the discrete system.
+
+## Transformations
+
+```@docs; canonical=false
+structural_simplify
+```
+
+## Problem Constructors
+
+```@docs; canonical=false
+DiscreteProblem(sys::DiscreteSystem, u0map, tspan)
+DiscreteFunction(sys::DiscreteSystem, u0map, tspan)
+```

From 4475a5a6ca7b61a897f3c90041895ed23b3df5d4 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Mon, 3 Feb 2025 22:37:56 -0500
Subject: [PATCH 059/111] up

---
 .../symbolics_tearing.jl                      | 19 ++++---------------
 1 file changed, 4 insertions(+), 15 deletions(-)

diff --git a/src/structural_transformation/symbolics_tearing.jl b/src/structural_transformation/symbolics_tearing.jl
index 599a75f4f7..059c7ca9e4 100644
--- a/src/structural_transformation/symbolics_tearing.jl
+++ b/src/structural_transformation/symbolics_tearing.jl
@@ -237,18 +237,6 @@ function check_diff_graph(var_to_diff, fullvars)
 end
 =#
 
-function state_selection() 
-    
-end
-
-function create_new_deriv_variables() 
-    
-end
-
-function solve_solvable_equations() 
-    
-end
-
 function tearing_reassemble(state::TearingState, var_eq_matching,
         full_var_eq_matching = nothing; simplify = false, mm = nothing, cse_hack = true, array_hack = true)
     @unpack fullvars, sys, structure = state
@@ -323,7 +311,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
             diff_to_var[dv] = nothing
         end
     end
-    @show neweqs
+    @show var_eq_matching
 
     println("Post state selection.")
     
@@ -337,7 +325,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
     end
     idx_to_lowest_shift = Dict{Int, Int}(var => 0 for var in 1:length(fullvars))
     for (i,var) in enumerate(fullvars)
-        key = operation(var) isa Shift ? only(arguments(var)) : var
+        key = (operation(var) isa Shift) ? only(arguments(var)) : var
         idx_to_lowest_shift[i] = get(lowest_shift, key, 0)
     end
 
@@ -345,6 +333,8 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
     isdervar = let diff_to_var = diff_to_var
         var -> diff_to_var[var] !== nothing
     end
+    # For discrete variables, we want the substitution to turn 
+    # Shift(t, k)(x(t)) => x_t-k(t)
     var_order = let diff_to_var = diff_to_var
         dv -> begin
             order = 0
@@ -358,7 +348,6 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
     end
     lower_name = is_only_discrete(state.structure) ? lower_varname_withshift : lower_varname_with_unit
 
-    #retear = BitSet()
     # There are three cases where we want to generate new variables to convert
     # the system into first order (semi-implicit) ODEs.
     #

From c35b7974f2d96584c96cfbffa2de87a390d92d8d Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Mon, 3 Feb 2025 22:43:39 -0500
Subject: [PATCH 060/111] revert to OrdinaryDiffEqDefault

---
 test/bvproblem.jl | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/test/bvproblem.jl b/test/bvproblem.jl
index cedd0eef8d..311a17e172 100644
--- a/test/bvproblem.jl
+++ b/test/bvproblem.jl
@@ -1,6 +1,6 @@
 ### TODO: update when BoundaryValueDiffEqAscher is updated to use the normal boundary condition conventions 
 
-using OrdinaryDiffEqVerner
+using OrdinaryDiffEqDefault
 using BoundaryValueDiffEqMIRK, BoundaryValueDiffEqAscher
 using BenchmarkTools
 using ModelingToolkit
@@ -24,7 +24,7 @@ let
      
      @mtkbuild lotkavolterra = ODESystem(eqs, t)
      op = ODEProblem(lotkavolterra, u0map, tspan, parammap)
-     osol = solve(op, Vern9())
+     osol = solve(op)
      
      bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(
          lotkavolterra, u0map, tspan, parammap; eval_expression = true)
@@ -61,7 +61,7 @@ let
      tspan = (0.0, 6.0)
      
      op = ODEProblem(pend, u0map, tspan, parammap)
-     osol = solve(op, Vern9())
+     osol = solve(op)
      
      bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap)
      for solver in solvers

From a964dd7e1660f599afdbf0e9b2c9ddbfa717ea83 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Mon, 3 Feb 2025 23:01:48 -0500
Subject: [PATCH 061/111] up

---
 src/structural_transformation/symbolics_tearing.jl | 4 ++++
 1 file changed, 4 insertions(+)

diff --git a/src/structural_transformation/symbolics_tearing.jl b/src/structural_transformation/symbolics_tearing.jl
index 059c7ca9e4..a80bebd436 100644
--- a/src/structural_transformation/symbolics_tearing.jl
+++ b/src/structural_transformation/symbolics_tearing.jl
@@ -237,6 +237,10 @@ function check_diff_graph(var_to_diff, fullvars)
 end
 =#
 
+function substitute_lower_order!(state::TearingState) 
+    
+end
+
 function tearing_reassemble(state::TearingState, var_eq_matching,
         full_var_eq_matching = nothing; simplify = false, mm = nothing, cse_hack = true, array_hack = true)
     @unpack fullvars, sys, structure = state

From 26a545dbaaf4ced4b59ae45cc02377523dec5f10 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Mon, 3 Feb 2025 23:07:07 -0500
Subject: [PATCH 062/111] up

---
 src/structural_transformation/symbolics_tearing.jl | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/src/structural_transformation/symbolics_tearing.jl b/src/structural_transformation/symbolics_tearing.jl
index a80bebd436..e612369498 100644
--- a/src/structural_transformation/symbolics_tearing.jl
+++ b/src/structural_transformation/symbolics_tearing.jl
@@ -241,6 +241,7 @@ function substitute_lower_order!(state::TearingState)
     
 end
 
+import ModelingToolkit: Shift
 function tearing_reassemble(state::TearingState, var_eq_matching,
         full_var_eq_matching = nothing; simplify = false, mm = nothing, cse_hack = true, array_hack = true)
     @unpack fullvars, sys, structure = state
@@ -255,6 +256,8 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
     end
 
     neweqs = collect(equations(state))
+    lower_name = is_only_discrete(state.structure) ? lower_varname_withshift : lower_varname_with_unit
+
     # Terminology and Definition:
     #
     # A general DAE is in the form of `F(u'(t), u(t), p, t) == 0`. We can
@@ -350,7 +353,6 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
             order, dv
         end
     end
-    lower_name = is_only_discrete(state.structure) ? lower_varname_withshift : lower_varname_with_unit
 
     # There are three cases where we want to generate new variables to convert
     # the system into first order (semi-implicit) ODEs.
@@ -464,7 +466,6 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
         add_edge!(solvable_graph, dummy_eq, dv)
         @assert nsrcs(graph) == nsrcs(solvable_graph) == dummy_eq
         @label FOUND_DUMMY_EQ
-        # If var = x with no shift, then 
         is_only_discrete(state.structure) && (idx_to_lowest_shift[v_t] = idx_to_lowest_shift[dv])
         var_to_diff[v_t] = var_to_diff[dv]
         var_eq_matching[dv] = unassigned

From 25e84dbf39a92c569dd3a8f4027c9b33dd99f34e Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Mon, 3 Feb 2025 23:08:12 -0500
Subject: [PATCH 063/111] use MIRK

---
 Project.toml | 7 ++++---
 1 file changed, 4 insertions(+), 3 deletions(-)

diff --git a/Project.toml b/Project.toml
index c3646021ff..41c5ad0730 100644
--- a/Project.toml
+++ b/Project.toml
@@ -85,6 +85,7 @@ BifurcationKit = "0.4"
 BlockArrays = "1.1"
 BoundaryValueDiffEq = "5.12.0"
 BoundaryValueDiffEqAscher = "1.1.0"
+BoundaryValueDiffEqMIRK = "1.4.0"
 ChainRulesCore = "1"
 Combinatorics = "1"
 CommonSolve = "0.2.4"
@@ -106,11 +107,11 @@ DynamicQuantities = "^0.11.2, 0.12, 0.13, 1"
 EnumX = "1.0.4"
 ExprTools = "0.1.10"
 Expronicon = "0.8"
+FMI = "0.14"
 FindFirstFunctions = "1"
 ForwardDiff = "0.10.3"
 FunctionWrappers = "1.1"
 FunctionWrappersWrappers = "0.1"
-FMI = "0.14"
 Graphs = "1.5.2"
 HomotopyContinuation = "2.11"
 InfiniteOpt = "0.5"
@@ -157,7 +158,7 @@ julia = "1.9"
 [extras]
 AmplNLWriter = "7c4d4715-977e-5154-bfe0-e096adeac482"
 BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
-BoundaryValueDiffEq = "764a87c0-6b3e-53db-9096-fe964310641d"
+BoundaryValueDiffEqMIRK = "1a22d4ce-7765-49ea-b6f2-13c8438986a6"
 BoundaryValueDiffEqAscher = "7227322d-7511-4e07-9247-ad6ff830280e"
 ControlSystemsBase = "aaaaaaaa-a6ca-5380-bf3e-84a91bcd477e"
 DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
@@ -190,4 +191,4 @@ Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
 Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
 
 [targets]
-test = ["AmplNLWriter", "BenchmarkTools", "BoundaryValueDiffEq", "BoundaryValueDiffEqAscher", "ControlSystemsBase", "DataInterpolations", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "OrdinaryDiffEqCore", "OrdinaryDiffEqDefault", "REPL", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET", "OrdinaryDiffEqNonlinearSolve"]
+test = ["AmplNLWriter", "BenchmarkTools", "BoundaryValueDiffEqMIRK", "BoundaryValueDiffEqAscher", "ControlSystemsBase", "DataInterpolations", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "OrdinaryDiffEqCore", "OrdinaryDiffEqDefault", "REPL", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET", "OrdinaryDiffEqNonlinearSolve"]

From e6a6932d192ef72bc46be8ed6124965372e87f26 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Mon, 3 Feb 2025 23:09:04 -0500
Subject: [PATCH 064/111] up

---
 Project.toml | 1 -
 1 file changed, 1 deletion(-)

diff --git a/Project.toml b/Project.toml
index 41c5ad0730..cd2c2c4380 100644
--- a/Project.toml
+++ b/Project.toml
@@ -83,7 +83,6 @@ AbstractTrees = "0.3, 0.4"
 ArrayInterface = "6, 7"
 BifurcationKit = "0.4"
 BlockArrays = "1.1"
-BoundaryValueDiffEq = "5.12.0"
 BoundaryValueDiffEqAscher = "1.1.0"
 BoundaryValueDiffEqMIRK = "1.4.0"
 ChainRulesCore = "1"

From 5e5c24c071f00a8d50b6e7441a141ca9cbeb1681 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Mon, 3 Feb 2025 23:30:44 -0500
Subject: [PATCH 065/111] revert to OrdinaryDiffEq

---
 test/bvproblem.jl | 10 ++++------
 1 file changed, 4 insertions(+), 6 deletions(-)

diff --git a/test/bvproblem.jl b/test/bvproblem.jl
index 311a17e172..1d7d4b3793 100644
--- a/test/bvproblem.jl
+++ b/test/bvproblem.jl
@@ -1,6 +1,6 @@
 ### TODO: update when BoundaryValueDiffEqAscher is updated to use the normal boundary condition conventions 
 
-using OrdinaryDiffEqDefault
+using OrdinaryDiffEq
 using BoundaryValueDiffEqMIRK, BoundaryValueDiffEqAscher
 using BenchmarkTools
 using ModelingToolkit
@@ -24,10 +24,9 @@ let
      
      @mtkbuild lotkavolterra = ODESystem(eqs, t)
      op = ODEProblem(lotkavolterra, u0map, tspan, parammap)
-     osol = solve(op)
+     osol = solve(op, Vern9())
      
-     bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(
-         lotkavolterra, u0map, tspan, parammap; eval_expression = true)
+     bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lotkavolterra, u0map, tspan, parammap)
      
      for solver in solvers
          sol = solve(bvp, solver(), dt = 0.01)
@@ -36,8 +35,7 @@ let
      end
      
      # Test out of place
-     bvp2 = SciMLBase.BVProblem{false, SciMLBase.AutoSpecialize}(
-         lotkavolterra, u0map, tspan, parammap; eval_expression = true)
+     bvp2 = SciMLBase.BVProblem{false, SciMLBase.AutoSpecialize}(lotkavolterra, u0map, tspan, parammap)
      
      for solver in solvers
          sol = solve(bvp2, solver(), dt = 0.01)

From 5338d4f6b60704ebcc1df0c425a0a721f1a46138 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Tue, 4 Feb 2025 00:08:54 -0500
Subject: [PATCH 066/111] tests passing

---
 test/bvproblem.jl | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/test/bvproblem.jl b/test/bvproblem.jl
index 1d7d4b3793..b4032e2927 100644
--- a/test/bvproblem.jl
+++ b/test/bvproblem.jl
@@ -59,7 +59,7 @@ let
      tspan = (0.0, 6.0)
      
      op = ODEProblem(pend, u0map, tspan, parammap)
-     osol = solve(op)
+     osol = solve(op, Vern9())
      
      bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap)
      for solver in solvers
@@ -111,8 +111,8 @@ let
     end
 
     u0 = [1., 2.]; p = [1.5, 1., 1., 3.]
-    genbc_iip = ModelingToolkit.generate_function_bc(lksys, u0, [1, 2], tspan, true)
-    genbc_oop = ModelingToolkit.generate_function_bc(lksys, u0, [1, 2], tspan, false)
+    fns = ModelingToolkit.generate_function_bc(lksys, u0, [1, 2], tspan)
+    genbc_oop, genbc_iip = ModelingToolkit.eval_or_rgf.(fns)
 
     bvpi1 = SciMLBase.BVProblem(lotkavolterra!, bc!, [1.,2.], tspan, p)
     bvpi2 = SciMLBase.BVProblem(lotkavolterra!, genbc_iip, [1.,2.], tspan, p)
@@ -141,8 +141,8 @@ let
     end
 
     u0 = [1, 1.]
-    genbc_iip = ModelingToolkit.generate_function_bc(lksys, u0, [1], tspan, true)
-    genbc_oop = ModelingToolkit.generate_function_bc(lksys, u0, [1], tspan, false)
+    fns = ModelingToolkit.generate_function_bc(lksys, u0, [1], tspan)
+    genbc_oop, genbc_iip = ModelingToolkit.eval_or_rgf.(fns)
 
     bvpi1 = SciMLBase.BVProblem(lotkavolterra!, bc!, u0, tspan, p)
     bvpi2 = SciMLBase.BVProblem(lotkavolterra!, genbc_iip, u0, tspan, p)

From 1656277f45e7a12229d8ed0545a979cb8bddb273 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Tue, 4 Feb 2025 02:49:11 -0500
Subject: [PATCH 067/111] maadd comments

---
 .../symbolics_tearing.jl                      | 68 ++++++++++++++-----
 src/structural_transformation/utils.jl        |  3 +-
 src/systems/systemstructure.jl                |  3 -
 3 files changed, 52 insertions(+), 22 deletions(-)

diff --git a/src/structural_transformation/symbolics_tearing.jl b/src/structural_transformation/symbolics_tearing.jl
index e612369498..cf6736b419 100644
--- a/src/structural_transformation/symbolics_tearing.jl
+++ b/src/structural_transformation/symbolics_tearing.jl
@@ -241,6 +241,32 @@ function substitute_lower_order!(state::TearingState)
     
 end
 
+# Documenting the differences to structural simplification for discrete systems:
+# In discrete systems the lowest-order term is x_k-i, instead of x(t). In order
+# to substitute dummy variables for x_k-1, x_k-2, ... instead you need to reverse 
+# the order. So for discrete systems `var_order` is defined a little differently.
+#
+# The orders will also be off by one. The reason this is is that the dynamics of
+# the system should be given in terms of Shift(t, 1)(x(t), x(t-1), ...). But 
+# having the observables be indexed by the next time step is not so nice. So we 
+# handle the shifts in the renaming, rather than explicitly.
+#
+# The substitution should look like the following: 
+#   x(t) -> Shift(t, 1)(x(t))
+#   x(k-1) -> x(t)
+#   x(k-2) -> x_{t-1}(t)
+#   x(k-3) -> x_{t-2}(t)
+#   and so on...
+#
+# In the implicit discrete case this shouldn't happen. The simplification should 
+# look like a NonlinearSystem.
+#
+# For discrete systems Shift(t, 2)(x(t)) is not equivalent to Shift(t, 1)(Shift(t,1)(x(t))
+# This is different from the continuous case where D(D(x)) can be substituted for 
+# by iteratively substituting x_t ~ D(x), then x_tt ~ D(x_t). For this reason the
+# total_sub dict is updated at the time that the renamed variables are written,
+# inside the loop where new variables are generated.
+
 import ModelingToolkit: Shift
 function tearing_reassemble(state::TearingState, var_eq_matching,
         full_var_eq_matching = nothing; simplify = false, mm = nothing, cse_hack = true, array_hack = true)
@@ -318,9 +344,6 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
             diff_to_var[dv] = nothing
         end
     end
-    @show var_eq_matching
-
-    println("Post state selection.")
     
     # `SelectedState` information is no longer needed past here. State selection
     # is done. All non-differentiated variables are algebraic variables, and all
@@ -330,18 +353,19 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
     is_solvable = let solvable_graph = solvable_graph
         (eq, iv) -> eq isa Int && iv isa Int && BipartiteEdge(eq, iv) in solvable_graph
     end
-    idx_to_lowest_shift = Dict{Int, Int}(var => 0 for var in 1:length(fullvars))
-    for (i,var) in enumerate(fullvars)
-        key = (operation(var) isa Shift) ? only(arguments(var)) : var
-        idx_to_lowest_shift[i] = get(lowest_shift, key, 0)
+
+    if is_only_discrete(state.structure)
+         idx_to_lowest_shift = Dict{Int, Int}(var => 0 for var in 1:length(fullvars))
+         for (i,var) in enumerate(fullvars)
+             key = (operation(var) isa Shift) ? only(arguments(var)) : var
+             idx_to_lowest_shift[i] = get(lowest_shift, key, 0)
+         end
     end
 
     # if var is like D(x)
     isdervar = let diff_to_var = diff_to_var
         var -> diff_to_var[var] !== nothing
     end
-    # For discrete variables, we want the substitution to turn 
-    # Shift(t, k)(x(t)) => x_t-k(t)
     var_order = let diff_to_var = diff_to_var
         dv -> begin
             order = 0
@@ -349,7 +373,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
                 order += 1
                 dv = dv′
             end
-            is_only_discrete(state.structure) && (order = -idx_to_lowest_shift[dv] - order - 1)
+            is_only_discrete(state.structure) && (order = -idx_to_lowest_shift[dv] - order)
             order, dv
         end
     end
@@ -403,6 +427,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
     linear_eqs = mm === nothing ? Dict{Int, Int}() :
                  Dict(reverse(en) for en in enumerate(mm.nzrows))
 
+    total_sub = Dict()
     for v in 1:length(var_to_diff)
         is_highest_discrete = begin
             var = fullvars[v]
@@ -442,8 +467,13 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
         end
         dx = fullvars[dv]
         # add `x_t`
-        order, lv = var_order(dv)
+        println()
+        @show order, lv = var_order(dv)
         x_t = lower_name(fullvars[lv], iv, order)
+        @show fullvars[v]
+        @show fullvars[dv]
+        @show fullvars[lv]
+        @show dx, x_t
         push!(fullvars, simplify_shifts(x_t))
         v_t = length(fullvars)
         v_t_idx = add_vertex!(var_to_diff)
@@ -466,11 +496,16 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
         add_edge!(solvable_graph, dummy_eq, dv)
         @assert nsrcs(graph) == nsrcs(solvable_graph) == dummy_eq
         @label FOUND_DUMMY_EQ
-        is_only_discrete(state.structure) && (idx_to_lowest_shift[v_t] = idx_to_lowest_shift[dv])
+        is_only_discrete(state.structure) && begin
+            idx_to_lowest_shift[v_t] = idx_to_lowest_shift[dv]
+            operation(dx) isa Shift && (total_sub[dx] = x_t)
+            order == 1 && (total_sub[x_t] = fullvars[var_to_diff[dv]])
+        end
         var_to_diff[v_t] = var_to_diff[dv]
         var_eq_matching[dv] = unassigned
         eq_var_matching[dummy_eq] = dv
     end
+    @show total_sub
 
     # Will reorder equations and unknowns to be:
     # [diffeqs; ...]
@@ -490,15 +525,13 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
     # Solve solvable equations
     println()
     println("SOLVING SOLVABLE EQUATIONS.")
-    @show eq_var_matching
     toporder = topological_sort(DiCMOBiGraph{false}(graph, var_eq_matching))
     eqs = Iterators.reverse(toporder)
-    @show eqs
-    @show neweqs
-    @show fullvars
-    total_sub = Dict()
     idep = iv
+    @show eq_var_matching
+
     for ieq in eqs
+        println()
         iv = eq_var_matching[ieq]
         if is_solvable(ieq, iv)
             # We don't solve differential equations, but we will need to try to
@@ -513,7 +546,6 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
                 dx = D(simplify_shifts(lower_name(
                     fullvars[lv], idep, order - 1)))
                 @show dx
-                @show neweqs[ieq]
                 eq = dx ~ simplify_shifts(Symbolics.fixpoint_sub(
                     Symbolics.symbolic_linear_solve(neweqs[ieq],
                         fullvars[iv]),
diff --git a/src/structural_transformation/utils.jl b/src/structural_transformation/utils.jl
index dd063c0b07..707a57097b 100644
--- a/src/structural_transformation/utils.jl
+++ b/src/structural_transformation/utils.jl
@@ -451,7 +451,8 @@ end
 
 function lower_varname_withshift(var, iv, order)
     order == 0 && return var
-    ds = "$iv-$order"
+    #order == -1 && return Shift(iv, 1)(var)
+    ds = "$iv-$(order-1)"
     d_separator = 'ˍ'
 
     if ModelingToolkit.isoperator(var, ModelingToolkit.Shift)
diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl
index 44558ade6c..89fcebf549 100644
--- a/src/systems/systemstructure.jl
+++ b/src/systems/systemstructure.jl
@@ -476,9 +476,6 @@ function shift_discrete_system(ts::TearingState, lowest_shift)
         vars!(discvars, eq; op = Union{Sample, Hold})
     end
     iv = get_iv(sys)
-    discmap = Dict(k => StructuralTransformations.simplify_shifts(Shift(iv, -get(lowest_shift, k, 0))(k))
-    for k in discvars
-    if any(isequal(k), fullvars) && !isa(operation(k), Union{Sample, Hold})) 
 
     discmap = Dict(k => StructuralTransformations.simplify_shifts(Shift(iv, 1)(k))    for k in discvars 
     if any(isequal(k), fullvars) && !isa(operation(k), Union{Sample, Hold})) 

From 810d4fa26af98e0231da9323bb61053f02fd13ce Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Tue, 4 Feb 2025 02:58:11 -0500
Subject: [PATCH 068/111] remove problematic tests, codegen assumes
 MTKParameters

---
 test/bvproblem.jl | 44 ++------------------------------------------
 1 file changed, 2 insertions(+), 42 deletions(-)

diff --git a/test/bvproblem.jl b/test/bvproblem.jl
index b4032e2927..edc85b4cbd 100644
--- a/test/bvproblem.jl
+++ b/test/bvproblem.jl
@@ -100,33 +100,6 @@ let
     function lotkavolterra(u, p, t) 
         [p[1]*u[1] - p[2]*u[1]*u[2], -p[4]*u[2] + p[3]*u[1]*u[2]]
     end
-    # Compare the built bc function to the actual constructed one.
-    function bc!(resid, u, p, t) 
-        resid[1] = u[1][1] - 1.
-        resid[2] = u[1][2] - 2.
-        nothing
-    end
-    function bc(u, p, t)
-        [u[1][1] - 1., u[1][2] - 2.]
-    end
-
-    u0 = [1., 2.]; p = [1.5, 1., 1., 3.]
-    fns = ModelingToolkit.generate_function_bc(lksys, u0, [1, 2], tspan)
-    genbc_oop, genbc_iip = ModelingToolkit.eval_or_rgf.(fns)
-
-    bvpi1 = SciMLBase.BVProblem(lotkavolterra!, bc!, [1.,2.], tspan, p)
-    bvpi2 = SciMLBase.BVProblem(lotkavolterra!, genbc_iip, [1.,2.], tspan, p)
-
-    sol1 = solve(bvpi1, MIRK4(), dt = 0.05)
-    sol2 = solve(bvpi2, MIRK4(), dt = 0.05)
-    @test sol1 ≈ sol2
-
-    bvpo1 = BVProblem(lotkavolterra, bc, [1,2], tspan, p)
-    bvpo2 = BVProblem(lotkavolterra, genbc_oop, [1,2], tspan, p)
-
-    sol1 = solve(bvpo1, MIRK4(), dt = 0.05)
-    sol2 = solve(bvpo2, MIRK4(), dt = 0.05)
-    @test sol1 ≈ sol2
 
     # Test with a constraint.
     constr = [y(0.5) ~ 2.]
@@ -140,29 +113,16 @@ let
         [u(0.0)[1] - 1., u(0.5)[2] - 2.]
     end
 
-    u0 = [1, 1.]
-    fns = ModelingToolkit.generate_function_bc(lksys, u0, [1], tspan)
-    genbc_oop, genbc_iip = ModelingToolkit.eval_or_rgf.(fns)
-
     bvpi1 = SciMLBase.BVProblem(lotkavolterra!, bc!, u0, tspan, p)
-    bvpi2 = SciMLBase.BVProblem(lotkavolterra!, genbc_iip, u0, tspan, p)
+    bvpi1 = SciMLBase.BVProblem(lotkavolterra, bc, u0, tspan, p)
     bvpi3 = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lksys, [x(t) => 1.], tspan; guesses = [y(t) => 1.])
-    bvpi4 = SciMLBase.BVProblem{true, SciMLBase.FullSpecialize}(lksys, [x(t) => 1.], tspan; guesses = [y(t) => 1.])
+    bvpi4 = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(lksys, [x(t) => 1.], tspan; guesses = [y(t) => 1.])
     
     sol1 = @btime solve($bvpi1, MIRK4(), dt = 0.01)
     sol2 = @btime solve($bvpi2, MIRK4(), dt = 0.01)
     sol3 = @btime solve($bvpi3, MIRK4(), dt = 0.01)
     sol4 = @btime solve($bvpi4, MIRK4(), dt = 0.01)
     @test sol1 ≈ sol2 ≈ sol3 ≈ sol4 # don't get true equality here, not sure why
-
-    bvpo1 = BVProblem(lotkavolterra, bc, u0, tspan, p)
-    bvpo2 = BVProblem(lotkavolterra, genbc_oop, u0, tspan, p)
-    bvpo3 = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(lksys, [x(t) => 1.], tspan; guesses = [y(t) => 1.])
-
-    sol1 = @btime solve($bvpo1, MIRK4(), dt = 0.05)
-    sol2 = @btime solve($bvpo2, MIRK4(), dt = 0.05)
-    sol3 = @btime solve($bvpo3, MIRK4(), dt = 0.05)
-    @test sol1 ≈ sol2 ≈ sol3
 end
 
 function test_solvers(solvers, prob, u0map, constraints, equations = []; dt = 0.05, atol = 1e-2)

From 16ea18bbde45a0c50db93d0922375a5f54312fe9 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Tue, 4 Feb 2025 03:27:27 -0500
Subject: [PATCH 069/111] begin refactor

---
 .../symbolics_tearing.jl                      | 129 +++++++++---------
 1 file changed, 66 insertions(+), 63 deletions(-)

diff --git a/src/structural_transformation/symbolics_tearing.jl b/src/structural_transformation/symbolics_tearing.jl
index cf6736b419..6340087003 100644
--- a/src/structural_transformation/symbolics_tearing.jl
+++ b/src/structural_transformation/symbolics_tearing.jl
@@ -237,8 +237,49 @@ function check_diff_graph(var_to_diff, fullvars)
 end
 =#
 
-function substitute_lower_order!(state::TearingState) 
-    
+
+function substitute_dummy_derivatives!(fullvars, var_to_diff, diff_to_var, var_eq_matching, neweqs)
+    # Dummy derivatives may determine that some differential variables are
+    # algebraic variables in disguise. The derivative of such variables are
+    # called dummy derivatives.
+
+    # Step 1:
+    # Replace derivatives of non-selected unknown variables by dummy derivatives
+
+    dummy_sub = Dict()
+    for var in 1:length(fullvars)
+        dv = var_to_diff[var]
+        dv === nothing && continue
+        if var_eq_matching[var] !== SelectedState()
+            dd = fullvars[dv]
+            v_t = setio(diff2term_with_unit(unwrap(dd), unwrap(iv)), false, false)
+            for eq in 𝑑neighbors(graph, dv)
+                dummy_sub[dd] = v_t
+                neweqs[eq] = fast_substitute(neweqs[eq], dd => v_t)
+            end
+            fullvars[dv] = v_t
+            # If we have:
+            # x -> D(x) -> D(D(x))
+            # We need to to transform it to:
+            # x   x_t -> D(x_t)
+            # update the structural information
+            dx = dv
+            x_t = v_t
+            while (ddx = var_to_diff[dx]) !== nothing
+                dx_t = D(x_t)
+                for eq in 𝑑neighbors(graph, ddx)
+                    neweqs[eq] = fast_substitute(neweqs[eq], fullvars[ddx] => dx_t)
+                end
+                fullvars[ddx] = dx_t
+                dx = ddx
+                x_t = dx_t
+            end
+            diff_to_var[dv] = nothing
+        end
+    end
+end
+
+function generate_derivative_variables!()
 end
 
 # Documenting the differences to structural simplification for discrete systems:
@@ -281,88 +322,50 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
         end
     end
 
-    neweqs = collect(equations(state))
-    lower_name = is_only_discrete(state.structure) ? lower_varname_withshift : lower_varname_with_unit
-
-    # Terminology and Definition:
-    #
-    # A general DAE is in the form of `F(u'(t), u(t), p, t) == 0`. We can
-    # characterize variables in `u(t)` into two classes: differential variables
-    # (denoted `v(t)`) and algebraic variables (denoted `z(t)`). Differential
-    # variables are marked as `SelectedState` and they are differentiated in the
-    # DAE system, i.e. `v'(t)` are all the variables in `u'(t)` that actually
-    # appear in the system. Algebraic variables are variables that are not
-    # differential variables.
-    #
-    # Dummy derivatives may determine that some differential variables are
-    # algebraic variables in disguise. The derivative of such variables are
-    # called dummy derivatives.
-
-    # Step 1:
-    # Replace derivatives of non-selected unknown variables by dummy derivatives
-
     if ModelingToolkit.has_iv(state.sys)
         iv = get_iv(state.sys)
         if is_only_discrete(state.structure)
             D = Shift(iv, 1)
+            lower_name = lower_varname_withshift
+            idx_to_lowest_shift = Dict{Int, Int}(var => 0 for var in 1:length(fullvars))
+            for (i,var) in enumerate(fullvars)
+                key = (operation(var) isa Shift) ? only(arguments(var)) : var
+                idx_to_lowest_shift[i] = get(lowest_shift, key, 0)
+            end
         else
             D = Differential(iv)
+            lower_name = lower_varname_with_unit
         end
     else
         iv = D = nothing
+        lower_name = lower_varname_with_unit
     end
+
     diff_to_var = invview(var_to_diff)
+    neweqs = collect(equations(state))
 
-    dummy_sub = Dict()
-    for var in 1:length(fullvars)
-        dv = var_to_diff[var]
-        dv === nothing && continue
-        if var_eq_matching[var] !== SelectedState()
-            dd = fullvars[dv]
-            v_t = setio(diff2term_with_unit(unwrap(dd), unwrap(iv)), false, false)
-            for eq in 𝑑neighbors(graph, dv)
-                dummy_sub[dd] = v_t
-                neweqs[eq] = fast_substitute(neweqs[eq], dd => v_t)
-            end
-            fullvars[dv] = v_t
-            # If we have:
-            # x -> D(x) -> D(D(x))
-            # We need to to transform it to:
-            # x   x_t -> D(x_t)
-            # update the structural information
-            dx = dv
-            x_t = v_t
-            while (ddx = var_to_diff[dx]) !== nothing
-                dx_t = D(x_t)
-                for eq in 𝑑neighbors(graph, ddx)
-                    neweqs[eq] = fast_substitute(neweqs[eq], fullvars[ddx] => dx_t)
-                end
-                fullvars[ddx] = dx_t
-                dx = ddx
-                x_t = dx_t
-            end
-            diff_to_var[dv] = nothing
-        end
-    end
+    # Terminology and Definition:
+    #
+    # A general DAE is in the form of `F(u'(t), u(t), p, t) == 0`. We can
+    # characterize variables in `u(t)` into two classes: differential variables
+    # (denoted `v(t)`) and algebraic variables (denoted `z(t)`). Differential
+    # variables are marked as `SelectedState` and they are differentiated in the
+    # DAE system, i.e. `v'(t)` are all the variables in `u'(t)` that actually
+    # appear in the system. Algebraic variables are variables that are not
+    # differential variables.
     
+    substitute_dummy_derivatives!(fullvars, var_to_diff, diff_to_var, var_eq_matching, neweqs)
+
     # `SelectedState` information is no longer needed past here. State selection
     # is done. All non-differentiated variables are algebraic variables, and all
     # variables that appear differentiated are differential variables.
 
-    ### extract partition information
+    ### Extract partition information
     is_solvable = let solvable_graph = solvable_graph
         (eq, iv) -> eq isa Int && iv isa Int && BipartiteEdge(eq, iv) in solvable_graph
     end
 
-    if is_only_discrete(state.structure)
-         idx_to_lowest_shift = Dict{Int, Int}(var => 0 for var in 1:length(fullvars))
-         for (i,var) in enumerate(fullvars)
-             key = (operation(var) isa Shift) ? only(arguments(var)) : var
-             idx_to_lowest_shift[i] = get(lowest_shift, key, 0)
-         end
-    end
-
-    # if var is like D(x)
+    # if var is like D(x) or Shift(t, 1)(x)
     isdervar = let diff_to_var = diff_to_var
         var -> diff_to_var[var] !== nothing
     end

From 6740b8c89b6992e3ddfc1b004f74fe632ef2295d Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Tue, 4 Feb 2025 03:31:53 -0500
Subject: [PATCH 070/111] test fix

---
 test/bvproblem.jl | 5 ++++-
 1 file changed, 4 insertions(+), 1 deletion(-)

diff --git a/test/bvproblem.jl b/test/bvproblem.jl
index edc85b4cbd..f05be90281 100644
--- a/test/bvproblem.jl
+++ b/test/bvproblem.jl
@@ -113,8 +113,11 @@ let
         [u(0.0)[1] - 1., u(0.5)[2] - 2.]
     end
 
+    u0 = [1., 1.]
+    tspan = (0., 1.)
+    p = [1.5, 1., 1., 3.]
     bvpi1 = SciMLBase.BVProblem(lotkavolterra!, bc!, u0, tspan, p)
-    bvpi1 = SciMLBase.BVProblem(lotkavolterra, bc, u0, tspan, p)
+    bvpi2 = SciMLBase.BVProblem(lotkavolterra, bc, u0, tspan, p)
     bvpi3 = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lksys, [x(t) => 1.], tspan; guesses = [y(t) => 1.])
     bvpi4 = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(lksys, [x(t) => 1.], tspan; guesses = [y(t) => 1.])
     

From bf3cd3394212acd34678e1927f79b958b5320bcd Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Tue, 4 Feb 2025 14:52:43 -0500
Subject: [PATCH 071/111] reorganization into functions

---
 .../symbolics_tearing.jl                      | 376 ++++++++++--------
 1 file changed, 208 insertions(+), 168 deletions(-)

diff --git a/src/structural_transformation/symbolics_tearing.jl b/src/structural_transformation/symbolics_tearing.jl
index 6340087003..cf70914a51 100644
--- a/src/structural_transformation/symbolics_tearing.jl
+++ b/src/structural_transformation/symbolics_tearing.jl
@@ -237,16 +237,22 @@ function check_diff_graph(var_to_diff, fullvars)
 end
 =#
 
+"""
+Replace derivatives of non-selected unknown variables by dummy derivatives. 
 
-function substitute_dummy_derivatives!(fullvars, var_to_diff, diff_to_var, var_eq_matching, neweqs)
-    # Dummy derivatives may determine that some differential variables are
-    # algebraic variables in disguise. The derivative of such variables are
-    # called dummy derivatives.
+State selection may determine that some differential variables are
+algebraic variables in disguise. The derivative of such variables are
+called dummy derivatives.
 
-    # Step 1:
-    # Replace derivatives of non-selected unknown variables by dummy derivatives
+`SelectedState` information is no longer needed past here. State selection
+is done. All non-differentiated variables are algebraic variables, and all
+variables that appear differentiated are differential variables.
+"""
+function substitute_dummy_derivatives!(ts::TearingState, neweqs, dummy_sub, var_eq_matching)
+    @unpack fullvars, sys, structure = ts
+    @unpack solvable_graph, var_to_diff, eq_to_diff, graph, lowest_shift = structure
+    diff_to_var = invview(var_to_diff)
 
-    dummy_sub = Dict()
     for var in 1:length(fullvars)
         dv = var_to_diff[var]
         dv === nothing && continue
@@ -279,163 +285,73 @@ function substitute_dummy_derivatives!(fullvars, var_to_diff, diff_to_var, var_e
     end
 end
 
-function generate_derivative_variables!()
-end
+"""
+Generate new derivative variables for the system.
 
-# Documenting the differences to structural simplification for discrete systems:
-# In discrete systems the lowest-order term is x_k-i, instead of x(t). In order
-# to substitute dummy variables for x_k-1, x_k-2, ... instead you need to reverse 
-# the order. So for discrete systems `var_order` is defined a little differently.
-#
-# The orders will also be off by one. The reason this is is that the dynamics of
-# the system should be given in terms of Shift(t, 1)(x(t), x(t-1), ...). But 
-# having the observables be indexed by the next time step is not so nice. So we 
-# handle the shifts in the renaming, rather than explicitly.
-#
-# The substitution should look like the following: 
-#   x(t) -> Shift(t, 1)(x(t))
-#   x(k-1) -> x(t)
-#   x(k-2) -> x_{t-1}(t)
-#   x(k-3) -> x_{t-2}(t)
-#   and so on...
-#
-# In the implicit discrete case this shouldn't happen. The simplification should 
-# look like a NonlinearSystem.
-#
-# For discrete systems Shift(t, 2)(x(t)) is not equivalent to Shift(t, 1)(Shift(t,1)(x(t))
-# This is different from the continuous case where D(D(x)) can be substituted for 
-# by iteratively substituting x_t ~ D(x), then x_tt ~ D(x_t). For this reason the
-# total_sub dict is updated at the time that the renamed variables are written,
-# inside the loop where new variables are generated.
+There are three cases where we want to generate new variables to convert
+the system into first order (semi-implicit) ODEs.
+    
+1. To first order:
+Whenever higher order differentiated variable like `D(D(D(x)))` appears,
+we introduce new variables `x_t`, `x_tt`, and `x_ttt` and new equations
+```
+D(x_tt) = x_ttt
+D(x_t) = x_tt
+D(x) = x_t
+```
+and replace `D(x)` to `x_t`, `D(D(x))` to `x_tt`, and `D(D(D(x)))` to
+`x_ttt`.
+
+2. To implicit to semi-implicit ODEs:
+2.1: Unsolvable derivative:
+If one derivative variable `D(x)` is unsolvable in all the equations it
+appears in, then we introduce a new variable `x_t`, a new equation
+```
+D(x) ~ x_t
+```
+and replace all other `D(x)` to `x_t`.
+
+2.2: Solvable derivative:
+If one derivative variable `D(x)` is solvable in at least one of the
+equations it appears in, then we introduce a new variable `x_t`. One of
+the solvable equations must be in the form of `0 ~ L(D(x), u...)` and
+there exists a function `l` such that `D(x) ~ l(u...)`. We should replace
+it to
+```
+0 ~ x_t - l(u...)
+D(x) ~ x_t
+```
+and replace all other `D(x)` to `x_t`.
+
+Observe that we don't need to actually introduce a new variable `x_t`, as
+the above equations can be lowered to
+```
+x_t := l(u...)
+D(x) ~ x_t
+```
+where `:=` denotes assignment.
+
+As a final note, in all the above cases where we need to introduce new
+variables and equations, don't add them when they already exist.
+"""
+function generate_derivative_variables!(ts::TearingState, neweqs, total_sub, var_eq_matching, var_order;
+        is_discrete = false, mm = nothing)
 
-import ModelingToolkit: Shift
-function tearing_reassemble(state::TearingState, var_eq_matching,
-        full_var_eq_matching = nothing; simplify = false, mm = nothing, cse_hack = true, array_hack = true)
-    @unpack fullvars, sys, structure = state
+    @unpack fullvars, sys, structure = ts
     @unpack solvable_graph, var_to_diff, eq_to_diff, graph, lowest_shift = structure
-    extra_vars = Int[]
-    if full_var_eq_matching !== nothing
-        for v in 𝑑vertices(state.structure.graph)
-            eq = full_var_eq_matching[v]
-            eq isa Int && continue
-            push!(extra_vars, v)
-        end
-    end
-
-    if ModelingToolkit.has_iv(state.sys)
-        iv = get_iv(state.sys)
-        if is_only_discrete(state.structure)
-            D = Shift(iv, 1)
-            lower_name = lower_varname_withshift
-            idx_to_lowest_shift = Dict{Int, Int}(var => 0 for var in 1:length(fullvars))
-            for (i,var) in enumerate(fullvars)
-                key = (operation(var) isa Shift) ? only(arguments(var)) : var
-                idx_to_lowest_shift[i] = get(lowest_shift, key, 0)
-            end
-        else
-            D = Differential(iv)
-            lower_name = lower_varname_with_unit
-        end
-    else
-        iv = D = nothing
-        lower_name = lower_varname_with_unit
-    end
-
+    eq_var_matching = invview(var_eq_matching)
     diff_to_var = invview(var_to_diff)
-    neweqs = collect(equations(state))
-
-    # Terminology and Definition:
-    #
-    # A general DAE is in the form of `F(u'(t), u(t), p, t) == 0`. We can
-    # characterize variables in `u(t)` into two classes: differential variables
-    # (denoted `v(t)`) and algebraic variables (denoted `z(t)`). Differential
-    # variables are marked as `SelectedState` and they are differentiated in the
-    # DAE system, i.e. `v'(t)` are all the variables in `u'(t)` that actually
-    # appear in the system. Algebraic variables are variables that are not
-    # differential variables.
-    
-    substitute_dummy_derivatives!(fullvars, var_to_diff, diff_to_var, var_eq_matching, neweqs)
-
-    # `SelectedState` information is no longer needed past here. State selection
-    # is done. All non-differentiated variables are algebraic variables, and all
-    # variables that appear differentiated are differential variables.
-
-    ### Extract partition information
-    is_solvable = let solvable_graph = solvable_graph
-        (eq, iv) -> eq isa Int && iv isa Int && BipartiteEdge(eq, iv) in solvable_graph
-    end
-
-    # if var is like D(x) or Shift(t, 1)(x)
-    isdervar = let diff_to_var = diff_to_var
-        var -> diff_to_var[var] !== nothing
-    end
-    var_order = let diff_to_var = diff_to_var
-        dv -> begin
-            order = 0
-            while (dv′ = diff_to_var[dv]) !== nothing
-                order += 1
-                dv = dv′
-            end
-            is_only_discrete(state.structure) && (order = -idx_to_lowest_shift[dv] - order)
-            order, dv
-        end
-    end
+    iv = ModelingToolkit.has_iv(sys) ? ModelingToolkit.get_iv(sys) : nothing
+    lower_name = is_discrete ? lower_name_withshift : lower_name_with_unit
 
-    # There are three cases where we want to generate new variables to convert
-    # the system into first order (semi-implicit) ODEs.
-    #
-    # 1. To first order:
-    # Whenever higher order differentiated variable like `D(D(D(x)))` appears,
-    # we introduce new variables `x_t`, `x_tt`, and `x_ttt` and new equations
-    # ```
-    # D(x_tt) = x_ttt
-    # D(x_t) = x_tt
-    # D(x) = x_t
-    # ```
-    # and replace `D(x)` to `x_t`, `D(D(x))` to `x_tt`, and `D(D(D(x)))` to
-    # `x_ttt`.
-    #
-    # 2. To implicit to semi-implicit ODEs:
-    # 2.1: Unsolvable derivative:
-    # If one derivative variable `D(x)` is unsolvable in all the equations it
-    # appears in, then we introduce a new variable `x_t`, a new equation
-    # ```
-    # D(x) ~ x_t
-    # ```
-    # and replace all other `D(x)` to `x_t`.
-    #
-    # 2.2: Solvable derivative:
-    # If one derivative variable `D(x)` is solvable in at least one of the
-    # equations it appears in, then we introduce a new variable `x_t`. One of
-    # the solvable equations must be in the form of `0 ~ L(D(x), u...)` and
-    # there exists a function `l` such that `D(x) ~ l(u...)`. We should replace
-    # it to
-    # ```
-    # 0 ~ x_t - l(u...)
-    # D(x) ~ x_t
-    # ```
-    # and replace all other `D(x)` to `x_t`.
-    #
-    # Observe that we don't need to actually introduce a new variable `x_t`, as
-    # the above equations can be lowered to
-    # ```
-    # x_t := l(u...)
-    # D(x) ~ x_t
-    # ```
-    # where `:=` denotes assignment.
-    #
-    # As a final note, in all the above cases where we need to introduce new
-    # variables and equations, don't add them when they already exist.
-    eq_var_matching = invview(var_eq_matching)
     linear_eqs = mm === nothing ? Dict{Int, Int}() :
                  Dict(reverse(en) for en in enumerate(mm.nzrows))
 
-    total_sub = Dict()
     for v in 1:length(var_to_diff)
         is_highest_discrete = begin
             var = fullvars[v]
             op = operation(var)
-            if (!is_only_discrete(state.structure) || op isa Shift) 
+            if (!is_discrete || op isa Shift) 
                 false
             elseif !haskey(lowest_shift, var)
                 false
@@ -450,7 +366,8 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
 
         solved = var_eq_matching[dv] isa Int
         solved && continue
-        # check if there's `D(x) = x_t` already
+
+        # Check if there's `D(x) = x_t` already
         local v_t, dummy_eq
         for eq in 𝑑neighbors(solvable_graph, dv)
             mi = get(linear_eqs, eq, 0)
@@ -468,6 +385,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
                 @goto FOUND_DUMMY_EQ
             end
         end
+
         dx = fullvars[dv]
         # add `x_t`
         println()
@@ -499,7 +417,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
         add_edge!(solvable_graph, dummy_eq, dv)
         @assert nsrcs(graph) == nsrcs(solvable_graph) == dummy_eq
         @label FOUND_DUMMY_EQ
-        is_only_discrete(state.structure) && begin
+        is_discrete && begin
             idx_to_lowest_shift[v_t] = idx_to_lowest_shift[dv]
             operation(dx) isa Shift && (total_sub[dx] = x_t)
             order == 1 && (total_sub[x_t] = fullvars[var_to_diff[dv]])
@@ -509,13 +427,51 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
         eq_var_matching[dummy_eq] = dv
     end
     @show total_sub
+end
+
+"""
+Solve the solvable equations of the system and generate differential (or discrete)
+equations in terms of the selected states.
+
+Will reorder equations and unknowns to be:
+   [diffeqs; ...]
+   [diffvars; ...]
+such that the mass matrix is:
+   [I  0
+    0  0].
+
+Update the state to account for the new ordering and equations.
+"""
+function solve_and_generate_equations!(state::TearingState, neweqs, total_sub, var_eq_matching, var_order)
+    @unpack fullvars, sys, structure = state 
+    @unpack solvable_graph, var_to_diff, eq_to_diff, graph, lowest_shift = structure
+    eq_var_matching = invview(var_eq_matching)
+    diff_to_var = invview(var_to_diff)
+
+    if ModelingToolkit.has_iv(sys)
+        iv = get_iv(sys)
+        if is_only_discrete(structure)
+            D = Shift(iv, 1)
+            lower_name = lower_varname_withshift
+        else
+            D = Differential(iv)
+            lower_name = lower_varname_with_unit
+        end
+    else
+        iv = D = nothing
+        lower_name = lower_varname_with_unit
+    end
+
+    # if var is like D(x) or Shift(t, 1)(x)
+    isdervar = let diff_to_var = diff_to_var
+        var -> diff_to_var[var] !== nothing
+    end
+
+    ### Extract partition information
+    is_solvable = let solvable_graph = solvable_graph
+        (eq, iv) -> eq isa Int && iv isa Int && BipartiteEdge(eq, iv) in solvable_graph
+    end
 
-    # Will reorder equations and unknowns to be:
-    # [diffeqs; ...]
-    # [diffvars; ...]
-    # such that the mass matrix is:
-    # [I  0
-    #  0  0].
     diffeq_idxs = Int[]
     algeeq_idxs = Int[]
     diff_eqs = Equation[]
@@ -525,14 +481,11 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
     solved_equations = Int[]
     solved_variables = Int[]
 
-    # Solve solvable equations
-    println()
-    println("SOLVING SOLVABLE EQUATIONS.")
     toporder = topological_sort(DiCMOBiGraph{false}(graph, var_eq_matching))
     eqs = Iterators.reverse(toporder)
-    idep = iv
-    @show eq_var_matching
+    idep = ModelingToolkit.has_iv(sys) ? ModelingToolkit.get_iv(sys) : nothing
 
+    # Generate differential equations in terms of the new variables. 
     for ieq in eqs
         println()
         iv = eq_var_matching[ieq]
@@ -597,6 +550,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
             push!(algeeq_idxs, ieq)
         end
     end
+
     # TODO: BLT sorting
     neweqs = [diff_eqs; alge_eqs]
     inveqsperm = [diffeq_idxs; algeeq_idxs]
@@ -625,7 +579,6 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
     graph = contract_variables(graph, var_eq_matching, varsperm, eqsperm,
         length(solved_variables), length(solved_variables_set))
 
-    # Update system
     new_var_to_diff = complete(DiffGraph(length(invvarsperm)))
     for (v, d) in enumerate(var_to_diff)
         v′ = varsperm[v]
@@ -641,6 +594,92 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
         new_eq_to_diff[v′] = d′ > 0 ? d′ : nothing
     end
 
+    fullvars[invvarsperm], new_var_to_diff, new_eq_to_diff, neweqs, subeqs
+end
+
+# Documenting the differences to structural simplification for discrete systems:
+# In discrete systems the lowest-order term is x_k-i, instead of x(t). In order
+# to substitute dummy variables for x_k-1, x_k-2, ... instead you need to reverse 
+# the order. So for discrete systems `var_order` is defined a little differently.
+#
+# The orders will also be off by one. The reason this is is that the dynamics of
+# the system should be given in terms of Shift(t, 1)(x(t), x(t-1), ...). But 
+# having the observables be indexed by the next time step is not so nice. So we 
+# handle the shifts in the renaming, rather than explicitly.
+#
+# The substitution should look like the following: 
+#   x(t) -> Shift(t, 1)(x(t))
+#   x(k-1) -> x(t)
+#   x(k-2) -> x_{t-1}(t)
+#   x(k-3) -> x_{t-2}(t)
+#   and so on...
+#
+# In the implicit discrete case this shouldn't happen. The simplification should 
+# look like a NonlinearSystem.
+#
+# For discrete systems Shift(t, 2)(x(t)) is not equivalent to Shift(t, 1)(Shift(t,1)(x(t))
+# This is different from the continuous case where D(D(x)) can be substituted for 
+# by iteratively substituting x_t ~ D(x), then x_tt ~ D(x_t). For this reason the
+# total_sub dict is updated at the time that the renamed variables are written,
+# inside the loop where new variables are generated.
+
+
+# Terminology and Definition:
+#
+# A general DAE is in the form of `F(u'(t), u(t), p, t) == 0`. We can
+# characterize variables in `u(t)` into two classes: differential variables
+# (denoted `v(t)`) and algebraic variables (denoted `z(t)`). Differential
+# variables are marked as `SelectedState` and they are differentiated in the
+# DAE system, i.e. `v'(t)` are all the variables in `u'(t)` that actually
+# appear in the system. Algebraic variables are variables that are not
+# differential variables.
+    
+import ModelingToolkit: Shift
+
+function tearing_reassemble(state::TearingState, var_eq_matching,
+        full_var_eq_matching = nothing; simplify = false, mm = nothing, cse_hack = true, array_hack = true)
+    @unpack fullvars, sys, structure = state
+    @unpack solvable_graph, var_to_diff, eq_to_diff, graph, lowest_shift = structure
+    extra_vars = Int[]
+    if full_var_eq_matching !== nothing
+        for v in 𝑑vertices(state.structure.graph)
+            eq = full_var_eq_matching[v]
+            eq isa Int && continue
+            push!(extra_vars, v)
+        end
+    end
+    neweqs = collect(equations(state))
+    diff_to_var = invview(var_to_diff)
+    total_sub = Dict()
+    dummy_sub = Dict()
+    is_discrete = is_only_discrete(state.structure)
+
+    if is_discrete
+        idx_to_lowest_shift = Dict{Int, Int}(var => 0 for var in 1:length(fullvars))
+        for (i,var) in enumerate(fullvars)
+            key = (operation(var) isa Shift) ? only(arguments(var)) : var
+            idx_to_lowest_shift[i] = get(lowest_shift, key, 0)
+        end
+    end
+    var_order = let diff_to_var = diff_to_var
+        dv -> begin
+            order = 0
+            while (dv′ = diff_to_var[dv]) !== nothing
+                order += 1
+                dv = dv′
+            end
+            is_discrete && (order = -idx_to_lowest_shift[dv] - order)
+            order, dv
+        end
+    end
+
+    # Structural simplification 
+    substitute_dummy_derivatives!(state, neweqs, dummy_sub, var_eq_matching)
+    generate_derivative_variables!(state, neweqs, total_sub, var_eq_matching, var_order; 
+                                   is_discrete, mm)
+    new_fullvars, new_var_to_diff, new_eq_to_diff, neweqs, subeqs = solve_and_generate_equations!(state, neweqs, total_sub, var_eq_matching, var_order)
+
+    # Update system
     var_to_diff = new_var_to_diff
     eq_to_diff = new_eq_to_diff
     diff_to_var = invview(var_to_diff)
@@ -649,7 +688,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
     @set! state.structure.graph = complete(graph)
     @set! state.structure.var_to_diff = var_to_diff
     @set! state.structure.eq_to_diff = eq_to_diff
-    @set! state.fullvars = fullvars = fullvars[invvarsperm]
+    @set! state.fullvars = fullvars = new_fullvars
     ispresent = let var_to_diff = var_to_diff, graph = graph
         i -> (!isempty(𝑑neighbors(graph, i)) ||
               (var_to_diff[i] !== nothing && !isempty(𝑑neighbors(graph, var_to_diff[i]))))
@@ -657,11 +696,13 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
 
     sys = state.sys
 
+    @show dummy_sub
     obs_sub = dummy_sub
     for eq in neweqs
         isdiffeq(eq) || continue
         obs_sub[eq.lhs] = eq.rhs
     end
+    @show obs_sub
     # TODO: compute the dependency correctly so that we don't have to do this
     obs = [fast_substitute(observed(sys), obs_sub); subeqs]
 
@@ -680,7 +721,6 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
 
     @set! sys.eqs = neweqs
     @set! sys.observed = obs
-
     @set! sys.substitutions = Substitutions(subeqs, deps)
 
     # Only makes sense for time-dependent

From 5a77f4347b1b9259cb6da56128146af254b827a2 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Thu, 6 Feb 2025 17:22:49 -0500
Subject: [PATCH 072/111] explicit case

---
 .../symbolics_tearing.jl                      | 180 ++++++++++--------
 src/structural_transformation/utils.jl        |   7 +-
 src/systems/systems.jl                        |   8 +-
 3 files changed, 104 insertions(+), 91 deletions(-)

diff --git a/src/structural_transformation/symbolics_tearing.jl b/src/structural_transformation/symbolics_tearing.jl
index cf70914a51..76bba2a303 100644
--- a/src/structural_transformation/symbolics_tearing.jl
+++ b/src/structural_transformation/symbolics_tearing.jl
@@ -333,6 +333,32 @@ where `:=` denotes assignment.
 
 As a final note, in all the above cases where we need to introduce new
 variables and equations, don't add them when they already exist.
+
+###### DISCRETE SYSTEMS ####### 
+
+Documenting the differences to structural simplification for discrete systems:
+In discrete systems the lowest-order term is x_k-i, instead of x(t).
+
+The orders will also be off by one. The reason this is is that the dynamics of
+the system should be given in terms of Shift(t, 1)(x(t), x(t-1), ...). But 
+having the observables be indexed by the next time step is not so nice. So we 
+handle the shifts in the renaming, rather than explicitly.
+
+The substitution should look like the following: 
+  x(t) -> Shift(t, 1)(x(t))
+  Shift(t, -1)(x(t)) -> x(t)
+  Shift(t, -2)(x(t)) -> x_{t-1}(t)
+  Shift(t, -3)(x(t)) -> x_{t-2}(t)
+  and so on...
+
+In the implicit discrete case this shouldn't happen. The simplification should 
+look like a NonlinearSystem.
+
+For discrete systems Shift(t, 2)(x(t)) is not equivalent to Shift(t, 1)(Shift(t,1)(x(t))
+This is different from the continuous case where D(D(x)) can be substituted for 
+by iteratively substituting x_t ~ D(x), then x_tt ~ D(x_t). For this reason the
+total_sub dict is updated at the time that the renamed variables are written,
+inside the loop where new variables are generated.
 """
 function generate_derivative_variables!(ts::TearingState, neweqs, total_sub, var_eq_matching, var_order;
         is_discrete = false, mm = nothing)
@@ -342,26 +368,41 @@ function generate_derivative_variables!(ts::TearingState, neweqs, total_sub, var
     eq_var_matching = invview(var_eq_matching)
     diff_to_var = invview(var_to_diff)
     iv = ModelingToolkit.has_iv(sys) ? ModelingToolkit.get_iv(sys) : nothing
-    lower_name = is_discrete ? lower_name_withshift : lower_name_with_unit
+    lower_name = is_discrete ? lower_varname_withshift : lower_varname_with_unit
+
+    # index v gets mapped to the lowest shift and the index of the unshifted variable
+    if is_discrete
+        idx_to_lowest_shift = Dict{Int, Tuple{Int, Int}}(var => (0,0) for var in 1:length(fullvars))
+        var_to_unshiftedidx = Dict{Any, Int}(var => findfirst(x -> isequal(x, var), fullvars) for var in keys(lowest_shift))
+
+        for (i,var) in enumerate(fullvars)
+            key = (operation(var) isa Shift) ? only(arguments(var)) : var
+            idx_to_lowest_shift[i] = (get(lowest_shift, key, 0), get(var_to_unshiftedidx, key, i))
+        end
+    end
 
     linear_eqs = mm === nothing ? Dict{Int, Int}() :
                  Dict(reverse(en) for en in enumerate(mm.nzrows))
 
+    # v is the index of the current variable, x = fullvars[v]
+    # dv is the index of the derivative dx = D(x), x_t is the substituted variable 
+    #
+    # For ODESystems: lv is the index of the lowest-order variable (x(t))
+    # For DiscreteSystems: 
+    # - lv is the index of the lowest-order variable (Shift(t, k)(x(t)))
+    # - uv is the index of the highest-order variable (x(t))
     for v in 1:length(var_to_diff)
-        is_highest_discrete = begin
-            var = fullvars[v]
-            op = operation(var)
-            if (!is_discrete || op isa Shift) 
-                false
-            elseif !haskey(lowest_shift, var)
-                false
-            else
-                low = lowest_shift[var]
-                idx = findfirst(x -> isequal(x, Shift(iv, low)(var)), fullvars)
-                true
+        dv = var_to_diff[v]
+        if is_discrete 
+            x = fullvars[v]
+            op = operation(x)
+            (low, uv) = idx_to_lowest_shift[v]
+
+            # If v is unshifted (i.e. x(t)), then substitute the lowest-shift variable
+            if !(op isa Shift) && (low != 0)
+                dv = findfirst(_x -> isequal(_x, Shift(iv, low)(x)), fullvars)
             end
         end
-        dv = is_highest_discrete ? idx : var_to_diff[v]
         dv isa Int || continue
 
         solved = var_eq_matching[dv] isa Int
@@ -388,13 +429,9 @@ function generate_derivative_variables!(ts::TearingState, neweqs, total_sub, var
 
         dx = fullvars[dv]
         # add `x_t`
-        println()
-        @show order, lv = var_order(dv)
-        x_t = lower_name(fullvars[lv], iv, order)
-        @show fullvars[v]
-        @show fullvars[dv]
-        @show fullvars[lv]
-        @show dx, x_t
+        order, lv = var_order(dv)
+        x_t = is_discrete ? lower_name(fullvars[lv], iv, -low-order-1; unshifted = fullvars[uv]) :
+                            lower_name(fullvars[lv], iv, order)
         push!(fullvars, simplify_shifts(x_t))
         v_t = length(fullvars)
         v_t_idx = add_vertex!(var_to_diff)
@@ -402,11 +439,27 @@ function generate_derivative_variables!(ts::TearingState, neweqs, total_sub, var
         # TODO: do we care about solvable_graph? We don't use them after
         # `dummy_derivative_graph`.
         add_vertex!(solvable_graph, DST)
-        # var_eq_matching is a bit odd.
-        # length(var_eq_matching) == length(invview(var_eq_matching))
         push!(var_eq_matching, unassigned)
         @assert v_t_idx == ndsts(graph) == ndsts(solvable_graph) == length(fullvars) ==
                 length(var_eq_matching)
+
+        # Add the substitutions to total_sub directly.  
+        is_discrete && begin
+            idx_to_lowest_shift[v_t] = idx_to_lowest_shift[dv]
+            @show dx
+            if operation(dx) isa Shift 
+                total_sub[dx] = x_t
+                for e in 𝑑neighbors(graph, dv)
+                    add_edge!(graph, e, v_t)
+                    rem_edge!(graph, e, dv)
+                end
+                @show graph
+                !(operation(x) isa Shift) && begin
+                    var_to_diff[v_t] = var_to_diff[dv]
+                    continue
+                end
+            end
+        end
         # add `D(x) - x_t ~ 0`
         push!(neweqs, 0 ~ x_t - dx)
         add_vertex!(graph, SRC)
@@ -417,16 +470,10 @@ function generate_derivative_variables!(ts::TearingState, neweqs, total_sub, var
         add_edge!(solvable_graph, dummy_eq, dv)
         @assert nsrcs(graph) == nsrcs(solvable_graph) == dummy_eq
         @label FOUND_DUMMY_EQ
-        is_discrete && begin
-            idx_to_lowest_shift[v_t] = idx_to_lowest_shift[dv]
-            operation(dx) isa Shift && (total_sub[dx] = x_t)
-            order == 1 && (total_sub[x_t] = fullvars[var_to_diff[dv]])
-        end
         var_to_diff[v_t] = var_to_diff[dv]
         var_eq_matching[dv] = unassigned
         eq_var_matching[dummy_eq] = dv
     end
-    @show total_sub
 end
 
 """
@@ -442,7 +489,7 @@ such that the mass matrix is:
 
 Update the state to account for the new ordering and equations.
 """
-function solve_and_generate_equations!(state::TearingState, neweqs, total_sub, var_eq_matching, var_order)
+function solve_and_generate_equations!(state::TearingState, neweqs, total_sub, var_eq_matching, var_order; simplify = false)
     @unpack fullvars, sys, structure = state 
     @unpack solvable_graph, var_to_diff, eq_to_diff, graph, lowest_shift = structure
     eq_var_matching = invview(var_eq_matching)
@@ -467,7 +514,7 @@ function solve_and_generate_equations!(state::TearingState, neweqs, total_sub, v
         var -> diff_to_var[var] !== nothing
     end
 
-    ### Extract partition information
+    # Extract partition information
     is_solvable = let solvable_graph = solvable_graph
         (eq, iv) -> eq isa Int && iv isa Int && BipartiteEdge(eq, iv) in solvable_graph
     end
@@ -485,9 +532,12 @@ function solve_and_generate_equations!(state::TearingState, neweqs, total_sub, v
     eqs = Iterators.reverse(toporder)
     idep = ModelingToolkit.has_iv(sys) ? ModelingToolkit.get_iv(sys) : nothing
 
-    # Generate differential equations in terms of the new variables. 
+    @show eq_var_matching
+    @show fullvars
+    @show neweqs
+
+    # Equation ieq is solved for the RHS of iv 
     for ieq in eqs
-        println()
         iv = eq_var_matching[ieq]
         if is_solvable(ieq, iv)
             # We don't solve differential equations, but we will need to try to
@@ -497,23 +547,14 @@ function solve_and_generate_equations!(state::TearingState, neweqs, total_sub, v
                 isnothing(D) &&
                     error("Differential found in a non-differential system. Likely this is a bug in the construction of an initialization system. Please report this issue with a reproducible example. Offending equation: $(equations(sys)[ieq])")
                 order, lv = var_order(iv)
-                @show fullvars[iv]
-                @show (order, lv)
-                dx = D(simplify_shifts(lower_name(
-                    fullvars[lv], idep, order - 1)))
-                @show dx
+                dx = D(fullvars[lv])
                 eq = dx ~ simplify_shifts(Symbolics.fixpoint_sub(
                     Symbolics.symbolic_linear_solve(neweqs[ieq],
                         fullvars[iv]),
                     total_sub; operator = ModelingToolkit.Shift))
-                @show total_sub
-                @show eq
                 for e in 𝑑neighbors(graph, iv)
-                    e == ieq && continue
-                    for v in 𝑠neighbors(graph, e)
-                        add_edge!(graph, e, v)
-                    end
                     rem_edge!(graph, e, iv)
+                    add_edge!(graph, e, lv)
                 end
                 push!(diff_eqs, eq)
                 total_sub[simplify_shifts(eq.lhs)] = eq.rhs
@@ -594,38 +635,10 @@ function solve_and_generate_equations!(state::TearingState, neweqs, total_sub, v
         new_eq_to_diff[v′] = d′ > 0 ? d′ : nothing
     end
 
-    fullvars[invvarsperm], new_var_to_diff, new_eq_to_diff, neweqs, subeqs
+    fullvars[invvarsperm], new_var_to_diff, new_eq_to_diff, neweqs, subeqs, graph
 end
 
-# Documenting the differences to structural simplification for discrete systems:
-# In discrete systems the lowest-order term is x_k-i, instead of x(t). In order
-# to substitute dummy variables for x_k-1, x_k-2, ... instead you need to reverse 
-# the order. So for discrete systems `var_order` is defined a little differently.
-#
-# The orders will also be off by one. The reason this is is that the dynamics of
-# the system should be given in terms of Shift(t, 1)(x(t), x(t-1), ...). But 
-# having the observables be indexed by the next time step is not so nice. So we 
-# handle the shifts in the renaming, rather than explicitly.
-#
-# The substitution should look like the following: 
-#   x(t) -> Shift(t, 1)(x(t))
-#   x(k-1) -> x(t)
-#   x(k-2) -> x_{t-1}(t)
-#   x(k-3) -> x_{t-2}(t)
-#   and so on...
-#
-# In the implicit discrete case this shouldn't happen. The simplification should 
-# look like a NonlinearSystem.
-#
-# For discrete systems Shift(t, 2)(x(t)) is not equivalent to Shift(t, 1)(Shift(t,1)(x(t))
-# This is different from the continuous case where D(D(x)) can be substituted for 
-# by iteratively substituting x_t ~ D(x), then x_tt ~ D(x_t). For this reason the
-# total_sub dict is updated at the time that the renamed variables are written,
-# inside the loop where new variables are generated.
-
-
 # Terminology and Definition:
-#
 # A general DAE is in the form of `F(u'(t), u(t), p, t) == 0`. We can
 # characterize variables in `u(t)` into two classes: differential variables
 # (denoted `v(t)`) and algebraic variables (denoted `z(t)`). Differential
@@ -654,13 +667,6 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
     dummy_sub = Dict()
     is_discrete = is_only_discrete(state.structure)
 
-    if is_discrete
-        idx_to_lowest_shift = Dict{Int, Int}(var => 0 for var in 1:length(fullvars))
-        for (i,var) in enumerate(fullvars)
-            key = (operation(var) isa Shift) ? only(arguments(var)) : var
-            idx_to_lowest_shift[i] = get(lowest_shift, key, 0)
-        end
-    end
     var_order = let diff_to_var = diff_to_var
         dv -> begin
             order = 0
@@ -668,7 +674,6 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
                 order += 1
                 dv = dv′
             end
-            is_discrete && (order = -idx_to_lowest_shift[dv] - order)
             order, dv
         end
     end
@@ -677,7 +682,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
     substitute_dummy_derivatives!(state, neweqs, dummy_sub, var_eq_matching)
     generate_derivative_variables!(state, neweqs, total_sub, var_eq_matching, var_order; 
                                    is_discrete, mm)
-    new_fullvars, new_var_to_diff, new_eq_to_diff, neweqs, subeqs = solve_and_generate_equations!(state, neweqs, total_sub, var_eq_matching, var_order)
+    new_fullvars, new_var_to_diff, new_eq_to_diff, neweqs, subeqs, graph = solve_and_generate_equations!(state, neweqs, total_sub, var_eq_matching, var_order; simplify)
 
     # Update system
     var_to_diff = new_var_to_diff
@@ -693,16 +698,25 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
         i -> (!isempty(𝑑neighbors(graph, i)) ||
               (var_to_diff[i] !== nothing && !isempty(𝑑neighbors(graph, var_to_diff[i]))))
     end
+    @show graph
+    println()
+    println("Shift test...")
+    @show neweqs
+    @show fullvars
+    @show 𝑑neighbors(graph, 5)
 
     sys = state.sys
-
-    @show dummy_sub
     obs_sub = dummy_sub
     for eq in neweqs
         isdiffeq(eq) || continue
         obs_sub[eq.lhs] = eq.rhs
     end
+    is_discrete && for eq in subeqs
+        obs_sub[eq.rhs] = eq.lhs
+    end
+
     @show obs_sub
+    @show observed(sys)
     # TODO: compute the dependency correctly so that we don't have to do this
     obs = [fast_substitute(observed(sys), obs_sub); subeqs]
 
diff --git a/src/structural_transformation/utils.jl b/src/structural_transformation/utils.jl
index 707a57097b..8df2d4177b 100644
--- a/src/structural_transformation/utils.jl
+++ b/src/structural_transformation/utils.jl
@@ -449,10 +449,9 @@ end
 ### Misc
 ###
 
-function lower_varname_withshift(var, iv, order)
-    order == 0 && return var
-    #order == -1 && return Shift(iv, 1)(var)
-    ds = "$iv-$(order-1)"
+function lower_varname_withshift(var, iv, backshift; unshifted = nothing)
+    backshift == 0 && return unshifted 
+    ds = "$iv-$backshift"
     d_separator = 'ˍ'
 
     if ModelingToolkit.isoperator(var, ModelingToolkit.Shift)
diff --git a/src/systems/systems.jl b/src/systems/systems.jl
index 9c8c272c5c..0a7e4264e4 100644
--- a/src/systems/systems.jl
+++ b/src/systems/systems.jl
@@ -41,10 +41,10 @@ function structural_simplify(
     end
     if newsys isa DiscreteSystem &&
        any(eq -> symbolic_type(eq.lhs) == NotSymbolic(), equations(newsys))
-        error("""
-            Encountered algebraic equations when simplifying discrete system. This is \
-            not yet supported.
-        """)
+        # error("""
+        #     Encountered algebraic equations when simplifying discrete system. This is \
+        #     not yet supported.
+        # """)
     end
     for pass in additional_passes
         newsys = pass(newsys)

From 91acf912e7651b1482a5e29552026774ed5fe5f0 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Thu, 6 Feb 2025 22:48:29 -0500
Subject: [PATCH 073/111] beginning implicit equation

---
 .../symbolics_tearing.jl                      | 76 +++++++++----------
 src/structural_transformation/utils.jl        |  1 +
 2 files changed, 37 insertions(+), 40 deletions(-)

diff --git a/src/structural_transformation/symbolics_tearing.jl b/src/structural_transformation/symbolics_tearing.jl
index 76bba2a303..fb1765964e 100644
--- a/src/structural_transformation/symbolics_tearing.jl
+++ b/src/structural_transformation/symbolics_tearing.jl
@@ -360,9 +360,8 @@ by iteratively substituting x_t ~ D(x), then x_tt ~ D(x_t). For this reason the
 total_sub dict is updated at the time that the renamed variables are written,
 inside the loop where new variables are generated.
 """
-function generate_derivative_variables!(ts::TearingState, neweqs, total_sub, var_eq_matching, var_order;
+function generate_derivative_variables!(ts::TearingState, neweqs, total_sub, var_eq_matching;
         is_discrete = false, mm = nothing)
-
     @unpack fullvars, sys, structure = ts
     @unpack solvable_graph, var_to_diff, eq_to_diff, graph, lowest_shift = structure
     eq_var_matching = invview(var_eq_matching)
@@ -386,13 +385,14 @@ function generate_derivative_variables!(ts::TearingState, neweqs, total_sub, var
 
     # v is the index of the current variable, x = fullvars[v]
     # dv is the index of the derivative dx = D(x), x_t is the substituted variable 
-    #
     # For ODESystems: lv is the index of the lowest-order variable (x(t))
     # For DiscreteSystems: 
     # - lv is the index of the lowest-order variable (Shift(t, k)(x(t)))
     # - uv is the index of the highest-order variable (x(t))
     for v in 1:length(var_to_diff)
         dv = var_to_diff[v]
+        println()
+        @show (v, dv)
         if is_discrete 
             x = fullvars[v]
             op = operation(x)
@@ -405,6 +405,9 @@ function generate_derivative_variables!(ts::TearingState, neweqs, total_sub, var
         end
         dv isa Int || continue
 
+        @show dv
+        @show var_eq_matching[dv]
+        @show fullvars
         solved = var_eq_matching[dv] isa Int
         solved && continue
 
@@ -429,9 +432,10 @@ function generate_derivative_variables!(ts::TearingState, neweqs, total_sub, var
 
         dx = fullvars[dv]
         # add `x_t`
-        order, lv = var_order(dv)
+        order, lv = var_order(diff_to_var, dv)
         x_t = is_discrete ? lower_name(fullvars[lv], iv, -low-order-1; unshifted = fullvars[uv]) :
                             lower_name(fullvars[lv], iv, order)
+        @show dx, x_t
         push!(fullvars, simplify_shifts(x_t))
         v_t = length(fullvars)
         v_t_idx = add_vertex!(var_to_diff)
@@ -443,23 +447,23 @@ function generate_derivative_variables!(ts::TearingState, neweqs, total_sub, var
         @assert v_t_idx == ndsts(graph) == ndsts(solvable_graph) == length(fullvars) ==
                 length(var_eq_matching)
 
-        # Add the substitutions to total_sub directly.  
+        # Add discrete substitutions to total_sub directly.  
         is_discrete && begin
             idx_to_lowest_shift[v_t] = idx_to_lowest_shift[dv]
-            @show dx
             if operation(dx) isa Shift 
                 total_sub[dx] = x_t
                 for e in 𝑑neighbors(graph, dv)
                     add_edge!(graph, e, v_t)
                     rem_edge!(graph, e, dv)
                 end
-                @show graph
+                # Do not add the lowest-order substitution as an equation, just substitute
                 !(operation(x) isa Shift) && begin
                     var_to_diff[v_t] = var_to_diff[dv]
                     continue
                 end
             end
         end
+
         # add `D(x) - x_t ~ 0`
         push!(neweqs, 0 ~ x_t - dx)
         add_vertex!(graph, SRC)
@@ -489,7 +493,7 @@ such that the mass matrix is:
 
 Update the state to account for the new ordering and equations.
 """
-function solve_and_generate_equations!(state::TearingState, neweqs, total_sub, var_eq_matching, var_order; simplify = false)
+function solve_and_generate_equations!(state::TearingState, neweqs, total_sub, var_eq_matching; simplify = false)
     @unpack fullvars, sys, structure = state 
     @unpack solvable_graph, var_to_diff, eq_to_diff, graph, lowest_shift = structure
     eq_var_matching = invview(var_eq_matching)
@@ -530,13 +534,11 @@ function solve_and_generate_equations!(state::TearingState, neweqs, total_sub, v
 
     toporder = topological_sort(DiCMOBiGraph{false}(graph, var_eq_matching))
     eqs = Iterators.reverse(toporder)
-    idep = ModelingToolkit.has_iv(sys) ? ModelingToolkit.get_iv(sys) : nothing
-
-    @show eq_var_matching
-    @show fullvars
-    @show neweqs
+    idep = iv
 
-    # Equation ieq is solved for the RHS of iv 
+    # Generate differential equations.
+    # fullvars[iv] is a differential variable of the form D^n(x), and neweqs[ieq]
+    # is solved to give the RHS.
     for ieq in eqs
         iv = eq_var_matching[ieq]
         if is_solvable(ieq, iv)
@@ -546,7 +548,7 @@ function solve_and_generate_equations!(state::TearingState, neweqs, total_sub, v
             if isdervar(iv)
                 isnothing(D) &&
                     error("Differential found in a non-differential system. Likely this is a bug in the construction of an initialization system. Please report this issue with a reproducible example. Offending equation: $(equations(sys)[ieq])")
-                order, lv = var_order(iv)
+                order, lv = var_order(diff_to_var, iv)
                 dx = D(fullvars[lv])
                 eq = dx ~ simplify_shifts(Symbolics.fixpoint_sub(
                     Symbolics.symbolic_linear_solve(neweqs[ieq],
@@ -634,8 +636,9 @@ function solve_and_generate_equations!(state::TearingState, neweqs, total_sub, v
         d′ = eqsperm[d]
         new_eq_to_diff[v′] = d′ > 0 ? d′ : nothing
     end
+    new_fullvars = fullvars[invvarsperm]
 
-    fullvars[invvarsperm], new_var_to_diff, new_eq_to_diff, neweqs, subeqs, graph
+    new_fullvars, new_var_to_diff, new_eq_to_diff, neweqs, subeqs, graph
 end
 
 # Terminology and Definition:
@@ -649,6 +652,16 @@ end
     
 import ModelingToolkit: Shift
 
+# Give the order of the variable indexed by dv
+function var_order(diff_to_var, dv) 
+    order = 0
+    while (dv′ = diff_to_var[dv]) !== nothing
+        order += 1
+        dv = dv′
+    end
+    order, dv
+end
+
 function tearing_reassemble(state::TearingState, var_eq_matching,
         full_var_eq_matching = nothing; simplify = false, mm = nothing, cse_hack = true, array_hack = true)
     @unpack fullvars, sys, structure = state
@@ -667,22 +680,12 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
     dummy_sub = Dict()
     is_discrete = is_only_discrete(state.structure)
 
-    var_order = let diff_to_var = diff_to_var
-        dv -> begin
-            order = 0
-            while (dv′ = diff_to_var[dv]) !== nothing
-                order += 1
-                dv = dv′
-            end
-            order, dv
-        end
-    end
-
     # Structural simplification 
     substitute_dummy_derivatives!(state, neweqs, dummy_sub, var_eq_matching)
-    generate_derivative_variables!(state, neweqs, total_sub, var_eq_matching, var_order; 
+    generate_derivative_variables!(state, neweqs, total_sub, var_eq_matching; 
                                    is_discrete, mm)
-    new_fullvars, new_var_to_diff, new_eq_to_diff, neweqs, subeqs, graph = solve_and_generate_equations!(state, neweqs, total_sub, var_eq_matching, var_order; simplify)
+    new_fullvars, new_var_to_diff, new_eq_to_diff, neweqs, subeqs, graph = 
+        solve_and_generate_equations!(state, neweqs, total_sub, var_eq_matching; simplify)
 
     # Update system
     var_to_diff = new_var_to_diff
@@ -698,25 +701,18 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
         i -> (!isempty(𝑑neighbors(graph, i)) ||
               (var_to_diff[i] !== nothing && !isempty(𝑑neighbors(graph, var_to_diff[i]))))
     end
-    @show graph
-    println()
-    println("Shift test...")
-    @show neweqs
-    @show fullvars
+    @show ispresent.(collect(1:length(fullvars)))
     @show 𝑑neighbors(graph, 5)
+    @show var_to_diff[5]
 
+    @show neweqs
+    @show fullvars
     sys = state.sys
     obs_sub = dummy_sub
     for eq in neweqs
         isdiffeq(eq) || continue
         obs_sub[eq.lhs] = eq.rhs
     end
-    is_discrete && for eq in subeqs
-        obs_sub[eq.rhs] = eq.lhs
-    end
-
-    @show obs_sub
-    @show observed(sys)
     # TODO: compute the dependency correctly so that we don't have to do this
     obs = [fast_substitute(observed(sys), obs_sub); subeqs]
 
diff --git a/src/structural_transformation/utils.jl b/src/structural_transformation/utils.jl
index 8df2d4177b..b2d67dc547 100644
--- a/src/structural_transformation/utils.jl
+++ b/src/structural_transformation/utils.jl
@@ -450,6 +450,7 @@ end
 ###
 
 function lower_varname_withshift(var, iv, backshift; unshifted = nothing)
+    backshift < 0 && return Shift(iv, -backshift)(var)
     backshift == 0 && return unshifted 
     ds = "$iv-$backshift"
     d_separator = 'ˍ'

From 56829f7c05e94a6048c55b482af8c1752e9cbad7 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Fri, 7 Feb 2025 02:26:13 -0500
Subject: [PATCH 074/111] up

---
 .../symbolics_tearing.jl                      | 112 ++++++++++--------
 src/structural_transformation/utils.jl        |  13 +-
 2 files changed, 73 insertions(+), 52 deletions(-)

diff --git a/src/structural_transformation/symbolics_tearing.jl b/src/structural_transformation/symbolics_tearing.jl
index fb1765964e..020c85b05f 100644
--- a/src/structural_transformation/symbolics_tearing.jl
+++ b/src/structural_transformation/symbolics_tearing.jl
@@ -252,6 +252,7 @@ function substitute_dummy_derivatives!(ts::TearingState, neweqs, dummy_sub, var_
     @unpack fullvars, sys, structure = ts
     @unpack solvable_graph, var_to_diff, eq_to_diff, graph, lowest_shift = structure
     diff_to_var = invview(var_to_diff)
+    iv = ModelingToolkit.has_iv(sys) ? ModelingToolkit.get_iv(sys) : nothing
 
     for var in 1:length(fullvars)
         dv = var_to_diff[var]
@@ -337,16 +338,17 @@ variables and equations, don't add them when they already exist.
 ###### DISCRETE SYSTEMS ####### 
 
 Documenting the differences to structural simplification for discrete systems:
-In discrete systems the lowest-order term is x_k-i, instead of x(t).
+
+1. In discrete systems the lowest-order term is Shift(t, k)(x(t)), instead of x(t). We need to substitute the k-1 lowest order terms instead of the k-1 highest order terms.
 
 The orders will also be off by one. The reason this is is that the dynamics of
 the system should be given in terms of Shift(t, 1)(x(t), x(t-1), ...). But 
 having the observables be indexed by the next time step is not so nice. So we 
-handle the shifts in the renaming, rather than explicitly.
+handle the shifts in the renaming.
 
 The substitution should look like the following: 
   x(t) -> Shift(t, 1)(x(t))
-  Shift(t, -1)(x(t)) -> x(t)
+  Shift(t, -1)(x(t)) -> Shift(t, 0)(x(t))
   Shift(t, -2)(x(t)) -> x_{t-1}(t)
   Shift(t, -3)(x(t)) -> x_{t-2}(t)
   and so on...
@@ -354,14 +356,14 @@ The substitution should look like the following:
 In the implicit discrete case this shouldn't happen. The simplification should 
 look like a NonlinearSystem.
 
-For discrete systems Shift(t, 2)(x(t)) is not equivalent to Shift(t, 1)(Shift(t,1)(x(t))
+2. For discrete systems Shift(t, 2)(x(t)) cannot be substituted as Shift(t, 1)(Shift(t,1)(x(t)). 
 This is different from the continuous case where D(D(x)) can be substituted for 
 by iteratively substituting x_t ~ D(x), then x_tt ~ D(x_t). For this reason the
-total_sub dict is updated at the time that the renamed variables are written,
+shift_sub dict is updated at the time that the renamed variables are written,
 inside the loop where new variables are generated.
 """
-function generate_derivative_variables!(ts::TearingState, neweqs, total_sub, var_eq_matching;
-        is_discrete = false, mm = nothing)
+function generate_derivative_variables!(ts::TearingState, neweqs, var_eq_matching;
+        is_discrete = false, mm = nothing, shift_sub = nothing)
     @unpack fullvars, sys, structure = ts
     @unpack solvable_graph, var_to_diff, eq_to_diff, graph, lowest_shift = structure
     eq_var_matching = invview(var_eq_matching)
@@ -391,23 +393,24 @@ function generate_derivative_variables!(ts::TearingState, neweqs, total_sub, var
     # - uv is the index of the highest-order variable (x(t))
     for v in 1:length(var_to_diff)
         dv = var_to_diff[v]
-        println()
-        @show (v, dv)
+
         if is_discrete 
             x = fullvars[v]
             op = operation(x)
             (low, uv) = idx_to_lowest_shift[v]
 
             # If v is unshifted (i.e. x(t)), then substitute the lowest-shift variable
-            if !(op isa Shift) && (low != 0)
+            if !(op isa Shift)
                 dv = findfirst(_x -> isequal(_x, Shift(iv, low)(x)), fullvars)
             end
+            dx = fullvars[dv]
+            order, lv = var_order(diff_to_var, dv)
+            @show fullvars[uv]
+            x_t = lower_name(fullvars[lv], iv, -low-order-1; unshifted = fullvars[uv])
+            shift_sub[dx] = x_t
+            (var_eq_matching[dv] isa Int) ? continue : @goto DISCRETE_VARIABLE
         end
         dv isa Int || continue
-
-        @show dv
-        @show var_eq_matching[dv]
-        @show fullvars
         solved = var_eq_matching[dv] isa Int
         solved && continue
 
@@ -430,13 +433,13 @@ function generate_derivative_variables!(ts::TearingState, neweqs, total_sub, var
             end
         end
 
-        dx = fullvars[dv]
         # add `x_t`
+        dx = fullvars[dv]
         order, lv = var_order(diff_to_var, dv)
-        x_t = is_discrete ? lower_name(fullvars[lv], iv, -low-order-1; unshifted = fullvars[uv]) :
-                            lower_name(fullvars[lv], iv, order)
-        @show dx, x_t
-        push!(fullvars, simplify_shifts(x_t))
+        x_t = lower_name(fullvars[lv], iv, order)
+        
+        @label DISCRETE_VARIABLE
+        push!(fullvars, x_t)
         v_t = length(fullvars)
         v_t_idx = add_vertex!(var_to_diff)
         add_vertex!(graph, DST)
@@ -444,23 +447,18 @@ function generate_derivative_variables!(ts::TearingState, neweqs, total_sub, var
         # `dummy_derivative_graph`.
         add_vertex!(solvable_graph, DST)
         push!(var_eq_matching, unassigned)
-        @assert v_t_idx == ndsts(graph) == ndsts(solvable_graph) == length(fullvars) ==
-                length(var_eq_matching)
 
         # Add discrete substitutions to total_sub directly.  
         is_discrete && begin
             idx_to_lowest_shift[v_t] = idx_to_lowest_shift[dv]
-            if operation(dx) isa Shift 
-                total_sub[dx] = x_t
-                for e in 𝑑neighbors(graph, dv)
-                    add_edge!(graph, e, v_t)
-                    rem_edge!(graph, e, dv)
-                end
-                # Do not add the lowest-order substitution as an equation, just substitute
-                !(operation(x) isa Shift) && begin
-                    var_to_diff[v_t] = var_to_diff[dv]
-                    continue
-                end
+            for e in 𝑑neighbors(graph, dv)
+                add_edge!(graph, e, v_t)
+                rem_edge!(graph, e, dv)
+            end
+            # Do not add the lowest-order substitution as an equation, just substitute
+            !(operation(x) isa Shift) && begin
+                var_to_diff[v_t] = var_to_diff[dv]
+                continue
             end
         end
 
@@ -472,7 +470,7 @@ function generate_derivative_variables!(ts::TearingState, neweqs, total_sub, var
         add_edge!(graph, dummy_eq, v_t)
         add_vertex!(solvable_graph, SRC)
         add_edge!(solvable_graph, dummy_eq, dv)
-        @assert nsrcs(graph) == nsrcs(solvable_graph) == dummy_eq
+
         @label FOUND_DUMMY_EQ
         var_to_diff[v_t] = var_to_diff[dv]
         var_eq_matching[dv] = unassigned
@@ -480,6 +478,14 @@ function generate_derivative_variables!(ts::TearingState, neweqs, total_sub, var
     end
 end
 
+function add_solvable_variable!() 
+    
+end
+
+function add_solvable_equation!() 
+    
+end
+
 """
 Solve the solvable equations of the system and generate differential (or discrete)
 equations in terms of the selected states.
@@ -492,21 +498,27 @@ such that the mass matrix is:
     0  0].
 
 Update the state to account for the new ordering and equations.
+
+####### DISCRETE CASE
+- only substitute Shift(t, -2) 
 """
-function solve_and_generate_equations!(state::TearingState, neweqs, total_sub, var_eq_matching; simplify = false)
+function solve_and_generate_equations!(state::TearingState, neweqs, var_eq_matching; simplify = false, shift_sub = Dict())
     @unpack fullvars, sys, structure = state 
     @unpack solvable_graph, var_to_diff, eq_to_diff, graph, lowest_shift = structure
     eq_var_matching = invview(var_eq_matching)
     diff_to_var = invview(var_to_diff)
+    dx_sub = Dict()
 
     if ModelingToolkit.has_iv(sys)
         iv = get_iv(sys)
         if is_only_discrete(structure)
             D = Shift(iv, 1)
             lower_name = lower_varname_withshift
+            total_sub = shift_sub
         else
             D = Differential(iv)
             lower_name = lower_varname_with_unit
+            total_sub = dx_sub
         end
     else
         iv = D = nothing
@@ -540,6 +552,7 @@ function solve_and_generate_equations!(state::TearingState, neweqs, total_sub, v
     # fullvars[iv] is a differential variable of the form D^n(x), and neweqs[ieq]
     # is solved to give the RHS.
     for ieq in eqs
+        println()
         iv = eq_var_matching[ieq]
         if is_solvable(ieq, iv)
             # We don't solve differential equations, but we will need to try to
@@ -549,7 +562,9 @@ function solve_and_generate_equations!(state::TearingState, neweqs, total_sub, v
                 isnothing(D) &&
                     error("Differential found in a non-differential system. Likely this is a bug in the construction of an initialization system. Please report this issue with a reproducible example. Offending equation: $(equations(sys)[ieq])")
                 order, lv = var_order(diff_to_var, iv)
-                dx = D(fullvars[lv])
+                @show fullvars[lv]
+                @show simplify_shifts(fullvars[lv])
+                dx = D(simplify_shifts(fullvars[lv]))
                 eq = dx ~ simplify_shifts(Symbolics.fixpoint_sub(
                     Symbolics.symbolic_linear_solve(neweqs[ieq],
                         fullvars[iv]),
@@ -560,6 +575,9 @@ function solve_and_generate_equations!(state::TearingState, neweqs, total_sub, v
                 end
                 push!(diff_eqs, eq)
                 total_sub[simplify_shifts(eq.lhs)] = eq.rhs
+                dx_sub[simplify_shifts(eq.lhs)] = eq.rhs
+                @show total_sub
+                @show eq
                 push!(diffeq_idxs, ieq)
                 push!(diff_vars, diff_to_var[iv])
                 continue
@@ -575,10 +593,10 @@ function solve_and_generate_equations!(state::TearingState, neweqs, total_sub, v
                 @warn "Tearing: solving $eq for $var is singular!"
             else
                 rhs = -b / a
-                neweq = var ~ Symbolics.fixpoint_sub(
+                neweq = var ~ simplify_shifts(Symbolics.fixpoint_sub(
                     simplify ?
                     Symbolics.simplify(rhs) : rhs,
-                    total_sub; operator = ModelingToolkit.Shift)
+                    dx_sub; operator = ModelingToolkit.Shift))
                 push!(subeqs, neweq)
                 push!(solved_equations, ieq)
                 push!(solved_variables, iv)
@@ -589,7 +607,7 @@ function solve_and_generate_equations!(state::TearingState, neweqs, total_sub, v
             if !(eq.lhs isa Number && eq.lhs == 0)
                 rhs = eq.rhs - eq.lhs
             end
-            push!(alge_eqs, 0 ~ Symbolics.fixpoint_sub(rhs, total_sub))
+            push!(alge_eqs, 0 ~ simplify_shifts(Symbolics.fixpoint_sub(rhs, total_sub)))
             push!(algeeq_idxs, ieq)
         end
     end
@@ -676,16 +694,19 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
     end
     neweqs = collect(equations(state))
     diff_to_var = invview(var_to_diff)
-    total_sub = Dict()
-    dummy_sub = Dict()
     is_discrete = is_only_discrete(state.structure)
 
+    shift_sub = Dict() 
+
     # Structural simplification 
+    dummy_sub = Dict()
     substitute_dummy_derivatives!(state, neweqs, dummy_sub, var_eq_matching)
-    generate_derivative_variables!(state, neweqs, total_sub, var_eq_matching; 
-                                   is_discrete, mm)
+
+    generate_derivative_variables!(state, neweqs, var_eq_matching; 
+                                   is_discrete, mm, shift_sub)
+
     new_fullvars, new_var_to_diff, new_eq_to_diff, neweqs, subeqs, graph = 
-        solve_and_generate_equations!(state, neweqs, total_sub, var_eq_matching; simplify)
+        solve_and_generate_equations!(state, neweqs, var_eq_matching; simplify, shift_sub)
 
     # Update system
     var_to_diff = new_var_to_diff
@@ -701,12 +722,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
         i -> (!isempty(𝑑neighbors(graph, i)) ||
               (var_to_diff[i] !== nothing && !isempty(𝑑neighbors(graph, var_to_diff[i]))))
     end
-    @show ispresent.(collect(1:length(fullvars)))
-    @show 𝑑neighbors(graph, 5)
-    @show var_to_diff[5]
 
-    @show neweqs
-    @show fullvars
     sys = state.sys
     obs_sub = dummy_sub
     for eq in neweqs
diff --git a/src/structural_transformation/utils.jl b/src/structural_transformation/utils.jl
index b2d67dc547..c202556906 100644
--- a/src/structural_transformation/utils.jl
+++ b/src/structural_transformation/utils.jl
@@ -449,10 +449,9 @@ end
 ### Misc
 ###
 
-function lower_varname_withshift(var, iv, backshift; unshifted = nothing)
-    backshift < 0 && return Shift(iv, -backshift)(var)
-    backshift == 0 && return unshifted 
-    ds = "$iv-$backshift"
+function lower_varname_withshift(var, iv, backshift; unshifted = nothing, allow_zero = true)
+    backshift <= 0 && return Shift(iv, -backshift)(unshifted, allow_zero)
+    ds = backshift > 0 ? "$iv-$backshift" : "$iv+$(-backshift)"
     d_separator = 'ˍ'
 
     if ModelingToolkit.isoperator(var, ModelingToolkit.Shift)
@@ -475,9 +474,15 @@ function isdoubleshift(var)
            ModelingToolkit.isoperator(arguments(var)[1], ModelingToolkit.Shift)
 end
 
+### Rules
+# 1. x(t) -> x(t)
+# 2. Shift(t, 0)(x(t)) -> x(t)
+# 3. Shift(t, 1)(x + z) -> Shift(t, 1)(x) + Shift(t, 1)(z)
+
 function simplify_shifts(var)
     ModelingToolkit.hasshift(var) || return var
     var isa Equation && return simplify_shifts(var.lhs) ~ simplify_shifts(var.rhs)
+    ((op = operation(var)) isa Shift) && op.steps == 0 && return simplify_shifts(arguments(var)[1])
     if isdoubleshift(var)
         op1 = operation(var)
         vv1 = arguments(var)[1]

From 1c578c3535fe0d7af1d82d08285d90fd244ce2af Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Fri, 7 Feb 2025 14:46:09 -0500
Subject: [PATCH 075/111] rename functions

---
 .../symbolics_tearing.jl                      | 37 +++++++-----
 src/structural_transformation/utils.jl        | 60 ++++++++++++++++++-
 test/structural_transformation/utils.jl       | 17 ++++++
 3 files changed, 98 insertions(+), 16 deletions(-)

diff --git a/src/structural_transformation/symbolics_tearing.jl b/src/structural_transformation/symbolics_tearing.jl
index 020c85b05f..778473907e 100644
--- a/src/structural_transformation/symbolics_tearing.jl
+++ b/src/structural_transformation/symbolics_tearing.jl
@@ -254,7 +254,9 @@ function substitute_dummy_derivatives!(ts::TearingState, neweqs, dummy_sub, var_
     diff_to_var = invview(var_to_diff)
     iv = ModelingToolkit.has_iv(sys) ? ModelingToolkit.get_iv(sys) : nothing
 
+    @show neweqs
     for var in 1:length(fullvars)
+        #@show neweqs
         dv = var_to_diff[var]
         dv === nothing && continue
         if var_eq_matching[var] !== SelectedState()
@@ -286,9 +288,7 @@ function substitute_dummy_derivatives!(ts::TearingState, neweqs, dummy_sub, var_
     end
 end
 
-"""
-Generate new derivative variables for the system.
-
+#= 
 There are three cases where we want to generate new variables to convert
 the system into first order (semi-implicit) ODEs.
     
@@ -361,6 +361,16 @@ This is different from the continuous case where D(D(x)) can be substituted for
 by iteratively substituting x_t ~ D(x), then x_tt ~ D(x_t). For this reason the
 shift_sub dict is updated at the time that the renamed variables are written,
 inside the loop where new variables are generated.
+=#
+"""
+Generate new derivative variables for the system.
+
+Effects on the state: 
+- fullvars: add the new derivative variables x_t
+- neweqs: add the identity equations for the new variables, D(x) ~ x_t
+- graph: update graph with the new equations and variables, and their connections
+- solvable_graph:
+- var_eq_matching: solvable equations
 """
 function generate_derivative_variables!(ts::TearingState, neweqs, var_eq_matching;
         is_discrete = false, mm = nothing, shift_sub = nothing)
@@ -406,7 +416,7 @@ function generate_derivative_variables!(ts::TearingState, neweqs, var_eq_matchin
             dx = fullvars[dv]
             order, lv = var_order(diff_to_var, dv)
             @show fullvars[uv]
-            x_t = lower_name(fullvars[lv], iv, -low-order-1; unshifted = fullvars[uv])
+            x_t = lower_name(fullvars[lv], iv, -low-order-1; unshifted = fullvars[uv], allow_zero = true)
             shift_sub[dx] = x_t
             (var_eq_matching[dv] isa Int) ? continue : @goto DISCRETE_VARIABLE
         end
@@ -439,7 +449,7 @@ function generate_derivative_variables!(ts::TearingState, neweqs, var_eq_matchin
         x_t = lower_name(fullvars[lv], iv, order)
         
         @label DISCRETE_VARIABLE
-        push!(fullvars, x_t)
+        push!(fullvars, simplify_shifts(x_t))
         v_t = length(fullvars)
         v_t_idx = add_vertex!(var_to_diff)
         add_vertex!(graph, DST)
@@ -448,7 +458,6 @@ function generate_derivative_variables!(ts::TearingState, neweqs, var_eq_matchin
         add_vertex!(solvable_graph, DST)
         push!(var_eq_matching, unassigned)
 
-        # Add discrete substitutions to total_sub directly.  
         is_discrete && begin
             idx_to_lowest_shift[v_t] = idx_to_lowest_shift[dv]
             for e in 𝑑neighbors(graph, dv)
@@ -463,7 +472,7 @@ function generate_derivative_variables!(ts::TearingState, neweqs, var_eq_matchin
         end
 
         # add `D(x) - x_t ~ 0`
-        push!(neweqs, 0 ~ x_t - dx)
+        push!(neweqs, 0 ~ dx - x_t)
         add_vertex!(graph, SRC)
         dummy_eq = length(neweqs)
         add_edge!(graph, dummy_eq, dv)
@@ -478,12 +487,11 @@ function generate_derivative_variables!(ts::TearingState, neweqs, var_eq_matchin
     end
 end
 
-function add_solvable_variable!() 
+function add_solvable_variable!(state::TearingState)
     
 end
 
-function add_solvable_equation!() 
-    
+function add_solvable_equation!(s::SystemStructure, neweqs, eq)
 end
 
 """
@@ -500,7 +508,8 @@ such that the mass matrix is:
 Update the state to account for the new ordering and equations.
 
 ####### DISCRETE CASE
-- only substitute Shift(t, -2) 
+- Differential equations: substitute variables with everything shifted forward one timestep.
+- Algebraic and observable equations: substitute variables with everything shifted back one timestep.
 """
 function solve_and_generate_equations!(state::TearingState, neweqs, var_eq_matching; simplify = false, shift_sub = Dict())
     @unpack fullvars, sys, structure = state 
@@ -562,8 +571,6 @@ function solve_and_generate_equations!(state::TearingState, neweqs, var_eq_match
                 isnothing(D) &&
                     error("Differential found in a non-differential system. Likely this is a bug in the construction of an initialization system. Please report this issue with a reproducible example. Offending equation: $(equations(sys)[ieq])")
                 order, lv = var_order(diff_to_var, iv)
-                @show fullvars[lv]
-                @show simplify_shifts(fullvars[lv])
                 dx = D(simplify_shifts(fullvars[lv]))
                 eq = dx ~ simplify_shifts(Symbolics.fixpoint_sub(
                     Symbolics.symbolic_linear_solve(neweqs[ieq],
@@ -576,8 +583,6 @@ function solve_and_generate_equations!(state::TearingState, neweqs, var_eq_match
                 push!(diff_eqs, eq)
                 total_sub[simplify_shifts(eq.lhs)] = eq.rhs
                 dx_sub[simplify_shifts(eq.lhs)] = eq.rhs
-                @show total_sub
-                @show eq
                 push!(diffeq_idxs, ieq)
                 push!(diff_vars, diff_to_var[iv])
                 continue
@@ -611,6 +616,8 @@ function solve_and_generate_equations!(state::TearingState, neweqs, var_eq_match
             push!(algeeq_idxs, ieq)
         end
     end
+    @show neweqs
+    @show subeqs
 
     # TODO: BLT sorting
     neweqs = [diff_eqs; alge_eqs]
diff --git a/src/structural_transformation/utils.jl b/src/structural_transformation/utils.jl
index c202556906..b947fa4269 100644
--- a/src/structural_transformation/utils.jl
+++ b/src/structural_transformation/utils.jl
@@ -271,6 +271,8 @@ end
 
 function find_solvables!(state::TearingState; kwargs...)
     @assert state.structure.solvable_graph === nothing
+    println("in find_solvables")
+    @show eqs
     eqs = equations(state)
     graph = state.structure.graph
     state.structure.solvable_graph = BipartiteGraph(nsrcs(graph), ndsts(graph))
@@ -278,6 +280,7 @@ function find_solvables!(state::TearingState; kwargs...)
     for ieq in 1:length(eqs)
         find_eq_solvables!(state, ieq, to_rm; kwargs...)
     end
+    @show eqs
     return nothing
 end
 
@@ -477,7 +480,7 @@ end
 ### Rules
 # 1. x(t) -> x(t)
 # 2. Shift(t, 0)(x(t)) -> x(t)
-# 3. Shift(t, 1)(x + z) -> Shift(t, 1)(x) + Shift(t, 1)(z)
+# 3. Shift(t, 3)(Shift(t, 2)(x(t)) -> Shift(t, 5)(x(t))
 
 function simplify_shifts(var)
     ModelingToolkit.hasshift(var) || return var
@@ -498,3 +501,58 @@ function simplify_shifts(var)
             unwrap(var).metadata)
     end
 end
+
+"""
+Power expand the shifts. Used for substitution.
+
+Shift(t, -3)(x(t)) -> Shift(t, -1)(Shift(t, -1)(Shift(t, -1)(x)))
+"""
+function expand_shifts(var)
+    ModelingToolkit.hasshift(var) || return var
+    var = ModelingToolkit.value(var)
+
+    var isa Equation && return expand_shifts(var.lhs) ~ expand_shifts(var.rhs)
+    op = operation(var)
+    s = sign(op.steps)
+    arg = only(arguments(var))
+
+    if ModelingToolkit.isvariable(arg) && (ModelingToolkit.getvariabletype(arg) === VARIABLE) && isequal(op.t, only(arguments(arg)))
+        out = arg
+        for i in 1:op.steps
+            out = Shift(op.t, s)(out)
+        end
+        return out
+    elseif iscall(arg)
+        return maketerm(typeof(var), operation(var), expand_shifts.(arguments(var)),
+            unwrap(var).metadata)
+    else
+        return arg
+    end
+end
+
+"""
+Shift(t, 1)(x + z) -> Shift(t, 1)(x) + Shift(t, 1)(z)
+"""
+function distribute_shift(var) 
+    ModelingToolkit.hasshift(var) || return var
+    var isa Equation && return distribute_shift(var.lhs) ~ distribute_shift(var.rhs)
+    shift = operation(var)
+    expr = only(arguments(var))
+    _distribute_shift(expr, shift)
+end
+
+function _distribute_shift(expr, shift)
+    op = operation(expr)
+    args = arguments(expr)
+
+    if length(args) == 1 
+        if ModelingToolkit.isvariable(only(args)) && isequal(op.t, only(args))
+            return shift(only(args))
+        else
+            return only(args)
+        end
+    else iscall(op)
+        return maketerm(typeof(expr), operation(expr), _distribute_shift.(args, shift),
+            unwrap(var).metadata)
+    end
+end
diff --git a/test/structural_transformation/utils.jl b/test/structural_transformation/utils.jl
index 621aa8b8a8..98d986ebd4 100644
--- a/test/structural_transformation/utils.jl
+++ b/test/structural_transformation/utils.jl
@@ -4,6 +4,7 @@ using Graphs
 using SparseArrays
 using UnPack
 using ModelingToolkit: t_nounits as t, D_nounits as D
+const ST = StructuralTransformations
 
 # Define some variables
 @parameters L g
@@ -161,3 +162,19 @@ end
     structural_simplify(sys; additional_passes = [pass])
     @test value[] == 1
 end
+
+@testset "Shift simplification" begin
+    @variables x(t) y(t) z(t)
+    @parameters a b c
+    
+    # Expand shifts
+    @test isequal(ST.expand_shifts(Shift(t, -3)(x)), Shift(t, -1)(Shift(t, -1)(Shift(t, -1)(x))))
+    expr = a * Shift(t, -2)(x) + Shift(t, 2)(y) + b
+    @test isequal(ST.expand_shifts(expr), 
+                    a * Shift(t, -1)(Shift(t, -1)(x)) + Shift(t, 1)(Shift(t, 1)(y)) + b)
+    @test isequal(ST.expand_shifts(Shift(t, 2)(Shift(t, 1)(a))), a)
+
+
+    # Distribute shifts
+
+end

From 603c894044cd6589a0a15a7f8765369393ed9762 Mon Sep 17 00:00:00 2001
From: Christopher Rackauckas <accounts@chrisrackauckas.com>
Date: Mon, 10 Feb 2025 06:10:32 -0800
Subject: [PATCH 076/111] Update src/systems/diffeqs/odesystem.jl

---
 src/systems/diffeqs/odesystem.jl | 1 -
 1 file changed, 1 deletion(-)

diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl
index 44cc7df46b..5b8041ba33 100644
--- a/src/systems/diffeqs/odesystem.jl
+++ b/src/systems/diffeqs/odesystem.jl
@@ -711,6 +711,5 @@ function process_constraint_system(constraints::Vector{Equation}, sts, ps, iv; c
         end
     end
 
-    @show constraints
     ConstraintsSystem(constraints, collect(constraintsts), collect(constraintps); name = consname)
 end

From ed143a5d85d4b4660d478718fc9bcd9d79ec810c Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Mon, 10 Feb 2025 15:18:05 -0500
Subject: [PATCH 077/111] refactor: refactor tearing_assemble into functions

---
 .../symbolics_tearing.jl                      | 503 ++++++++----------
 src/structural_transformation/utils.jl        |  71 +--
 src/systems/systemstructure.jl                |   6 +-
 3 files changed, 239 insertions(+), 341 deletions(-)

diff --git a/src/structural_transformation/symbolics_tearing.jl b/src/structural_transformation/symbolics_tearing.jl
index 778473907e..b99e6ec21a 100644
--- a/src/structural_transformation/symbolics_tearing.jl
+++ b/src/structural_transformation/symbolics_tearing.jl
@@ -244,19 +244,17 @@ State selection may determine that some differential variables are
 algebraic variables in disguise. The derivative of such variables are
 called dummy derivatives.
 
-`SelectedState` information is no longer needed past here. State selection
-is done. All non-differentiated variables are algebraic variables, and all
-variables that appear differentiated are differential variables.
+`SelectedState` information is no longer needed after this function is called. 
+State selection is done. All non-differentiated variables are algebraic 
+variables, and all variables that appear differentiated are differential variables.
 """
-function substitute_dummy_derivatives!(ts::TearingState, neweqs, dummy_sub, var_eq_matching)
+function substitute_derivatives_algevars!(ts::TearingState, neweqs, dummy_sub, var_eq_matching)
     @unpack fullvars, sys, structure = ts
-    @unpack solvable_graph, var_to_diff, eq_to_diff, graph, lowest_shift = structure
+    @unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure
     diff_to_var = invview(var_to_diff)
     iv = ModelingToolkit.has_iv(sys) ? ModelingToolkit.get_iv(sys) : nothing
 
-    @show neweqs
     for var in 1:length(fullvars)
-        #@show neweqs
         dv = var_to_diff[var]
         dv === nothing && continue
         if var_eq_matching[var] !== SelectedState()
@@ -339,199 +337,137 @@ variables and equations, don't add them when they already exist.
 
 Documenting the differences to structural simplification for discrete systems:
 
-1. In discrete systems the lowest-order term is Shift(t, k)(x(t)), instead of x(t). We need to substitute the k-1 lowest order terms instead of the k-1 highest order terms.
+In discrete systems the lowest-order term is Shift(t, k)(x(t)), instead of x(t). 
+We want to substitute the k-1 lowest order terms instead of the k-1 highest order terms.
 
-The orders will also be off by one. The reason this is is that the dynamics of
-the system should be given in terms of Shift(t, 1)(x(t), x(t-1), ...). But 
-having the observables be indexed by the next time step is not so nice. So we 
-handle the shifts in the renaming.
-
-The substitution should look like the following: 
-  x(t) -> Shift(t, 1)(x(t))
-  Shift(t, -1)(x(t)) -> Shift(t, 0)(x(t))
-  Shift(t, -2)(x(t)) -> x_{t-1}(t)
-  Shift(t, -3)(x(t)) -> x_{t-2}(t)
-  and so on...
-
-In the implicit discrete case this shouldn't happen. The simplification should 
-look like a NonlinearSystem.
-
-2. For discrete systems Shift(t, 2)(x(t)) cannot be substituted as Shift(t, 1)(Shift(t,1)(x(t)). 
-This is different from the continuous case where D(D(x)) can be substituted for 
-by iteratively substituting x_t ~ D(x), then x_tt ~ D(x_t). For this reason the
-shift_sub dict is updated at the time that the renamed variables are written,
-inside the loop where new variables are generated.
+In the system x(k) ~ x(k-1) + x(k-2), we want to lower
+Shift(t, -1)(x(t)) -> x\_{t-1}(t)
 =#
 """
 Generate new derivative variables for the system.
 
-Effects on the state: 
+Effects on the system structure: 
 - fullvars: add the new derivative variables x_t
 - neweqs: add the identity equations for the new variables, D(x) ~ x_t
 - graph: update graph with the new equations and variables, and their connections
 - solvable_graph:
-- var_eq_matching: solvable equations
+- var_eq_matching: match D(x) to the added identity equation
 """
 function generate_derivative_variables!(ts::TearingState, neweqs, var_eq_matching;
-        is_discrete = false, mm = nothing, shift_sub = nothing)
+        is_discrete = false, mm = nothing)
     @unpack fullvars, sys, structure = ts
-    @unpack solvable_graph, var_to_diff, eq_to_diff, graph, lowest_shift = structure
+    @unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure
     eq_var_matching = invview(var_eq_matching)
     diff_to_var = invview(var_to_diff)
     iv = ModelingToolkit.has_iv(sys) ? ModelingToolkit.get_iv(sys) : nothing
     lower_name = is_discrete ? lower_varname_withshift : lower_varname_with_unit
 
-    # index v gets mapped to the lowest shift and the index of the unshifted variable
-    if is_discrete
-        idx_to_lowest_shift = Dict{Int, Tuple{Int, Int}}(var => (0,0) for var in 1:length(fullvars))
-        var_to_unshiftedidx = Dict{Any, Int}(var => findfirst(x -> isequal(x, var), fullvars) for var in keys(lowest_shift))
-
-        for (i,var) in enumerate(fullvars)
-            key = (operation(var) isa Shift) ? only(arguments(var)) : var
-            idx_to_lowest_shift[i] = (get(lowest_shift, key, 0), get(var_to_unshiftedidx, key, i))
-        end
-    end
-
     linear_eqs = mm === nothing ? Dict{Int, Int}() :
                  Dict(reverse(en) for en in enumerate(mm.nzrows))
 
-    # v is the index of the current variable, x = fullvars[v]
-    # dv is the index of the derivative dx = D(x), x_t is the substituted variable 
-    # For ODESystems: lv is the index of the lowest-order variable (x(t))
-    # For DiscreteSystems: 
-    # - lv is the index of the lowest-order variable (Shift(t, k)(x(t)))
-    # - uv is the index of the highest-order variable (x(t))
+    # Generate new derivative variables for all unsolved variables that have a derivative in the system 
     for v in 1:length(var_to_diff)
+        # Check if a derivative 1) exists and 2) is unsolved for
         dv = var_to_diff[v]
-
-        if is_discrete 
-            x = fullvars[v]
-            op = operation(x)
-            (low, uv) = idx_to_lowest_shift[v]
-
-            # If v is unshifted (i.e. x(t)), then substitute the lowest-shift variable
-            if !(op isa Shift)
-                dv = findfirst(_x -> isequal(_x, Shift(iv, low)(x)), fullvars)
-            end
-            dx = fullvars[dv]
-            order, lv = var_order(diff_to_var, dv)
-            @show fullvars[uv]
-            x_t = lower_name(fullvars[lv], iv, -low-order-1; unshifted = fullvars[uv], allow_zero = true)
-            shift_sub[dx] = x_t
-            (var_eq_matching[dv] isa Int) ? continue : @goto DISCRETE_VARIABLE
-        end
         dv isa Int || continue
         solved = var_eq_matching[dv] isa Int
         solved && continue
 
-        # Check if there's `D(x) = x_t` already
-        local v_t, dummy_eq
-        for eq in 𝑑neighbors(solvable_graph, dv)
-            mi = get(linear_eqs, eq, 0)
-            iszero(mi) && continue
-            row = @view mm[mi, :]
-            nzs = nonzeros(row)
-            rvs = SparseArrays.nonzeroinds(row)
-            # note that `v_t` must not be differentiated
-            if length(nzs) == 2 &&
-               (abs(nzs[1]) == 1 && nzs[1] == -nzs[2]) &&
-               (v_t = rvs[1] == dv ? rvs[2] : rvs[1];
-               diff_to_var[v_t] === nothing)
-                @assert dv in rvs
-                dummy_eq = eq
-                @goto FOUND_DUMMY_EQ
-            end
+        # If there's `D(x) = x_t` already, update mappings and continue without
+        # adding new equations/variables
+        dd = find_duplicate_dd(dv, lineareqs, mm)
+
+        if !isnothing(dd)
+            dummy_eq, v_t = dd
+            var_to_diff[v_t] = var_to_diff[dv]
+            var_eq_matching[dv] = unassigned
+            eq_var_matching[dummy_eq] = dv
+            continue
         end
 
-        # add `x_t`
         dx = fullvars[dv]
         order, lv = var_order(diff_to_var, dv)
-        x_t = lower_name(fullvars[lv], iv, order)
+        x_t = is_discrete ? lower_name(fullvars[lv], iv)
+                      : lower_name(fullvars[lv], iv, order)
         
-        @label DISCRETE_VARIABLE
-        push!(fullvars, simplify_shifts(x_t))
-        v_t = length(fullvars)
-        v_t_idx = add_vertex!(var_to_diff)
-        add_vertex!(graph, DST)
-        # TODO: do we care about solvable_graph? We don't use them after
-        # `dummy_derivative_graph`.
-        add_vertex!(solvable_graph, DST)
-        push!(var_eq_matching, unassigned)
-
-        is_discrete && begin
-            idx_to_lowest_shift[v_t] = idx_to_lowest_shift[dv]
-            for e in 𝑑neighbors(graph, dv)
-                add_edge!(graph, e, v_t)
-                rem_edge!(graph, e, dv)
-            end
-            # Do not add the lowest-order substitution as an equation, just substitute
-            !(operation(x) isa Shift) && begin
-                var_to_diff[v_t] = var_to_diff[dv]
-                continue
-            end
-        end
+        # Add `x_t` to the graph
+        add_dd_variable!(structure, x_t, dv)
+        # Add `D(x) - x_t ~ 0` to the graph
+        add_dd_equation!(structure, neweqs, 0 ~ dx - x_t, dv)
 
-        # add `D(x) - x_t ~ 0`
-        push!(neweqs, 0 ~ dx - x_t)
-        add_vertex!(graph, SRC)
-        dummy_eq = length(neweqs)
-        add_edge!(graph, dummy_eq, dv)
-        add_edge!(graph, dummy_eq, v_t)
-        add_vertex!(solvable_graph, SRC)
-        add_edge!(solvable_graph, dummy_eq, dv)
-
-        @label FOUND_DUMMY_EQ
-        var_to_diff[v_t] = var_to_diff[dv]
+        # Update matching
+        push!(var_eq_matching, unassigned)
         var_eq_matching[dv] = unassigned
         eq_var_matching[dummy_eq] = dv
     end
 end
 
-function add_solvable_variable!(state::TearingState)
-    
+"""
+Check if there's `D(x) = x_t` already. 
+"""
+function find_duplicate_dd(dv, lineareqs, mm)
+    for eq in 𝑑neighbors(solvable_graph, dv)
+        mi = get(linear_eqs, eq, 0)
+        iszero(mi) && continue
+        row = @view mm[mi, :]
+        nzs = nonzeros(row)
+        rvs = SparseArrays.nonzeroinds(row)
+        # note that `v_t` must not be differentiated
+        if length(nzs) == 2 &&
+           (abs(nzs[1]) == 1 && nzs[1] == -nzs[2]) &&
+           (v_t = rvs[1] == dv ? rvs[2] : rvs[1];
+           diff_to_var[v_t] === nothing)
+            @assert dv in rvs
+            return eq, v_t
+        end
+    end
+    return nothing 
+end
+
+function add_dd_variable!(s::SystemStructure, x_t, dv)
+    push!(s.fullvars, simplify_shifts(x_t))
+    v_t_idx = add_vertex!(s.var_to_diff)
+    @assert v_t_idx == ndsts(graph) == ndsts(solvable_graph) == length(fullvars) ==
+        length(var_eq_matching)
+    add_vertex!(s.graph, DST)
+    # TODO: do we care about solvable_graph? We don't use them after
+    # `dummy_derivative_graph`.
+    add_vertex!(s.solvable_graph, DST)
+    var_to_diff[v_t] = var_to_diff[dv]
 end
 
-function add_solvable_equation!(s::SystemStructure, neweqs, eq)
+# dv = index of D(x), v_t = index of x_t
+function add_dd_equation!(s::SystemStructure, neweqs, eq, dv)
+    push!(neweqs, eq)
+    add_vertex!(s.graph, SRC)
+    v_t = length(s.fullvars)
+    dummy_eq = length(neweqs)
+    add_edge!(s.graph, dummy_eq, dv)
+    add_edge!(s.graph, dummy_eq, v_t)
+    add_vertex!(s.solvable_graph, SRC)
+    add_edge!(s.solvable_graph, dummy_eq, dv)
 end
 
 """
 Solve the solvable equations of the system and generate differential (or discrete)
 equations in terms of the selected states.
-
-Will reorder equations and unknowns to be:
-   [diffeqs; ...]
-   [diffvars; ...]
-such that the mass matrix is:
-   [I  0
-    0  0].
-
-Update the state to account for the new ordering and equations.
-
-####### DISCRETE CASE
-- Differential equations: substitute variables with everything shifted forward one timestep.
-- Algebraic and observable equations: substitute variables with everything shifted back one timestep.
 """
-function solve_and_generate_equations!(state::TearingState, neweqs, var_eq_matching; simplify = false, shift_sub = Dict())
+function generate_system_equations!(state::TearingState, neweqs, var_eq_matching; simplify = false)
     @unpack fullvars, sys, structure = state 
-    @unpack solvable_graph, var_to_diff, eq_to_diff, graph, lowest_shift = structure
+    @unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure
     eq_var_matching = invview(var_eq_matching)
     diff_to_var = invview(var_to_diff)
-    dx_sub = Dict()
+    total_sub = Dict()
 
     if ModelingToolkit.has_iv(sys)
         iv = get_iv(sys)
         if is_only_discrete(structure)
             D = Shift(iv, 1)
-            lower_name = lower_varname_withshift
-            total_sub = shift_sub
         else
             D = Differential(iv)
-            lower_name = lower_varname_with_unit
-            total_sub = dx_sub
         end
     else
         iv = D = nothing
-        lower_name = lower_varname_with_unit
     end
 
     # if var is like D(x) or Shift(t, 1)(x)
@@ -544,14 +480,14 @@ function solve_and_generate_equations!(state::TearingState, neweqs, var_eq_match
         (eq, iv) -> eq isa Int && iv isa Int && BipartiteEdge(eq, iv) in solvable_graph
     end
 
-    diffeq_idxs = Int[]
-    algeeq_idxs = Int[]
     diff_eqs = Equation[]
-    alge_eqs = Equation[]
+    diffeq_idxs = Int[]
     diff_vars = Int[]
-    subeqs = Equation[]
-    solved_equations = Int[]
-    solved_variables = Int[]
+    alge_eqs = Equation[]
+    algeeq_idxs = Int[]
+    solved_eqs = Equation[]
+    solvedeq_idxs = Int[]
+    solved_vars = Int[]
 
     toporder = topological_sort(DiCMOBiGraph{false}(graph, var_eq_matching))
     eqs = Iterators.reverse(toporder)
@@ -561,100 +497,126 @@ function solve_and_generate_equations!(state::TearingState, neweqs, var_eq_match
     # fullvars[iv] is a differential variable of the form D^n(x), and neweqs[ieq]
     # is solved to give the RHS.
     for ieq in eqs
-        println()
         iv = eq_var_matching[ieq]
         if is_solvable(ieq, iv)
-            # We don't solve differential equations, but we will need to try to
-            # convert it into the mass matrix form.
-            # We cannot solve the differential variable like D(x)
             if isdervar(iv)
                 isnothing(D) &&
                     error("Differential found in a non-differential system. Likely this is a bug in the construction of an initialization system. Please report this issue with a reproducible example. Offending equation: $(equations(sys)[ieq])")
-                order, lv = var_order(diff_to_var, iv)
-                dx = D(simplify_shifts(fullvars[lv]))
-                eq = dx ~ simplify_shifts(Symbolics.fixpoint_sub(
-                    Symbolics.symbolic_linear_solve(neweqs[ieq],
-                        fullvars[iv]),
-                    total_sub; operator = ModelingToolkit.Shift))
-                for e in 𝑑neighbors(graph, iv)
-                    rem_edge!(graph, e, iv)
-                    add_edge!(graph, e, lv)
-                end
-                push!(diff_eqs, eq)
-                total_sub[simplify_shifts(eq.lhs)] = eq.rhs
-                dx_sub[simplify_shifts(eq.lhs)] = eq.rhs
-                push!(diffeq_idxs, ieq)
-                push!(diff_vars, diff_to_var[iv])
-                continue
-            end
-            eq = neweqs[ieq]
-            var = fullvars[iv]
-            residual = eq.lhs - eq.rhs
-            a, b, islinear = linear_expansion(residual, var)
-            @assert islinear
-            # 0 ~ a * var + b
-            # var ~ -b/a
-            if ModelingToolkit._iszero(a)
-                @warn "Tearing: solving $eq for $var is singular!"
+                add_differential_equation!(structure, iv, neweqs, ieq, 
+                                           diff_vars, diff_eqs, diffeq_idxs, total_sub)
             else
-                rhs = -b / a
-                neweq = var ~ simplify_shifts(Symbolics.fixpoint_sub(
-                    simplify ?
-                    Symbolics.simplify(rhs) : rhs,
-                    dx_sub; operator = ModelingToolkit.Shift))
-                push!(subeqs, neweq)
-                push!(solved_equations, ieq)
-                push!(solved_variables, iv)
+                add_solved_equation!(structure, iv, neweqs, ieq, 
+                                     solved_vars, solved_eqs, solvedeq_idxs, total_sub)
             end
         else
-            eq = neweqs[ieq]
-            rhs = eq.rhs
-            if !(eq.lhs isa Number && eq.lhs == 0)
-                rhs = eq.rhs - eq.lhs
-            end
-            push!(alge_eqs, 0 ~ simplify_shifts(Symbolics.fixpoint_sub(rhs, total_sub)))
-            push!(algeeq_idxs, ieq)
+            add_algebraic_equation!(structure, neweqs, ieq, 
+                                    alge_eqs, algeeq_idxs, total_sub)
         end
     end
-    @show neweqs
-    @show subeqs
 
-    # TODO: BLT sorting
+    # Generate new equations and orderings 
     neweqs = [diff_eqs; alge_eqs]
-    inveqsperm = [diffeq_idxs; algeeq_idxs]
-    eqsperm = zeros(Int, nsrcs(graph))
-    for (i, v) in enumerate(inveqsperm)
-        eqsperm[v] = i
-    end
+    eq_ordering = [diffeq_idxs; algeeq_idxs]
     diff_vars_set = BitSet(diff_vars)
     if length(diff_vars_set) != length(diff_vars)
         error("Tearing internal error: lowering DAE into semi-implicit ODE failed!")
     end
-    solved_variables_set = BitSet(solved_variables)
-    invvarsperm = [diff_vars;
+    solved_vars_set = BitSet(solved_vars)
+    var_ordering = [diff_vars;
                    setdiff!(setdiff(1:ndsts(graph), diff_vars_set),
-                       solved_variables_set)]
+                       solved_vars_set)]
+
+    return neweqs, solved_eqs, eq_ordering, var_ordering, length(solved_vars), length(solved_vars_set)
+end
+
+function add_differential_equation!(s::SystemStructure, iv, neweqs, ieq, diff_vars, diff_eqs, diffeqs_idxs, total_sub)
+    diff_to_var = invview(s.var_to_diff)
+
+    order, lv = var_order(diff_to_var, iv)
+    dx = D(simplify_shifts(fullvars[lv]))
+    eq = dx ~ simplify_shifts(Symbolics.fixpoint_sub(
+        Symbolics.symbolic_linear_solve(neweqs[ieq],
+            fullvars[iv]),
+        total_sub; operator = ModelingToolkit.Shift))
+    for e in 𝑑neighbors(s.graph, iv)
+        e == ieq && continue
+        rem_edge!(s.graph, e, iv)
+    end
+
+    push!(diff_eqs, eq)
+    total_sub[simplify_shifts(eq.lhs)] = eq.rhs
+    push!(diffeq_idxs, ieq)
+    push!(diff_vars, diff_to_var[iv])
+end
+
+function add_algebraic_equation!(s::SystemStructure, neweqs, ieq, alge_eqs, algeeq_idxs, total_sub)
+    eq = neweqs[ieq]
+    rhs = eq.rhs
+    if !(eq.lhs isa Number && eq.lhs == 0)
+        rhs = eq.rhs - eq.lhs
+    end
+    push!(alge_eqs, 0 ~ simplify_shifts(Symbolics.fixpoint_sub(rhs, total_sub)))
+    push!(algeeq_idxs, ieq)
+end
+
+function add_solved_equation!(s::SystemStructure, iv, neweqs, ieq, solved_vars, solved_eqs, solvedeq_idxs, total_sub)
+    eq = neweqs[ieq]
+    var = fullvars[iv]
+    residual = eq.lhs - eq.rhs
+    a, b, islinear = linear_expansion(residual, var)
+    @assert islinear
+    # 0 ~ a * var + b
+    # var ~ -b/a
+    if ModelingToolkit._iszero(a)
+        @warn "Tearing: solving $eq for $var is singular!"
+    else
+        rhs = -b / a
+        neweq = var ~ simplify_shifts(Symbolics.fixpoint_sub(
+            simplify ?
+            Symbolics.simplify(rhs) : rhs,
+            total_sub; operator = ModelingToolkit.Shift))
+        push!(solved_eqs, neweq)
+        push!(solvedeq_idxs, ieq)
+        push!(solved_vars, iv)
+    end
+end
+
+"""
+Reorder the equations and unknowns to be:
+   [diffeqs; ...]
+   [diffvars; ...]
+such that the mass matrix is:
+   [I  0
+    0  0].
+
+Update the state to account for the new ordering and equations.
+"""
+# TODO: BLT sorting
+function reorder_vars!(s::SystemStructure, var_eq_matching, eq_ordering, var_ordering, nelim_eq, nelim_var)
+    @unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure
+
+    eqsperm = zeros(Int, nsrcs(graph))
+    for (i, v) in enumerate(eq_ordering)
+        eqsperm[v] = i
+    end
     varsperm = zeros(Int, ndsts(graph))
-    for (i, v) in enumerate(invvarsperm)
+    for (i, v) in enumerate(var_ordering)
         varsperm[v] = i
     end
 
-    deps = Vector{Int}[i == 1 ? Int[] : collect(1:(i - 1))
-                       for i in 1:length(solved_equations)]
-
     # Contract the vertices in the structure graph to make the structure match
     # the new reality of the system we've just created.
-    graph = contract_variables(graph, var_eq_matching, varsperm, eqsperm,
-        length(solved_variables), length(solved_variables_set))
+    new_graph = contract_variables(graph, var_eq_matching, varsperm, eqsperm,
+        nelim_eq, nelim_var)
 
-    new_var_to_diff = complete(DiffGraph(length(invvarsperm)))
+    new_var_to_diff = complete(DiffGraph(length(var_ordering)))
     for (v, d) in enumerate(var_to_diff)
         v′ = varsperm[v]
         (v′ > 0 && d !== nothing) || continue
         d′ = varsperm[d]
         new_var_to_diff[v′] = d′ > 0 ? d′ : nothing
     end
-    new_eq_to_diff = complete(DiffGraph(length(inveqsperm)))
+    new_eq_to_diff = complete(DiffGraph(length(eq_ordering)))
     for (v, d) in enumerate(eq_to_diff)
         v′ = eqsperm[v]
         (v′ > 0 && d !== nothing) || continue
@@ -663,7 +625,53 @@ function solve_and_generate_equations!(state::TearingState, neweqs, var_eq_match
     end
     new_fullvars = fullvars[invvarsperm]
 
-    new_fullvars, new_var_to_diff, new_eq_to_diff, neweqs, subeqs, graph
+    # Update system structure 
+    @set! state.structure.graph = complete(new_graph)
+    @set! state.structure.var_to_diff = new_var_to_diff
+    @set! state.structure.eq_to_diff = new_eq_to_diff
+    @set! state.fullvars = new_fullvars
+end
+
+"""
+Set the system equations, unknowns, observables post-tearing.
+"""
+function update_simplified_system!(state::TearingState, neweqs, solved_eqs, dummy_sub, var_eq_matching, extra_unknowns)
+    @unpack solvable_graph, var_to_diff, eq_to_diff, graph = state.structure
+    diff_to_var = invview(var_to_diff)
+
+    ispresent = let var_to_diff = var_to_diff, graph = graph
+        i -> (!isempty(𝑑neighbors(graph, i)) ||
+              (var_to_diff[i] !== nothing && !isempty(𝑑neighbors(graph, var_to_diff[i]))))
+    end
+
+    sys = state.sys
+    obs_sub = dummy_sub
+    for eq in neweqs
+        isdiffeq(eq) || continue
+        obs_sub[eq.lhs] = eq.rhs
+    end
+    # TODO: compute the dependency correctly so that we don't have to do this
+    obs = [fast_substitute(observed(sys), obs_sub); solved_eqs]
+
+    unknowns = Any[v
+                   for (i, v) in enumerate(state.fullvars)
+                   if diff_to_var[i] === nothing && ispresent(i)]
+    unknowns = [unknowns; extra_unknowns]
+    @set! sys.unknowns = unknowns
+
+    obs, subeqs, deps = cse_and_array_hacks(
+        sys, obs, solved_eqs, unknowns, neweqs; cse = cse_hack, array = array_hack)
+
+    @set! sys.eqs = neweqs
+    @set! sys.observed = obs
+    @set! sys.substitutions = Substitutions(subeqs, deps)
+
+    # Only makes sense for time-dependent
+    # TODO: generalize to SDE
+    if sys isa ODESystem
+        @set! sys.schedule = Schedule(var_eq_matching, dummy_sub)
+    end
+    sys = schedule(sys)
 end
 
 # Terminology and Definition:
@@ -675,10 +683,8 @@ end
 # appear in the system. Algebraic variables are variables that are not
 # differential variables.
     
-import ModelingToolkit: Shift
-
 # Give the order of the variable indexed by dv
-function var_order(diff_to_var, dv) 
+function var_order(diff_to_var, dv)
     order = 0
     while (dv′ = diff_to_var[dv]) !== nothing
         order += 1
@@ -689,8 +695,7 @@ end
 
 function tearing_reassemble(state::TearingState, var_eq_matching,
         full_var_eq_matching = nothing; simplify = false, mm = nothing, cse_hack = true, array_hack = true)
-    @unpack fullvars, sys, structure = state
-    @unpack solvable_graph, var_to_diff, eq_to_diff, graph, lowest_shift = structure
+
     extra_vars = Int[]
     if full_var_eq_matching !== nothing
         for v in 𝑑vertices(state.structure.graph)
@@ -699,69 +704,21 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
             push!(extra_vars, v)
         end
     end
+    extra_unknowns = fullvars[extra_vars]
     neweqs = collect(equations(state))
-    diff_to_var = invview(var_to_diff)
-    is_discrete = is_only_discrete(state.structure)
-
-    shift_sub = Dict() 
 
     # Structural simplification 
     dummy_sub = Dict()
-    substitute_dummy_derivatives!(state, neweqs, dummy_sub, var_eq_matching)
-
-    generate_derivative_variables!(state, neweqs, var_eq_matching; 
-                                   is_discrete, mm, shift_sub)
-
-    new_fullvars, new_var_to_diff, new_eq_to_diff, neweqs, subeqs, graph = 
-        solve_and_generate_equations!(state, neweqs, var_eq_matching; simplify, shift_sub)
-
-    # Update system
-    var_to_diff = new_var_to_diff
-    eq_to_diff = new_eq_to_diff
-    diff_to_var = invview(var_to_diff)
-
-    old_fullvars = fullvars
-    @set! state.structure.graph = complete(graph)
-    @set! state.structure.var_to_diff = var_to_diff
-    @set! state.structure.eq_to_diff = eq_to_diff
-    @set! state.fullvars = fullvars = new_fullvars
-    ispresent = let var_to_diff = var_to_diff, graph = graph
-        i -> (!isempty(𝑑neighbors(graph, i)) ||
-              (var_to_diff[i] !== nothing && !isempty(𝑑neighbors(graph, var_to_diff[i]))))
-    end
-
-    sys = state.sys
-    obs_sub = dummy_sub
-    for eq in neweqs
-        isdiffeq(eq) || continue
-        obs_sub[eq.lhs] = eq.rhs
-    end
-    # TODO: compute the dependency correctly so that we don't have to do this
-    obs = [fast_substitute(observed(sys), obs_sub); subeqs]
+    substitute_derivatives_algevars!(state, neweqs, dummy_sub, var_eq_matching)
 
-    unknowns = Any[v
-                   for (i, v) in enumerate(fullvars)
-                   if diff_to_var[i] === nothing && ispresent(i)]
-    if !isempty(extra_vars)
-        for v in extra_vars
-            push!(unknowns, old_fullvars[v])
-        end
-    end
-    @set! sys.unknowns = unknowns
+    generate_derivative_variables!(state, neweqs, var_eq_matching; mm)
 
-    obs, subeqs, deps = cse_and_array_hacks(
-        sys, obs, subeqs, unknowns, neweqs; cse = cse_hack, array = array_hack)
+    neweqs, solved_eqs, eq_ordering, var_ordering, nelim_eq, nelim_var = 
+        generate_system_equations!(state, neweqs, var_eq_matching; simplify)
 
-    @set! sys.eqs = neweqs
-    @set! sys.observed = obs
-    @set! sys.substitutions = Substitutions(subeqs, deps)
+    reorder_vars!(state.structure, var_eq_matching, eq_ordering, var_ordering, nelim_eq, nelim_var)
 
-    # Only makes sense for time-dependent
-    # TODO: generalize to SDE
-    if sys isa ODESystem
-        @set! sys.schedule = Schedule(var_eq_matching, dummy_sub)
-    end
-    sys = schedule(sys)
+    sys = update_simplified_system!(state, neweqs, solved_eqs, dummy_sub, var_eq_matching, extra_unknowns)
     @set! state.sys = sys
     @set! sys.tearing_state = state
     return invalidate_cache!(sys)
diff --git a/src/structural_transformation/utils.jl b/src/structural_transformation/utils.jl
index b947fa4269..77f0b62a1c 100644
--- a/src/structural_transformation/utils.jl
+++ b/src/structural_transformation/utils.jl
@@ -452,9 +452,13 @@ end
 ### Misc
 ###
 
-function lower_varname_withshift(var, iv, backshift; unshifted = nothing, allow_zero = true)
-    backshift <= 0 && return Shift(iv, -backshift)(unshifted, allow_zero)
-    ds = backshift > 0 ? "$iv-$backshift" : "$iv+$(-backshift)"
+function lower_varname_withshift(var, iv)
+    op = operation(var)
+    op isa Shift || return var
+    backshift = op.steps
+    backshift > 0 && return var
+
+    ds = "$iv-$(-backshift)"
     d_separator = 'ˍ'
 
     if ModelingToolkit.isoperator(var, ModelingToolkit.Shift)
@@ -477,15 +481,9 @@ function isdoubleshift(var)
            ModelingToolkit.isoperator(arguments(var)[1], ModelingToolkit.Shift)
 end
 
-### Rules
-# 1. x(t) -> x(t)
-# 2. Shift(t, 0)(x(t)) -> x(t)
-# 3. Shift(t, 3)(Shift(t, 2)(x(t)) -> Shift(t, 5)(x(t))
-
 function simplify_shifts(var)
     ModelingToolkit.hasshift(var) || return var
     var isa Equation && return simplify_shifts(var.lhs) ~ simplify_shifts(var.rhs)
-    ((op = operation(var)) isa Shift) && op.steps == 0 && return simplify_shifts(arguments(var)[1])
     if isdoubleshift(var)
         op1 = operation(var)
         vv1 = arguments(var)[1]
@@ -501,58 +499,3 @@ function simplify_shifts(var)
             unwrap(var).metadata)
     end
 end
-
-"""
-Power expand the shifts. Used for substitution.
-
-Shift(t, -3)(x(t)) -> Shift(t, -1)(Shift(t, -1)(Shift(t, -1)(x)))
-"""
-function expand_shifts(var)
-    ModelingToolkit.hasshift(var) || return var
-    var = ModelingToolkit.value(var)
-
-    var isa Equation && return expand_shifts(var.lhs) ~ expand_shifts(var.rhs)
-    op = operation(var)
-    s = sign(op.steps)
-    arg = only(arguments(var))
-
-    if ModelingToolkit.isvariable(arg) && (ModelingToolkit.getvariabletype(arg) === VARIABLE) && isequal(op.t, only(arguments(arg)))
-        out = arg
-        for i in 1:op.steps
-            out = Shift(op.t, s)(out)
-        end
-        return out
-    elseif iscall(arg)
-        return maketerm(typeof(var), operation(var), expand_shifts.(arguments(var)),
-            unwrap(var).metadata)
-    else
-        return arg
-    end
-end
-
-"""
-Shift(t, 1)(x + z) -> Shift(t, 1)(x) + Shift(t, 1)(z)
-"""
-function distribute_shift(var) 
-    ModelingToolkit.hasshift(var) || return var
-    var isa Equation && return distribute_shift(var.lhs) ~ distribute_shift(var.rhs)
-    shift = operation(var)
-    expr = only(arguments(var))
-    _distribute_shift(expr, shift)
-end
-
-function _distribute_shift(expr, shift)
-    op = operation(expr)
-    args = arguments(expr)
-
-    if length(args) == 1 
-        if ModelingToolkit.isvariable(only(args)) && isequal(op.t, only(args))
-            return shift(only(args))
-        else
-            return only(args)
-        end
-    else iscall(op)
-        return maketerm(typeof(expr), operation(expr), _distribute_shift.(args, shift),
-            unwrap(var).metadata)
-    end
-end
diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl
index 89fcebf549..e5a227d9fb 100644
--- a/src/systems/systemstructure.jl
+++ b/src/systems/systemstructure.jl
@@ -154,7 +154,6 @@ Base.@kwdef mutable struct SystemStructure
     var_types::Union{Vector{VariableType}, Nothing}
     """Whether the system is discrete."""
     only_discrete::Bool
-    lowest_shift::Union{Dict, Nothing}
 end
 
 function Base.copy(structure::SystemStructure)
@@ -436,7 +435,7 @@ function TearingState(sys; quick_cancel = false, check = true)
 
     ts = TearingState(sys, fullvars,
         SystemStructure(complete(var_to_diff), complete(eq_to_diff),
-            complete(graph), nothing, var_types, sys isa DiscreteSystem, lowest_shift),
+            complete(graph), nothing, var_types, sys isa DiscreteSystem),
         Any[])
     if sys isa DiscreteSystem
         ts = shift_discrete_system(ts, lowest_shift)
@@ -466,9 +465,8 @@ end
     Shift variable x by the largest shift s such that x(k-s) appears in the system of equations.
     The lowest-shift term will have.
 """
-function shift_discrete_system(ts::TearingState, lowest_shift)
+function shift_discrete_system(ts::TearingState)
     @unpack fullvars, sys = ts
-    return ts
     discvars = OrderedSet()
     eqs = equations(sys)
 

From 240ab215b2ec457186b889ca53160e9d23db02e8 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Mon, 10 Feb 2025 18:38:22 -0500
Subject: [PATCH 078/111] fix: properly rename variables inside
 generate_system_equations

---
 .../symbolics_tearing.jl                      | 174 ++++++++++--------
 src/structural_transformation/utils.jl        |  14 +-
 src/systems/systemstructure.jl                |  16 +-
 3 files changed, 119 insertions(+), 85 deletions(-)

diff --git a/src/structural_transformation/symbolics_tearing.jl b/src/structural_transformation/symbolics_tearing.jl
index b99e6ec21a..74205fcf86 100644
--- a/src/structural_transformation/symbolics_tearing.jl
+++ b/src/structural_transformation/symbolics_tearing.jl
@@ -248,7 +248,7 @@ called dummy derivatives.
 State selection is done. All non-differentiated variables are algebraic 
 variables, and all variables that appear differentiated are differential variables.
 """
-function substitute_derivatives_algevars!(ts::TearingState, neweqs, dummy_sub, var_eq_matching)
+function substitute_derivatives_algevars!(ts::TearingState, neweqs, var_eq_matching, dummy_sub)
     @unpack fullvars, sys, structure = ts
     @unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure
     diff_to_var = invview(var_to_diff)
@@ -353,30 +353,32 @@ Effects on the system structure:
 - solvable_graph:
 - var_eq_matching: match D(x) to the added identity equation
 """
-function generate_derivative_variables!(ts::TearingState, neweqs, var_eq_matching;
-        is_discrete = false, mm = nothing)
+function generate_derivative_variables!(ts::TearingState, neweqs, var_eq_matching; mm = nothing)
     @unpack fullvars, sys, structure = ts
     @unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure
     eq_var_matching = invview(var_eq_matching)
     diff_to_var = invview(var_to_diff)
     iv = ModelingToolkit.has_iv(sys) ? ModelingToolkit.get_iv(sys) : nothing
-    lower_name = is_discrete ? lower_varname_withshift : lower_varname_with_unit
-
+    is_discrete = is_only_discrete(structure)
+    lower_varname = is_discrete ? lower_shift_varname : lower_varname_with_unit
     linear_eqs = mm === nothing ? Dict{Int, Int}() :
                  Dict(reverse(en) for en in enumerate(mm.nzrows))
 
-    # Generate new derivative variables for all unsolved variables that have a derivative in the system 
+    # For variable x, make dummy derivative x_t if the
+    # derivative is in the system
     for v in 1:length(var_to_diff)
-        # Check if a derivative 1) exists and 2) is unsolved for
         dv = var_to_diff[v]
+        # For discrete systems, directly substitute lowest-order variable 
+        if is_discrete && diff_to_var[v] == nothing
+            fullvars[v] = lower_varname(fullvars[v], iv)
+        end
         dv isa Int || continue
         solved = var_eq_matching[dv] isa Int
         solved && continue
 
         # If there's `D(x) = x_t` already, update mappings and continue without
         # adding new equations/variables
-        dd = find_duplicate_dd(dv, lineareqs, mm)
-
+        dd = find_duplicate_dd(dv, solvable_graph, linear_eqs, mm)
         if !isnothing(dd)
             dummy_eq, v_t = dd
             var_to_diff[v_t] = var_to_diff[dv]
@@ -386,26 +388,25 @@ function generate_derivative_variables!(ts::TearingState, neweqs, var_eq_matchin
         end
 
         dx = fullvars[dv]
-        order, lv = var_order(diff_to_var, dv)
-        x_t = is_discrete ? lower_name(fullvars[lv], iv)
-                      : lower_name(fullvars[lv], iv, order)
-        
+        order, lv = var_order(dv, diff_to_var)
+        x_t = is_discrete ? lower_varname(fullvars[dv], iv) : lower_varname(fullvars[lv], iv, order)
+
         # Add `x_t` to the graph
-        add_dd_variable!(structure, x_t, dv)
+        v_t = add_dd_variable!(structure, fullvars, x_t, dv)
         # Add `D(x) - x_t ~ 0` to the graph
-        add_dd_equation!(structure, neweqs, 0 ~ dx - x_t, dv)
+        dummy_eq = add_dd_equation!(structure, neweqs, 0 ~ dx - x_t, dv, v_t)
 
         # Update matching
         push!(var_eq_matching, unassigned)
         var_eq_matching[dv] = unassigned
-        eq_var_matching[dummy_eq] = dv
+        eq_var_matching[dummy_eq] = dv 
     end
 end
 
 """
-Check if there's `D(x) = x_t` already. 
+Check if there's `D(x) = x_t` already.
 """
-function find_duplicate_dd(dv, lineareqs, mm)
+function find_duplicate_dd(dv, solvable_graph, linear_eqs, mm)
     for eq in 𝑑neighbors(solvable_graph, dv)
         mi = get(linear_eqs, eq, 0)
         iszero(mi) && continue
@@ -424,28 +425,28 @@ function find_duplicate_dd(dv, lineareqs, mm)
     return nothing 
 end
 
-function add_dd_variable!(s::SystemStructure, x_t, dv)
-    push!(s.fullvars, simplify_shifts(x_t))
+function add_dd_variable!(s::SystemStructure, fullvars, x_t, dv)
+    push!(fullvars, simplify_shifts(x_t))
+    v_t = length(fullvars)
     v_t_idx = add_vertex!(s.var_to_diff)
-    @assert v_t_idx == ndsts(graph) == ndsts(solvable_graph) == length(fullvars) ==
-        length(var_eq_matching)
     add_vertex!(s.graph, DST)
     # TODO: do we care about solvable_graph? We don't use them after
     # `dummy_derivative_graph`.
     add_vertex!(s.solvable_graph, DST)
-    var_to_diff[v_t] = var_to_diff[dv]
+    s.var_to_diff[v_t] = s.var_to_diff[dv]
+    v_t
 end
 
 # dv = index of D(x), v_t = index of x_t
-function add_dd_equation!(s::SystemStructure, neweqs, eq, dv)
+function add_dd_equation!(s::SystemStructure, neweqs, eq, dv, v_t)
     push!(neweqs, eq)
     add_vertex!(s.graph, SRC)
-    v_t = length(s.fullvars)
     dummy_eq = length(neweqs)
     add_edge!(s.graph, dummy_eq, dv)
     add_edge!(s.graph, dummy_eq, v_t)
     add_vertex!(s.solvable_graph, SRC)
     add_edge!(s.solvable_graph, dummy_eq, dv)
+    dummy_eq
 end
 
 """
@@ -463,6 +464,10 @@ function generate_system_equations!(state::TearingState, neweqs, var_eq_matching
         iv = get_iv(sys)
         if is_only_discrete(structure)
             D = Shift(iv, 1)
+            for v in fullvars
+                op = operation(v)
+                op isa Shift && (op.steps < 0) && (total_sub[v] = lower_shift_varname(v, iv))
+            end
         else
             D = Differential(iv)
         end
@@ -493,24 +498,40 @@ function generate_system_equations!(state::TearingState, neweqs, var_eq_matching
     eqs = Iterators.reverse(toporder)
     idep = iv
 
-    # Generate differential equations.
-    # fullvars[iv] is a differential variable of the form D^n(x), and neweqs[ieq]
-    # is solved to give the RHS.
+    # Generate equations.
+    #   Solvable equations of differential variables D(x) become differential equations
+    #   Solvable equations of non-differential variables become observable equations
+    #   Non-solvable equations become algebraic equations.
     for ieq in eqs
         iv = eq_var_matching[ieq]
-        if is_solvable(ieq, iv)
-            if isdervar(iv)
-                isnothing(D) &&
-                    error("Differential found in a non-differential system. Likely this is a bug in the construction of an initialization system. Please report this issue with a reproducible example. Offending equation: $(equations(sys)[ieq])")
-                add_differential_equation!(structure, iv, neweqs, ieq, 
-                                           diff_vars, diff_eqs, diffeq_idxs, total_sub)
-            else
-                add_solved_equation!(structure, iv, neweqs, ieq, 
-                                     solved_vars, solved_eqs, solvedeq_idxs, total_sub)
+        var = fullvars[iv]
+        eq = neweqs[ieq]
+
+        if is_solvable(ieq, iv) && isdervar(iv)
+            isnothing(D) && throw(UnexpectedDifferentialError(equations(sys)[ieq]))
+            order, lv = var_order(iv, diff_to_var)
+            dx = D(simplify_shifts(fullvars[lv]))
+
+            neweq = make_differential_equation(var, dx, eq, total_sub)
+            for e in 𝑑neighbors(graph, iv)
+                e == ieq && continue
+                rem_edge!(graph, e, iv)
+            end
+
+            push!(diff_eqs, neweq)
+            push!(diffeq_idxs, ieq)
+            push!(diff_vars, diff_to_var[iv])
+        elseif is_solvable(ieq, iv)
+            neweq = make_solved_equation(var, eq, total_sub; simplify)
+            !isnothing(neweq) && begin
+                push!(solved_eqs, neweq)
+                push!(solvedeq_idxs, ieq)
+                push!(solved_vars, iv)
             end
         else
-            add_algebraic_equation!(structure, neweqs, ieq, 
-                                    alge_eqs, algeeq_idxs, total_sub)
+            neweq = make_algebraic_equation(var, eq, total_sub)
+            push!(alge_eqs, neweq)
+            push!(algeeq_idxs, ieq)
         end
     end
 
@@ -529,39 +550,29 @@ function generate_system_equations!(state::TearingState, neweqs, var_eq_matching
     return neweqs, solved_eqs, eq_ordering, var_ordering, length(solved_vars), length(solved_vars_set)
 end
 
-function add_differential_equation!(s::SystemStructure, iv, neweqs, ieq, diff_vars, diff_eqs, diffeqs_idxs, total_sub)
-    diff_to_var = invview(s.var_to_diff)
+struct UnexpectedDifferentialError
+    eq::Equation
+end
 
-    order, lv = var_order(diff_to_var, iv)
-    dx = D(simplify_shifts(fullvars[lv]))
-    eq = dx ~ simplify_shifts(Symbolics.fixpoint_sub(
-        Symbolics.symbolic_linear_solve(neweqs[ieq],
-            fullvars[iv]),
-        total_sub; operator = ModelingToolkit.Shift))
-    for e in 𝑑neighbors(s.graph, iv)
-        e == ieq && continue
-        rem_edge!(s.graph, e, iv)
-    end
+function Base.showerror(io::IO, err::UnexpectedDifferentialError)
+    error("Differential found in a non-differential system. Likely this is a bug in the construction of an initialization system. Please report this issue with a reproducible example. Offending equation: $(err.eq)")
+end
 
-    push!(diff_eqs, eq)
-    total_sub[simplify_shifts(eq.lhs)] = eq.rhs
-    push!(diffeq_idxs, ieq)
-    push!(diff_vars, diff_to_var[iv])
+function make_differential_equation(var, dx, eq, total_sub)
+    dx ~ simplify_shifts(Symbolics.fixpoint_sub(
+        Symbolics.symbolic_linear_solve(eq, var),
+        total_sub; operator = ModelingToolkit.Shift))
 end
 
-function add_algebraic_equation!(s::SystemStructure, neweqs, ieq, alge_eqs, algeeq_idxs, total_sub)
-    eq = neweqs[ieq]
+function make_algebraic_equation(var, eq, total_sub)
     rhs = eq.rhs
     if !(eq.lhs isa Number && eq.lhs == 0)
         rhs = eq.rhs - eq.lhs
     end
-    push!(alge_eqs, 0 ~ simplify_shifts(Symbolics.fixpoint_sub(rhs, total_sub)))
-    push!(algeeq_idxs, ieq)
+    0 ~ simplify_shifts(Symbolics.fixpoint_sub(rhs, total_sub))
 end
 
-function add_solved_equation!(s::SystemStructure, iv, neweqs, ieq, solved_vars, solved_eqs, solvedeq_idxs, total_sub)
-    eq = neweqs[ieq]
-    var = fullvars[iv]
+function make_solved_equation(var, eq, total_sub; simplify = false)
     residual = eq.lhs - eq.rhs
     a, b, islinear = linear_expansion(residual, var)
     @assert islinear
@@ -569,15 +580,13 @@ function add_solved_equation!(s::SystemStructure, iv, neweqs, ieq, solved_vars,
     # var ~ -b/a
     if ModelingToolkit._iszero(a)
         @warn "Tearing: solving $eq for $var is singular!"
+        return nothing
     else
         rhs = -b / a
-        neweq = var ~ simplify_shifts(Symbolics.fixpoint_sub(
+        return var ~ simplify_shifts(Symbolics.fixpoint_sub(
             simplify ?
             Symbolics.simplify(rhs) : rhs,
             total_sub; operator = ModelingToolkit.Shift))
-        push!(solved_eqs, neweq)
-        push!(solvedeq_idxs, ieq)
-        push!(solved_vars, iv)
     end
 end
 
@@ -592,8 +601,8 @@ such that the mass matrix is:
 Update the state to account for the new ordering and equations.
 """
 # TODO: BLT sorting
-function reorder_vars!(s::SystemStructure, var_eq_matching, eq_ordering, var_ordering, nelim_eq, nelim_var)
-    @unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure
+function reorder_vars!(state::TearingState, var_eq_matching, eq_ordering, var_ordering, nelim_eq, nelim_var)
+    @unpack solvable_graph, var_to_diff, eq_to_diff, graph = state.structure
 
     eqsperm = zeros(Int, nsrcs(graph))
     for (i, v) in enumerate(eq_ordering)
@@ -623,20 +632,26 @@ function reorder_vars!(s::SystemStructure, var_eq_matching, eq_ordering, var_ord
         d′ = eqsperm[d]
         new_eq_to_diff[v′] = d′ > 0 ? d′ : nothing
     end
-    new_fullvars = fullvars[invvarsperm]
+    new_fullvars = state.fullvars[var_ordering]
 
+    @show new_graph
+    @show new_var_to_diff
     # Update system structure 
     @set! state.structure.graph = complete(new_graph)
     @set! state.structure.var_to_diff = new_var_to_diff
     @set! state.structure.eq_to_diff = new_eq_to_diff
     @set! state.fullvars = new_fullvars
+    state
 end
 
 """
 Set the system equations, unknowns, observables post-tearing.
 """
-function update_simplified_system!(state::TearingState, neweqs, solved_eqs, dummy_sub, var_eq_matching, extra_unknowns)
+function update_simplified_system!(state::TearingState, neweqs, solved_eqs, dummy_sub, var_eq_matching, extra_unknowns; 
+        cse_hack = true, array_hack = true)
     @unpack solvable_graph, var_to_diff, eq_to_diff, graph = state.structure
+    @show graph
+    @show var_to_diff
     diff_to_var = invview(var_to_diff)
 
     ispresent = let var_to_diff = var_to_diff, graph = graph
@@ -656,7 +671,12 @@ function update_simplified_system!(state::TearingState, neweqs, solved_eqs, dumm
     unknowns = Any[v
                    for (i, v) in enumerate(state.fullvars)
                    if diff_to_var[i] === nothing && ispresent(i)]
+    @show unknowns
+    @show state.fullvars
+    @show 𝑑neighbors(graph, 5)
+    @show neweqs
     unknowns = [unknowns; extra_unknowns]
+    @show unknowns
     @set! sys.unknowns = unknowns
 
     obs, subeqs, deps = cse_and_array_hacks(
@@ -684,7 +704,7 @@ end
 # differential variables.
     
 # Give the order of the variable indexed by dv
-function var_order(diff_to_var, dv)
+function var_order(dv, diff_to_var)
     order = 0
     while (dv′ = diff_to_var[dv]) !== nothing
         order += 1
@@ -695,7 +715,6 @@ end
 
 function tearing_reassemble(state::TearingState, var_eq_matching,
         full_var_eq_matching = nothing; simplify = false, mm = nothing, cse_hack = true, array_hack = true)
-
     extra_vars = Int[]
     if full_var_eq_matching !== nothing
         for v in 𝑑vertices(state.structure.graph)
@@ -704,21 +723,22 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
             push!(extra_vars, v)
         end
     end
-    extra_unknowns = fullvars[extra_vars]
+    extra_unknowns = state.fullvars[extra_vars]
     neweqs = collect(equations(state))
+    dummy_sub = Dict()
 
     # Structural simplification 
-    dummy_sub = Dict()
-    substitute_derivatives_algevars!(state, neweqs, dummy_sub, var_eq_matching)
+    substitute_derivatives_algevars!(state, neweqs, var_eq_matching, dummy_sub)
 
     generate_derivative_variables!(state, neweqs, var_eq_matching; mm)
 
     neweqs, solved_eqs, eq_ordering, var_ordering, nelim_eq, nelim_var = 
         generate_system_equations!(state, neweqs, var_eq_matching; simplify)
 
-    reorder_vars!(state.structure, var_eq_matching, eq_ordering, var_ordering, nelim_eq, nelim_var)
+    state = reorder_vars!(state, var_eq_matching, eq_ordering, var_ordering, nelim_eq, nelim_var)
+
+    sys = update_simplified_system!(state, neweqs, solved_eqs, dummy_sub, var_eq_matching, extra_unknowns; cse_hack, array_hack)
 
-    sys = update_simplified_system!(state, neweqs, solved_eqs, dummy_sub, var_eq_matching, extra_unknowns)
     @set! state.sys = sys
     @set! sys.tearing_state = state
     return invalidate_cache!(sys)
diff --git a/src/structural_transformation/utils.jl b/src/structural_transformation/utils.jl
index 77f0b62a1c..fec57db080 100644
--- a/src/structural_transformation/utils.jl
+++ b/src/structural_transformation/utils.jl
@@ -452,9 +452,10 @@ end
 ### Misc
 ###
 
-function lower_varname_withshift(var, iv)
+# For discrete variables. Turn Shift(t, k)(x(t)) into xₜ₋ₖ(t)
+function lower_shift_varname(var, iv)
     op = operation(var)
-    op isa Shift || return var
+    op isa Shift || return Shift(iv, 0)(var, true) # hack to prevent simplification of x(t) - x(t)
     backshift = op.steps
     backshift > 0 && return var
 
@@ -476,6 +477,14 @@ function lower_varname_withshift(var, iv)
     return ModelingToolkit._with_unit(identity, newvar, iv)
 end
 
+function lower_varname(var, iv, order; is_discrete = false)
+    if is_discrete
+        lower_shift_varname(var, iv)
+    else
+        lower_varname_with_unit(var, iv, order)
+    end
+end
+
 function isdoubleshift(var)
     return ModelingToolkit.isoperator(var, ModelingToolkit.Shift) &&
            ModelingToolkit.isoperator(arguments(var)[1], ModelingToolkit.Shift)
@@ -484,6 +493,7 @@ end
 function simplify_shifts(var)
     ModelingToolkit.hasshift(var) || return var
     var isa Equation && return simplify_shifts(var.lhs) ~ simplify_shifts(var.rhs)
+    (op = operation(var)) isa Shift && op.steps == 0 && return first(arguments(var))
     if isdoubleshift(var)
         op1 = operation(var)
         vv1 = arguments(var)[1]
diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl
index e5a227d9fb..ae4d21f225 100644
--- a/src/systems/systemstructure.jl
+++ b/src/systems/systemstructure.jl
@@ -142,15 +142,16 @@ has_equations(::TransformationState) = true
 Base.@kwdef mutable struct SystemStructure
     """Maps the (index of) a variable to the (index of) the variable describing its derivative."""
     var_to_diff::DiffGraph
-    """Maps the (index of) a """
+    """Maps the (index of) an equation."""
     eq_to_diff::DiffGraph
     # Can be access as
     # `graph` to automatically look at the bipartite graph
     # or as `torn` to assert that tearing has run.
-    """Incidence graph of the system of equations. An edge from equation x to variable y exists if variable y appears in equation x."""
+    """Graph that maps equations to variables that appear in them."""
     graph::BipartiteGraph{Int, Nothing}
-    """."""
+    """Graph that connects equations to the variable they will be solved for during simplification."""
     solvable_graph::Union{BipartiteGraph{Int, Nothing}, Nothing}
+    """Variable types (brownian, variable, parameter) in the system."""
     var_types::Union{Vector{VariableType}, Nothing}
     """Whether the system is discrete."""
     only_discrete::Bool
@@ -200,7 +201,9 @@ function complete!(s::SystemStructure)
 end
 
 mutable struct TearingState{T <: AbstractSystem} <: AbstractTearingState{T}
+    """The system of equations."""
     sys::T
+    """The set of variables of the system."""
     fullvars::Vector
     structure::SystemStructure
     extra_eqs::Vector
@@ -438,7 +441,7 @@ function TearingState(sys; quick_cancel = false, check = true)
             complete(graph), nothing, var_types, sys isa DiscreteSystem),
         Any[])
     if sys isa DiscreteSystem
-        ts = shift_discrete_system(ts, lowest_shift)
+        ts = shift_discrete_system(ts)
     end
     return ts
 end
@@ -475,8 +478,9 @@ function shift_discrete_system(ts::TearingState)
     end
     iv = get_iv(sys)
 
-    discmap = Dict(k => StructuralTransformations.simplify_shifts(Shift(iv, 1)(k))    for k in discvars 
-    if any(isequal(k), fullvars) && !isa(operation(k), Union{Sample, Hold})) 
+    discmap = Dict(k => StructuralTransformations.simplify_shifts(Shift(iv, 1)(k))    
+                   for k in discvars 
+                   if any(isequal(k), fullvars) && !isa(operation(k), Union{Sample, Hold})) 
 
     for i in eachindex(fullvars)
         fullvars[i] = StructuralTransformations.simplify_shifts(fast_substitute(

From 22a39bd41b2123c48a7f88016f25470041271b69 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Mon, 10 Feb 2025 18:42:01 -0500
Subject: [PATCH 079/111] delete comments

---
 src/structural_transformation/utils.jl  |  3 ---
 src/systems/systems.jl                  |  8 ++++----
 test/structural_transformation/utils.jl | 16 ----------------
 3 files changed, 4 insertions(+), 23 deletions(-)

diff --git a/src/structural_transformation/utils.jl b/src/structural_transformation/utils.jl
index fec57db080..757eb9a8db 100644
--- a/src/structural_transformation/utils.jl
+++ b/src/structural_transformation/utils.jl
@@ -271,8 +271,6 @@ end
 
 function find_solvables!(state::TearingState; kwargs...)
     @assert state.structure.solvable_graph === nothing
-    println("in find_solvables")
-    @show eqs
     eqs = equations(state)
     graph = state.structure.graph
     state.structure.solvable_graph = BipartiteGraph(nsrcs(graph), ndsts(graph))
@@ -280,7 +278,6 @@ function find_solvables!(state::TearingState; kwargs...)
     for ieq in 1:length(eqs)
         find_eq_solvables!(state, ieq, to_rm; kwargs...)
     end
-    @show eqs
     return nothing
 end
 
diff --git a/src/systems/systems.jl b/src/systems/systems.jl
index 0a7e4264e4..9c8c272c5c 100644
--- a/src/systems/systems.jl
+++ b/src/systems/systems.jl
@@ -41,10 +41,10 @@ function structural_simplify(
     end
     if newsys isa DiscreteSystem &&
        any(eq -> symbolic_type(eq.lhs) == NotSymbolic(), equations(newsys))
-        # error("""
-        #     Encountered algebraic equations when simplifying discrete system. This is \
-        #     not yet supported.
-        # """)
+        error("""
+            Encountered algebraic equations when simplifying discrete system. This is \
+            not yet supported.
+        """)
     end
     for pass in additional_passes
         newsys = pass(newsys)
diff --git a/test/structural_transformation/utils.jl b/test/structural_transformation/utils.jl
index 98d986ebd4..67f6016bc4 100644
--- a/test/structural_transformation/utils.jl
+++ b/test/structural_transformation/utils.jl
@@ -162,19 +162,3 @@ end
     structural_simplify(sys; additional_passes = [pass])
     @test value[] == 1
 end
-
-@testset "Shift simplification" begin
-    @variables x(t) y(t) z(t)
-    @parameters a b c
-    
-    # Expand shifts
-    @test isequal(ST.expand_shifts(Shift(t, -3)(x)), Shift(t, -1)(Shift(t, -1)(Shift(t, -1)(x))))
-    expr = a * Shift(t, -2)(x) + Shift(t, 2)(y) + b
-    @test isequal(ST.expand_shifts(expr), 
-                    a * Shift(t, -1)(Shift(t, -1)(x)) + Shift(t, 1)(Shift(t, 1)(y)) + b)
-    @test isequal(ST.expand_shifts(Shift(t, 2)(Shift(t, 1)(a))), a)
-
-
-    # Distribute shifts
-
-end

From 9aaf04a512853d5c9336674455ade5d92487c2f5 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Mon, 10 Feb 2025 18:50:23 -0500
Subject: [PATCH 080/111] more cleanup

---
 docs/src/systems/DiscreteSystem.md            | 28 -------------------
 .../symbolics_tearing.jl                      |  9 ------
 src/systems/systemstructure.jl                | 11 ++------
 3 files changed, 3 insertions(+), 45 deletions(-)
 delete mode 100644 docs/src/systems/DiscreteSystem.md

diff --git a/docs/src/systems/DiscreteSystem.md b/docs/src/systems/DiscreteSystem.md
deleted file mode 100644
index b6a8061e50..0000000000
--- a/docs/src/systems/DiscreteSystem.md
+++ /dev/null
@@ -1,28 +0,0 @@
-# DiscreteSystem
-
-## System Constructors
-
-```@docs
-DiscreteSystem
-```
-
-## Composition and Accessor Functions
-
-  - `get_eqs(sys)` or `equations(sys)`: The equations that define the discrete system.
-  - `get_unknowns(sys)` or `unknowns(sys)`: The set of unknowns in the discrete system.
-  - `get_ps(sys)` or `parameters(sys)`: The parameters of the discrete system.
-  - `get_iv(sys)`: The independent variable of the discrete system.
-  - `discrete_events(sys)`: The set of discrete events in the discrete system.
-
-## Transformations
-
-```@docs; canonical=false
-structural_simplify
-```
-
-## Problem Constructors
-
-```@docs; canonical=false
-DiscreteProblem(sys::DiscreteSystem, u0map, tspan)
-DiscreteFunction(sys::DiscreteSystem, u0map, tspan)
-```
diff --git a/src/structural_transformation/symbolics_tearing.jl b/src/structural_transformation/symbolics_tearing.jl
index 74205fcf86..eed7bc3e29 100644
--- a/src/structural_transformation/symbolics_tearing.jl
+++ b/src/structural_transformation/symbolics_tearing.jl
@@ -634,8 +634,6 @@ function reorder_vars!(state::TearingState, var_eq_matching, eq_ordering, var_or
     end
     new_fullvars = state.fullvars[var_ordering]
 
-    @show new_graph
-    @show new_var_to_diff
     # Update system structure 
     @set! state.structure.graph = complete(new_graph)
     @set! state.structure.var_to_diff = new_var_to_diff
@@ -650,8 +648,6 @@ Set the system equations, unknowns, observables post-tearing.
 function update_simplified_system!(state::TearingState, neweqs, solved_eqs, dummy_sub, var_eq_matching, extra_unknowns; 
         cse_hack = true, array_hack = true)
     @unpack solvable_graph, var_to_diff, eq_to_diff, graph = state.structure
-    @show graph
-    @show var_to_diff
     diff_to_var = invview(var_to_diff)
 
     ispresent = let var_to_diff = var_to_diff, graph = graph
@@ -671,12 +667,7 @@ function update_simplified_system!(state::TearingState, neweqs, solved_eqs, dumm
     unknowns = Any[v
                    for (i, v) in enumerate(state.fullvars)
                    if diff_to_var[i] === nothing && ispresent(i)]
-    @show unknowns
-    @show state.fullvars
-    @show 𝑑neighbors(graph, 5)
-    @show neweqs
     unknowns = [unknowns; extra_unknowns]
-    @show unknowns
     @set! sys.unknowns = unknowns
 
     obs, subeqs, deps = cse_and_array_hacks(
diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl
index ae4d21f225..b98bb8c616 100644
--- a/src/systems/systemstructure.jl
+++ b/src/systems/systemstructure.jl
@@ -140,14 +140,14 @@ get_fullvars(ts::TransformationState) = ts.fullvars
 has_equations(::TransformationState) = true
 
 Base.@kwdef mutable struct SystemStructure
-    """Maps the (index of) a variable to the (index of) the variable describing its derivative."""
+    """Maps the index of variable x to the index of variable D(x)."""
     var_to_diff::DiffGraph
-    """Maps the (index of) an equation."""
+    """Maps the index of an equation."""
     eq_to_diff::DiffGraph
     # Can be access as
     # `graph` to automatically look at the bipartite graph
     # or as `torn` to assert that tearing has run.
-    """Graph that maps equations to variables that appear in them."""
+    """Graph that connects equations to variables that appear in them."""
     graph::BipartiteGraph{Int, Nothing}
     """Graph that connects equations to the variable they will be solved for during simplification."""
     solvable_graph::Union{BipartiteGraph{Int, Nothing}, Nothing}
@@ -464,15 +464,10 @@ function lower_order_var(dervar, t)
     diffvar
 end
 
-"""
-    Shift variable x by the largest shift s such that x(k-s) appears in the system of equations.
-    The lowest-shift term will have.
-"""
 function shift_discrete_system(ts::TearingState)
     @unpack fullvars, sys = ts
     discvars = OrderedSet()
     eqs = equations(sys)
-
     for eq in eqs
         vars!(discvars, eq; op = Union{Sample, Hold})
     end

From 66266e410c88a61a9c4e6af1c8c829433d5b4f76 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Mon, 10 Feb 2025 19:10:34 -0500
Subject: [PATCH 081/111] better comments

---
 src/structural_transformation/symbolics_tearing.jl | 14 +++++++++++---
 1 file changed, 11 insertions(+), 3 deletions(-)

diff --git a/src/structural_transformation/symbolics_tearing.jl b/src/structural_transformation/symbolics_tearing.jl
index eed7bc3e29..5707ac8386 100644
--- a/src/structural_transformation/symbolics_tearing.jl
+++ b/src/structural_transformation/symbolics_tearing.jl
@@ -337,11 +337,19 @@ variables and equations, don't add them when they already exist.
 
 Documenting the differences to structural simplification for discrete systems:
 
-In discrete systems the lowest-order term is Shift(t, k)(x(t)), instead of x(t). 
-We want to substitute the k-1 lowest order terms instead of the k-1 highest order terms.
+In discrete systems everything gets shifted forward a timestep by `shift_discrete_system`
+in order to properly generate the difference equations. 
+
+In the system x(k) ~ x(k-1) + x(k-2), becomes Shift(t, 1)(x(t)) ~ x(t) + Shift(t, -1)(x(t))
+
+The lowest-order term is Shift(t, k)(x(t)), instead of x(t).
+As such we actually want dummy variables for the k-1 lowest order terms instead of the k-1 highest order terms.
 
-In the system x(k) ~ x(k-1) + x(k-2), we want to lower
 Shift(t, -1)(x(t)) -> x\_{t-1}(t)
+
+Since Shift(t, -1)(x) is not a derivative, it is directly substituted in `fullvars`. No equation or variable is added for it. 
+
+For ODESystems D(D(D(x))) in equations is recursively substituted as D(x) ~ x_t, D(x_t) ~ x_tt, etc. The analogue for discrete systems, Shift(t, 1)(Shift(t,1)(Shift(t,1)(Shift(t, -3)(x(t))))) does not actually appear. So `total_sub` in generate_system_equations` is directly initialized with all of the lowered variables `Shift(t, -3)(x) -> x_t-3(t)`, etc. 
 =#
 """
 Generate new derivative variables for the system.

From 1ab62463adb298386145f57933f4ccd4667d90d1 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Mon, 10 Feb 2025 19:30:56 -0500
Subject: [PATCH 082/111] fix unassigned indexing

---
 src/structural_transformation/symbolics_tearing.jl | 7 ++++---
 1 file changed, 4 insertions(+), 3 deletions(-)

diff --git a/src/structural_transformation/symbolics_tearing.jl b/src/structural_transformation/symbolics_tearing.jl
index 5707ac8386..48f9141dbd 100644
--- a/src/structural_transformation/symbolics_tearing.jl
+++ b/src/structural_transformation/symbolics_tearing.jl
@@ -512,10 +512,10 @@ function generate_system_equations!(state::TearingState, neweqs, var_eq_matching
     #   Non-solvable equations become algebraic equations.
     for ieq in eqs
         iv = eq_var_matching[ieq]
-        var = fullvars[iv]
         eq = neweqs[ieq]
 
         if is_solvable(ieq, iv) && isdervar(iv)
+            var = fullvars[iv]
             isnothing(D) && throw(UnexpectedDifferentialError(equations(sys)[ieq]))
             order, lv = var_order(iv, diff_to_var)
             dx = D(simplify_shifts(fullvars[lv]))
@@ -530,6 +530,7 @@ function generate_system_equations!(state::TearingState, neweqs, var_eq_matching
             push!(diffeq_idxs, ieq)
             push!(diff_vars, diff_to_var[iv])
         elseif is_solvable(ieq, iv)
+            var = fullvars[iv]
             neweq = make_solved_equation(var, eq, total_sub; simplify)
             !isnothing(neweq) && begin
                 push!(solved_eqs, neweq)
@@ -537,7 +538,7 @@ function generate_system_equations!(state::TearingState, neweqs, var_eq_matching
                 push!(solved_vars, iv)
             end
         else
-            neweq = make_algebraic_equation(var, eq, total_sub)
+            neweq = make_algebraic_equation(eq, total_sub)
             push!(alge_eqs, neweq)
             push!(algeeq_idxs, ieq)
         end
@@ -572,7 +573,7 @@ function make_differential_equation(var, dx, eq, total_sub)
         total_sub; operator = ModelingToolkit.Shift))
 end
 
-function make_algebraic_equation(var, eq, total_sub)
+function make_algebraic_equation(eq, total_sub)
     rhs = eq.rhs
     if !(eq.lhs isa Number && eq.lhs == 0)
         rhs = eq.rhs - eq.lhs

From 079901bd481b8ea4f505761bce6aa476602dd023 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Tue, 11 Feb 2025 15:13:17 -0500
Subject: [PATCH 083/111] move D, iv into tearing_reassemble

---
 .../symbolics_tearing.jl                      | 43 +++++++++----------
 1 file changed, 21 insertions(+), 22 deletions(-)

diff --git a/src/structural_transformation/symbolics_tearing.jl b/src/structural_transformation/symbolics_tearing.jl
index 48f9141dbd..45430595b6 100644
--- a/src/structural_transformation/symbolics_tearing.jl
+++ b/src/structural_transformation/symbolics_tearing.jl
@@ -248,11 +248,10 @@ called dummy derivatives.
 State selection is done. All non-differentiated variables are algebraic 
 variables, and all variables that appear differentiated are differential variables.
 """
-function substitute_derivatives_algevars!(ts::TearingState, neweqs, var_eq_matching, dummy_sub)
+function substitute_derivatives_algevars!(ts::TearingState, neweqs, var_eq_matching, dummy_sub; iv = nothing, D = nothing)
     @unpack fullvars, sys, structure = ts
     @unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure
     diff_to_var = invview(var_to_diff)
-    iv = ModelingToolkit.has_iv(sys) ? ModelingToolkit.get_iv(sys) : nothing
 
     for var in 1:length(fullvars)
         dv = var_to_diff[var]
@@ -361,12 +360,11 @@ Effects on the system structure:
 - solvable_graph:
 - var_eq_matching: match D(x) to the added identity equation
 """
-function generate_derivative_variables!(ts::TearingState, neweqs, var_eq_matching; mm = nothing)
+function generate_derivative_variables!(ts::TearingState, neweqs, var_eq_matching; mm = nothing, iv = nothing, D = nothing)
     @unpack fullvars, sys, structure = ts
     @unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure
     eq_var_matching = invview(var_eq_matching)
     diff_to_var = invview(var_to_diff)
-    iv = ModelingToolkit.has_iv(sys) ? ModelingToolkit.get_iv(sys) : nothing
     is_discrete = is_only_discrete(structure)
     lower_varname = is_discrete ? lower_shift_varname : lower_varname_with_unit
     linear_eqs = mm === nothing ? Dict{Int, Int}() :
@@ -461,27 +459,18 @@ end
 Solve the solvable equations of the system and generate differential (or discrete)
 equations in terms of the selected states.
 """
-function generate_system_equations!(state::TearingState, neweqs, var_eq_matching; simplify = false)
+function generate_system_equations!(state::TearingState, neweqs, var_eq_matching; simplify = false, iv = nothing, D = nothing)
     @unpack fullvars, sys, structure = state 
     @unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure
     eq_var_matching = invview(var_eq_matching)
     diff_to_var = invview(var_to_diff)
     total_sub = Dict()
-
-    if ModelingToolkit.has_iv(sys)
-        iv = get_iv(sys)
-        if is_only_discrete(structure)
-            D = Shift(iv, 1)
-            for v in fullvars
-                op = operation(v)
-                op isa Shift && (op.steps < 0) && (total_sub[v] = lower_shift_varname(v, iv))
-            end
-        else
-            D = Differential(iv)
+    if is_only_discrete(structure)
+        for v in fullvars
+            op = operation(v)
+            op isa Shift && (op.steps < 0) && (total_sub[v] = lower_shift_varname(v, iv))
         end
-    else
-        iv = D = nothing
-    end
+   end
 
     # if var is like D(x) or Shift(t, 1)(x)
     isdervar = let diff_to_var = diff_to_var
@@ -727,13 +716,23 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
     neweqs = collect(equations(state))
     dummy_sub = Dict()
 
+    if ModelingToolkit.has_iv(state.sys)
+        iv = get_iv(state.sys)
+        if !is_only_discrete(state.structure)
+            D = Differential(iv)
+        else
+            D = Shift(iv, 1)
+        end
+        iv = D = nothing
+    end
+
     # Structural simplification 
-    substitute_derivatives_algevars!(state, neweqs, var_eq_matching, dummy_sub)
+    substitute_derivatives_algevars!(state, neweqs, var_eq_matching, dummy_sub; iv, D)
 
-    generate_derivative_variables!(state, neweqs, var_eq_matching; mm)
+    generate_derivative_variables!(state, neweqs, var_eq_matching; mm, iv, D)
 
     neweqs, solved_eqs, eq_ordering, var_ordering, nelim_eq, nelim_var = 
-        generate_system_equations!(state, neweqs, var_eq_matching; simplify)
+        generate_system_equations!(state, neweqs, var_eq_matching; simplify, iv, D)
 
     state = reorder_vars!(state, var_eq_matching, eq_ordering, var_ordering, nelim_eq, nelim_var)
 

From 9217c62941226cd8dc24a80fe3723810f2463ab5 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Tue, 11 Feb 2025 15:29:46 -0500
Subject: [PATCH 084/111] refactor iv

---
 src/structural_transformation/symbolics_tearing.jl | 1 +
 1 file changed, 1 insertion(+)

diff --git a/src/structural_transformation/symbolics_tearing.jl b/src/structural_transformation/symbolics_tearing.jl
index 45430595b6..883166cfa8 100644
--- a/src/structural_transformation/symbolics_tearing.jl
+++ b/src/structural_transformation/symbolics_tearing.jl
@@ -723,6 +723,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
         else
             D = Shift(iv, 1)
         end
+    else
         iv = D = nothing
     end
 

From 54910988509cfdf2f65ed8b8d24f565b3dab8584 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Tue, 11 Feb 2025 16:13:47 -0500
Subject: [PATCH 085/111] fix comment for eq_to_diff

---
 src/systems/systemstructure.jl | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl
index b98bb8c616..6fee78cfd6 100644
--- a/src/systems/systemstructure.jl
+++ b/src/systems/systemstructure.jl
@@ -142,7 +142,7 @@ has_equations(::TransformationState) = true
 Base.@kwdef mutable struct SystemStructure
     """Maps the index of variable x to the index of variable D(x)."""
     var_to_diff::DiffGraph
-    """Maps the index of an equation."""
+    """Maps the index of an algebraic equation to the index of the equation it is differentiated into."""
     eq_to_diff::DiffGraph
     # Can be access as
     # `graph` to automatically look at the bipartite graph

From 5878ad67ef6a78ad55034d8138a6fee741a9e7d6 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Tue, 11 Feb 2025 16:51:06 -0500
Subject: [PATCH 086/111] fix total_sub

---
 src/structural_transformation/symbolics_tearing.jl | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/src/structural_transformation/symbolics_tearing.jl b/src/structural_transformation/symbolics_tearing.jl
index 883166cfa8..7cb9798f01 100644
--- a/src/structural_transformation/symbolics_tearing.jl
+++ b/src/structural_transformation/symbolics_tearing.jl
@@ -464,6 +464,7 @@ function generate_system_equations!(state::TearingState, neweqs, var_eq_matching
     @unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure
     eq_var_matching = invview(var_eq_matching)
     diff_to_var = invview(var_to_diff)
+
     total_sub = Dict()
     if is_only_discrete(structure)
         for v in fullvars
@@ -515,6 +516,7 @@ function generate_system_equations!(state::TearingState, neweqs, var_eq_matching
                 rem_edge!(graph, e, iv)
             end
 
+            total_sub[simplify_shifts(neweq.lhs)] = neweq.rhs
             push!(diff_eqs, neweq)
             push!(diffeq_idxs, ieq)
             push!(diff_vars, diff_to_var[iv])

From 397b279df53c45a6c923ee89a24125237d4ee44c Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Tue, 11 Feb 2025 23:08:07 -0500
Subject: [PATCH 087/111] add diff_to_var as argument for find_dumplicate_dd,
 fix substitution of lowest-order variable

---
 src/structural_transformation/symbolics_tearing.jl | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/src/structural_transformation/symbolics_tearing.jl b/src/structural_transformation/symbolics_tearing.jl
index 7cb9798f01..ec1ff45fd7 100644
--- a/src/structural_transformation/symbolics_tearing.jl
+++ b/src/structural_transformation/symbolics_tearing.jl
@@ -374,9 +374,9 @@ function generate_derivative_variables!(ts::TearingState, neweqs, var_eq_matchin
     # derivative is in the system
     for v in 1:length(var_to_diff)
         dv = var_to_diff[v]
-        # For discrete systems, directly substitute lowest-order variable 
+        # For discrete systems, directly substitute lowest-order shift 
         if is_discrete && diff_to_var[v] == nothing
-            fullvars[v] = lower_varname(fullvars[v], iv)
+            operation(fullvars[v]) isa Shift && (fullvars[v] = lower_varname(fullvars[v], iv))
         end
         dv isa Int || continue
         solved = var_eq_matching[dv] isa Int
@@ -384,7 +384,7 @@ function generate_derivative_variables!(ts::TearingState, neweqs, var_eq_matchin
 
         # If there's `D(x) = x_t` already, update mappings and continue without
         # adding new equations/variables
-        dd = find_duplicate_dd(dv, solvable_graph, linear_eqs, mm)
+        dd = find_duplicate_dd(dv, solvable_graph, diff_to_var, linear_eqs, mm)
         if !isnothing(dd)
             dummy_eq, v_t = dd
             var_to_diff[v_t] = var_to_diff[dv]
@@ -412,7 +412,7 @@ end
 """
 Check if there's `D(x) = x_t` already.
 """
-function find_duplicate_dd(dv, solvable_graph, linear_eqs, mm)
+function find_duplicate_dd(dv, solvable_graph, diff_to_var, linear_eqs, mm)
     for eq in 𝑑neighbors(solvable_graph, dv)
         mi = get(linear_eqs, eq, 0)
         iszero(mi) && continue

From bb42719158e876d21a9ddf12b934141632fd2823 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Thu, 13 Feb 2025 15:46:43 -0500
Subject: [PATCH 088/111] fix: fix initialization of DiscreteSystem with
 renamed variables

---
 .../StructuralTransformations.jl              |  3 +-
 .../symbolics_tearing.jl                      | 20 +++++++-----
 src/structural_transformation/utils.jl        | 32 ++++++++++---------
 .../discrete_system/discrete_system.jl        |  6 ++--
 src/utils.jl                                  |  2 ++
 src/variables.jl                              |  7 +++-
 test/runtests.jl                              |  3 ++
 7 files changed, 45 insertions(+), 28 deletions(-)

diff --git a/src/structural_transformation/StructuralTransformations.jl b/src/structural_transformation/StructuralTransformations.jl
index 1220d517cc..510352ad59 100644
--- a/src/structural_transformation/StructuralTransformations.jl
+++ b/src/structural_transformation/StructuralTransformations.jl
@@ -22,7 +22,7 @@ using ModelingToolkit: ODESystem, AbstractSystem, var_from_nested_derivative, Di
                        get_postprocess_fbody, vars!,
                        IncrementalCycleTracker, add_edge_checked!, topological_sort,
                        invalidate_cache!, Substitutions, get_or_construct_tearing_state,
-                       filter_kwargs, lower_varname_with_unit, setio, SparseMatrixCLIL,
+                       filter_kwargs, lower_varname_with_unit, lower_shift_varname_with_unit, setio, SparseMatrixCLIL,
                        get_fullvars, has_equations, observed,
                        Schedule, schedule
 
@@ -63,6 +63,7 @@ export torn_system_jacobian_sparsity
 export full_equations
 export but_ordered_incidence, lowest_order_variable_mask, highest_order_variable_mask
 export computed_highest_diff_variables
+export shift2term, lower_shift_varname
 
 include("utils.jl")
 include("pantelides.jl")
diff --git a/src/structural_transformation/symbolics_tearing.jl b/src/structural_transformation/symbolics_tearing.jl
index ec1ff45fd7..ffbd6e3452 100644
--- a/src/structural_transformation/symbolics_tearing.jl
+++ b/src/structural_transformation/symbolics_tearing.jl
@@ -366,7 +366,6 @@ function generate_derivative_variables!(ts::TearingState, neweqs, var_eq_matchin
     eq_var_matching = invview(var_eq_matching)
     diff_to_var = invview(var_to_diff)
     is_discrete = is_only_discrete(structure)
-    lower_varname = is_discrete ? lower_shift_varname : lower_varname_with_unit
     linear_eqs = mm === nothing ? Dict{Int, Int}() :
                  Dict(reverse(en) for en in enumerate(mm.nzrows))
 
@@ -375,9 +374,9 @@ function generate_derivative_variables!(ts::TearingState, neweqs, var_eq_matchin
     for v in 1:length(var_to_diff)
         dv = var_to_diff[v]
         # For discrete systems, directly substitute lowest-order shift 
-        if is_discrete && diff_to_var[v] == nothing
-            operation(fullvars[v]) isa Shift && (fullvars[v] = lower_varname(fullvars[v], iv))
-        end
+        #if is_discrete && diff_to_var[v] == nothing
+        #    operation(fullvars[v]) isa Shift && (fullvars[v] = lower_shift_varname_with_unit(fullvars[v], iv))
+        #end
         dv isa Int || continue
         solved = var_eq_matching[dv] isa Int
         solved && continue
@@ -395,7 +394,8 @@ function generate_derivative_variables!(ts::TearingState, neweqs, var_eq_matchin
 
         dx = fullvars[dv]
         order, lv = var_order(dv, diff_to_var)
-        x_t = is_discrete ? lower_varname(fullvars[dv], iv) : lower_varname(fullvars[lv], iv, order)
+        x_t = is_discrete ? lower_shift_varname_with_unit(fullvars[dv], iv) : 
+                            Symbolics.diff2term(fullvars[dv])
 
         # Add `x_t` to the graph
         v_t = add_dd_variable!(structure, fullvars, x_t, dv)
@@ -467,11 +467,15 @@ function generate_system_equations!(state::TearingState, neweqs, var_eq_matching
 
     total_sub = Dict()
     if is_only_discrete(structure)
-        for v in fullvars
+        for (i, v) in enumerate(fullvars)
             op = operation(v)
-            op isa Shift && (op.steps < 0) && (total_sub[v] = lower_shift_varname(v, iv))
+            op isa Shift && (op.steps < 0) && begin
+                lowered = lower_shift_varname_with_unit(v, iv)
+                total_sub[v] = lowered
+                fullvars[i] = lowered
+            end
         end
-   end
+    end
 
     # if var is like D(x) or Shift(t, 1)(x)
     isdervar = let diff_to_var = diff_to_var
diff --git a/src/structural_transformation/utils.jl b/src/structural_transformation/utils.jl
index 757eb9a8db..f64a8da132 100644
--- a/src/structural_transformation/utils.jl
+++ b/src/structural_transformation/utils.jl
@@ -453,16 +453,25 @@ end
 function lower_shift_varname(var, iv)
     op = operation(var)
     op isa Shift || return Shift(iv, 0)(var, true) # hack to prevent simplification of x(t) - x(t)
-    backshift = op.steps
-    backshift > 0 && return var
+    if op.steps < 0
+        return shift2term(var)
+    else
+        return var
+    end
+end
 
-    ds = "$iv-$(-backshift)"
-    d_separator = 'ˍ'
+function shift2term(var) 
+    backshift = operation(var).steps
+    iv = operation(var).t
+    num = join(Char(0x2080 + d) for d in reverse!(digits(-backshift)))
+    ds = join([Char(0x209c), Char(0x208b), num])
+    #ds = "$iv-$(-backshift)"
+    #d_separator = 'ˍ'
 
     if ModelingToolkit.isoperator(var, ModelingToolkit.Shift)
         O = only(arguments(var))
         oldop = operation(O)
-        newname = Symbol(string(nameof(oldop)), d_separator, ds)
+        newname = Symbol(string(nameof(oldop)), ds)
     else
         O = var
         oldop = operation(var) 
@@ -470,16 +479,9 @@ function lower_shift_varname(var, iv)
         newname = Symbol(varname, d_separator, ds)
     end
     newvar = maketerm(typeof(O), Symbolics.rename(oldop, newname), Symbolics.children(O), Symbolics.metadata(O))
-    setmetadata(newvar, Symbolics.VariableSource, (:variables, newname))
-    return ModelingToolkit._with_unit(identity, newvar, iv)
-end
-
-function lower_varname(var, iv, order; is_discrete = false)
-    if is_discrete
-        lower_shift_varname(var, iv)
-    else
-        lower_varname_with_unit(var, iv, order)
-    end
+    newvar = setmetadata(newvar, Symbolics.VariableSource, (:variables, newname))
+    newvar = setmetadata(newvar, ModelingToolkit.VariableUnshifted, O)
+    return newvar
 end
 
 function isdoubleshift(var)
diff --git a/src/systems/discrete_system/discrete_system.jl b/src/systems/discrete_system/discrete_system.jl
index cd6ff45b6c..9a9ac47853 100644
--- a/src/systems/discrete_system/discrete_system.jl
+++ b/src/systems/discrete_system/discrete_system.jl
@@ -275,10 +275,10 @@ function shift_u0map_forward(sys::DiscreteSystem, u0map, defs)
     end
     for var in unknowns(sys)
         op = operation(var)
-        op isa Shift || continue
         haskey(updated, var) && continue
-        root = first(arguments(var))
-        haskey(defs, root) || error("Initial condition for $var not provided.")
+        root = getunshifted(var)
+        isnothing(root) && continue
+        haskey(defs, root) || error("Initial condition for $root not provided.")
         updated[var] = defs[root]
     end
     return updated
diff --git a/src/utils.jl b/src/utils.jl
index cf49d9f445..c3a0f637a1 100644
--- a/src/utils.jl
+++ b/src/utils.jl
@@ -1028,6 +1028,8 @@ end
 
 diff2term_with_unit(x, t) = _with_unit(diff2term, x, t)
 lower_varname_with_unit(var, iv, order) = _with_unit(lower_varname, var, iv, iv, order)
+shift2term_with_unit(x, t) = _with_unit(shift2term, x, t)
+lower_shift_varname_with_unit(var, iv) = _with_unit(lower_shift_varname, var, iv, iv)
 
 """
     $(TYPEDSIGNATURES)
diff --git a/src/variables.jl b/src/variables.jl
index 536119f107..a29a119607 100644
--- a/src/variables.jl
+++ b/src/variables.jl
@@ -6,6 +6,7 @@ struct VariableOutput end
 struct VariableIrreducible end
 struct VariableStatePriority end
 struct VariableMisc end
+struct VariableUnshifted end
 Symbolics.option_to_metadata_type(::Val{:unit}) = VariableUnit
 Symbolics.option_to_metadata_type(::Val{:connect}) = VariableConnectType
 Symbolics.option_to_metadata_type(::Val{:input}) = VariableInput
@@ -13,6 +14,7 @@ Symbolics.option_to_metadata_type(::Val{:output}) = VariableOutput
 Symbolics.option_to_metadata_type(::Val{:irreducible}) = VariableIrreducible
 Symbolics.option_to_metadata_type(::Val{:state_priority}) = VariableStatePriority
 Symbolics.option_to_metadata_type(::Val{:misc}) = VariableMisc
+Symbolics.option_to_metadata_type(::Val{:unshifted}) = VariableUnshifted
 
 """
     dump_variable_metadata(var)
@@ -133,7 +135,7 @@ function default_toterm(x)
     if iscall(x) && (op = operation(x)) isa Operator
         if !(op isa Differential)
             if op isa Shift && op.steps < 0
-                return x
+                return shift2term(x) 
             end
             x = normalize_to_differential(op)(arguments(x)...)
         end
@@ -600,3 +602,6 @@ getunit(x::Symbolic) = Symbolics.getmetadata(x, VariableUnit, nothing)
 Check if the variable `x` has a unit.
 """
 hasunit(x) = getunit(x) !== nothing
+
+getunshifted(x) = getunshifted(unwrap(x))
+getunshifted(x::Symbolic) = Symbolics.getmetadata(x, VariableUnshifted, nothing)
diff --git a/test/runtests.jl b/test/runtests.jl
index 9537b1b44e..e600305232 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -22,6 +22,8 @@ function activate_downstream_env()
     Pkg.instantiate()
 end
 
+@testset begin include("discrete_system.jl") end
+#=
 @time begin
     if GROUP == "All" || GROUP == "InterfaceI"
         @testset "InterfaceI" begin
@@ -136,3 +138,4 @@ end
         @safetestset "InfiniteOpt Extension Test" include("extensions/test_infiniteopt.jl")
     end
 end
+=#

From 09f31b23ca7338e3a646b26a3bc582bd58633a2a Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Thu, 13 Feb 2025 15:49:35 -0500
Subject: [PATCH 089/111] revert runtest

---
 test/runtests.jl | 3 ---
 1 file changed, 3 deletions(-)

diff --git a/test/runtests.jl b/test/runtests.jl
index e600305232..9537b1b44e 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -22,8 +22,6 @@ function activate_downstream_env()
     Pkg.instantiate()
 end
 
-@testset begin include("discrete_system.jl") end
-#=
 @time begin
     if GROUP == "All" || GROUP == "InterfaceI"
         @testset "InterfaceI" begin
@@ -138,4 +136,3 @@ end
         @safetestset "InfiniteOpt Extension Test" include("extensions/test_infiniteopt.jl")
     end
 end
-=#

From daa789856714d3926c265307a94bc6f3f7eecbbd Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Fri, 14 Feb 2025 06:25:44 -0500
Subject: [PATCH 090/111] frefactor use toterm instead of lower_varname

---
 src/structural_transformation/symbolics_tearing.jl | 10 +++-------
 1 file changed, 3 insertions(+), 7 deletions(-)

diff --git a/src/structural_transformation/symbolics_tearing.jl b/src/structural_transformation/symbolics_tearing.jl
index ffbd6e3452..47c7da371f 100644
--- a/src/structural_transformation/symbolics_tearing.jl
+++ b/src/structural_transformation/symbolics_tearing.jl
@@ -366,6 +366,7 @@ function generate_derivative_variables!(ts::TearingState, neweqs, var_eq_matchin
     eq_var_matching = invview(var_eq_matching)
     diff_to_var = invview(var_to_diff)
     is_discrete = is_only_discrete(structure)
+    toterm = is_discrete ? shift2term_with_unit : diff2term_with_unit
     linear_eqs = mm === nothing ? Dict{Int, Int}() :
                  Dict(reverse(en) for en in enumerate(mm.nzrows))
 
@@ -373,10 +374,6 @@ function generate_derivative_variables!(ts::TearingState, neweqs, var_eq_matchin
     # derivative is in the system
     for v in 1:length(var_to_diff)
         dv = var_to_diff[v]
-        # For discrete systems, directly substitute lowest-order shift 
-        #if is_discrete && diff_to_var[v] == nothing
-        #    operation(fullvars[v]) isa Shift && (fullvars[v] = lower_shift_varname_with_unit(fullvars[v], iv))
-        #end
         dv isa Int || continue
         solved = var_eq_matching[dv] isa Int
         solved && continue
@@ -394,8 +391,7 @@ function generate_derivative_variables!(ts::TearingState, neweqs, var_eq_matchin
 
         dx = fullvars[dv]
         order, lv = var_order(dv, diff_to_var)
-        x_t = is_discrete ? lower_shift_varname_with_unit(fullvars[dv], iv) : 
-                            Symbolics.diff2term(fullvars[dv])
+        x_t = toterm(fullvars[dv]) 
 
         # Add `x_t` to the graph
         v_t = add_dd_variable!(structure, fullvars, x_t, dv)
@@ -470,7 +466,7 @@ function generate_system_equations!(state::TearingState, neweqs, var_eq_matching
         for (i, v) in enumerate(fullvars)
             op = operation(v)
             op isa Shift && (op.steps < 0) && begin
-                lowered = lower_shift_varname_with_unit(v, iv)
+                lowered = shift2term_with_unit(v)
                 total_sub[v] = lowered
                 fullvars[i] = lowered
             end

From ae19eeb84b3b2d3cf6ca6db1a658581b8f9ba714 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Fri, 14 Feb 2025 06:50:46 -0500
Subject: [PATCH 091/111] fix unit

---
 src/structural_transformation/StructuralTransformations.jl | 2 +-
 src/structural_transformation/symbolics_tearing.jl         | 6 +++---
 2 files changed, 4 insertions(+), 4 deletions(-)

diff --git a/src/structural_transformation/StructuralTransformations.jl b/src/structural_transformation/StructuralTransformations.jl
index 510352ad59..7d2b9afa26 100644
--- a/src/structural_transformation/StructuralTransformations.jl
+++ b/src/structural_transformation/StructuralTransformations.jl
@@ -11,7 +11,7 @@ using SymbolicUtils: maketerm, iscall
 
 using ModelingToolkit
 using ModelingToolkit: ODESystem, AbstractSystem, var_from_nested_derivative, Differential,
-                       unknowns, equations, vars, Symbolic, diff2term_with_unit, value,
+                       unknowns, equations, vars, Symbolic, diff2term_with_unit, shift2term_with_unit, value,
                        operation, arguments, Sym, Term, simplify, symbolic_linear_solve,
                        isdiffeq, isdifferential, isirreducible,
                        empty_substitutions, get_substitutions,
diff --git a/src/structural_transformation/symbolics_tearing.jl b/src/structural_transformation/symbolics_tearing.jl
index 47c7da371f..779f5931f5 100644
--- a/src/structural_transformation/symbolics_tearing.jl
+++ b/src/structural_transformation/symbolics_tearing.jl
@@ -366,7 +366,6 @@ function generate_derivative_variables!(ts::TearingState, neweqs, var_eq_matchin
     eq_var_matching = invview(var_eq_matching)
     diff_to_var = invview(var_to_diff)
     is_discrete = is_only_discrete(structure)
-    toterm = is_discrete ? shift2term_with_unit : diff2term_with_unit
     linear_eqs = mm === nothing ? Dict{Int, Int}() :
                  Dict(reverse(en) for en in enumerate(mm.nzrows))
 
@@ -391,7 +390,8 @@ function generate_derivative_variables!(ts::TearingState, neweqs, var_eq_matchin
 
         dx = fullvars[dv]
         order, lv = var_order(dv, diff_to_var)
-        x_t = toterm(fullvars[dv]) 
+        x_t = is_discrete ? lower_shift_varname_with_unit(fullvars[dv], iv)
+                          : lower_varname_with_unit(fullvars[lv], iv, order)
 
         # Add `x_t` to the graph
         v_t = add_dd_variable!(structure, fullvars, x_t, dv)
@@ -466,7 +466,7 @@ function generate_system_equations!(state::TearingState, neweqs, var_eq_matching
         for (i, v) in enumerate(fullvars)
             op = operation(v)
             op isa Shift && (op.steps < 0) && begin
-                lowered = shift2term_with_unit(v)
+                lowered = lower_shift_varname_with_unit(v, iv)
                 total_sub[v] = lowered
                 fullvars[i] = lowered
             end

From 406d0a861d4904d10cf99ad11dbc118f38a0c718 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Fri, 14 Feb 2025 06:55:14 -0500
Subject: [PATCH 092/111] fix parse error

---
 src/structural_transformation/symbolics_tearing.jl | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/src/structural_transformation/symbolics_tearing.jl b/src/structural_transformation/symbolics_tearing.jl
index 779f5931f5..a166bb064b 100644
--- a/src/structural_transformation/symbolics_tearing.jl
+++ b/src/structural_transformation/symbolics_tearing.jl
@@ -390,8 +390,8 @@ function generate_derivative_variables!(ts::TearingState, neweqs, var_eq_matchin
 
         dx = fullvars[dv]
         order, lv = var_order(dv, diff_to_var)
-        x_t = is_discrete ? lower_shift_varname_with_unit(fullvars[dv], iv)
-                          : lower_varname_with_unit(fullvars[lv], iv, order)
+        x_t = is_discrete ? lower_shift_varname_with_unit(fullvars[dv], iv) :
+                          lower_varname_with_unit(fullvars[lv], iv, order)
 
         # Add `x_t` to the graph
         v_t = add_dd_variable!(structure, fullvars, x_t, dv)

From ae4e6f70200221e736bdeea732f50c9c8fbeced2 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Mon, 17 Feb 2025 11:44:51 -0500
Subject: [PATCH 093/111] add tests

---
 src/systems/problem_utils.jl | 32 ++++++++++++++++++++------------
 test/odesystem.jl            | 15 +++++++++++++++
 2 files changed, 35 insertions(+), 12 deletions(-)

diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl
index 8541056272..3d5b00c418 100644
--- a/src/systems/problem_utils.jl
+++ b/src/systems/problem_utils.jl
@@ -687,11 +687,11 @@ function process_SciMLProblem(
 
     u0map = to_varmap(u0map, dvs)
     symbols_to_symbolics!(sys, u0map)
-    check_keys(sys, u0map)
 
     pmap = to_varmap(pmap, ps)
     symbols_to_symbolics!(sys, pmap)
-    check_keys(sys, pmap)
+
+    check_inputmap_keys(sys, u0map, pmap)
 
     defs = add_toterms(recursive_unwrap(defaults(sys)))
     cmap, cs = get_cmap(sys)
@@ -783,29 +783,37 @@ end
 
 # Check that the keys of a u0map or pmap are valid
 # (i.e. are symbolic keys, and are defined for the system.)
-function check_keys(sys, map) 
-    badkeys = Any[]
-    for k in keys(map)
+function check_inputmap_keys(sys, u0map, pmap)
+    badvarkeys = Any[]
+    for k in keys(u0map)
         if symbolic_type(k) === NotSymbolic()
-            push!(badkeys, k)
+            push!(badvarkeys, k)
         end
     end
 
-    isempty(badkeys) || throw(BadKeyError(collect(badkeys)))
+    badparamkeys = Any[]
+    for k in keys(pmap)
+        if symbolic_type(k) === NotSymbolic()
+            push!(badparamkeys, k)
+        end
+    end
+    (isempty(badvarkeys) && isempty(badparamkeys)) || throw(InvalidKeyError(collect(badvarkeys), collect(badparamkeys)))
 end
 
 const BAD_KEY_MESSAGE = """
-                        Undefined keys found in the parameter or initial condition maps. 
-                        The following keys are either invalid or not parameters/states of the system:
+                        Undefined keys found in the parameter or initial condition maps. Check if symbolic variable names have been reassigned. 
+                        The following keys are invalid:
                         """
 
-struct BadKeyError <: Exception
+struct InvalidKeyError <: Exception
     vars::Any
+    params::Any
 end
 
-function Base.showerror(io::IO, e::BadKeyError) 
+function Base.showerror(io::IO, e::InvalidKeyError) 
     println(io, BAD_KEY_MESSAGE)
-    println(io, join(e.vars, ", "))
+    println(io, "u0map: $(join(e.vars, ", "))")
+    println(io, "pmap: $(join(e.params, ", "))")
 end
 
 
diff --git a/test/odesystem.jl b/test/odesystem.jl
index a635c3dad9..7b4d580718 100644
--- a/test/odesystem.jl
+++ b/test/odesystem.jl
@@ -1626,3 +1626,18 @@ end
     prob = ODEProblem{false}(lowered_dae_sys; u0_constructor = x -> SVector(x...))
     @test prob.u0 isa SVector
 end
+
+@testset "input map validation" begin
+    import ModelingToolkit: InvalidKeyError 
+    @variables x(t) y(t) z(t)
+    @parameters a b c d 
+    eqs = [D(x) ~ x*a, D(y) ~ y*c, D(z) ~ b + d]
+    @mtkbuild sys = ODESystem(eqs, t)
+    pmap = [a => 1, b => 2, c => 3, d => 4, "b" => 2]
+    u0map = [x => 1, y => 2, z => 3]
+    @test_throws InvalidKeyError ODEProblem(sys, u0map, (0., 1.), pmap)
+
+    pmap = [a => 1, b => 2, c => 3, d => 4]
+    u0map = [x => 1, y => 2, z => 3, :0 => 3]
+    @test_throws InvalidKeyError ODEProblem(sys, u0map, (0., 1.), pmap)
+end

From 83f572c42a3787a395675bd5a0de4daa5aa30063 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Mon, 17 Feb 2025 18:25:45 -0500
Subject: [PATCH 094/111] refactor tests

---
 test/odesystem.jl          | 15 ---------------
 test/problem_validation.jl | 21 +++++++++++++++------
 2 files changed, 15 insertions(+), 21 deletions(-)

diff --git a/test/odesystem.jl b/test/odesystem.jl
index 62bcf4c355..de166ef0a1 100644
--- a/test/odesystem.jl
+++ b/test/odesystem.jl
@@ -1673,18 +1673,3 @@ end
     prob = ODEProblem{false}(lowered_dae_sys; u0_constructor = x -> SVector(x...))
     @test prob.u0 isa SVector
 end
-
-@testset "input map validation" begin
-    import ModelingToolkit: InvalidKeyError 
-    @variables x(t) y(t) z(t)
-    @parameters a b c d 
-    eqs = [D(x) ~ x*a, D(y) ~ y*c, D(z) ~ b + d]
-    @mtkbuild sys = ODESystem(eqs, t)
-    pmap = [a => 1, b => 2, c => 3, d => 4, "b" => 2]
-    u0map = [x => 1, y => 2, z => 3]
-    @test_throws InvalidKeyError ODEProblem(sys, u0map, (0., 1.), pmap)
-
-    pmap = [a => 1, b => 2, c => 3, d => 4]
-    u0map = [x => 1, y => 2, z => 3, :0 => 3]
-    @test_throws InvalidKeyError ODEProblem(sys, u0map, (0., 1.), pmap)
-end
diff --git a/test/problem_validation.jl b/test/problem_validation.jl
index f871327ae8..bce39b51d2 100644
--- a/test/problem_validation.jl
+++ b/test/problem_validation.jl
@@ -2,6 +2,7 @@ using ModelingToolkit
 using ModelingToolkit: t_nounits as t, D_nounits as D
 
 @testset "Input map validation" begin
+    import ModelingToolkit: InvalidKeyError, MissingParametersError 
     @variables X(t)
     @parameters p d
     eqs = [D(X) ~ p - d*X]
@@ -10,16 +11,24 @@ using ModelingToolkit: t_nounits as t, D_nounits as D
     p = "I accidentally renamed p"
     u0 = [X => 1.0]
     ps = [p => 1.0, d => 0.5]
-    @test_throws ModelingToolkit.BadKeyError oprob = ODEProblem(osys, u0, (0.0, 1.0), ps)
+    @test_throws MissingParametersError oprob = ODEProblem(osys, u0, (0.0, 1.0), ps)
     
     @parameters p d
     ps = [p => 1.0, d => 0.5, "Random stuff" => 3.0]
-    @test_throws ModelingToolkit.BadKeyError oprob = ODEProblem(osys, u0, (0.0, 1.0), ps)
+    @test_throws InvalidKeyError oprob = ODEProblem(osys, u0, (0.0, 1.0), ps)
 
     u0 = [:X => 1.0, "random" => 3.0]
-    @test_throws ModelingToolkit.BadKeyError oprob = ODEProblem(osys, u0, (0.0, 1.0), ps)
+    @test_throws InvalidKeyError oprob = ODEProblem(osys, u0, (0.0, 1.0), ps)
 
-    @parameters k
-    ps = [p => 1., d => 0.5, k => 3.]
-    @test_throws ModelingToolkit.BadKeyError oprob = ODEProblem(osys, u0, (0.0, 1.0), ps)
+    @variables x(t) y(t) z(t)
+    @parameters a b c d 
+    eqs = [D(x) ~ x*a, D(y) ~ y*c, D(z) ~ b + d]
+    @mtkbuild sys = ODESystem(eqs, t)
+    pmap = [a => 1, b => 2, c => 3, d => 4, "b" => 2]
+    u0map = [x => 1, y => 2, z => 3]
+    @test_throws InvalidKeyError ODEProblem(sys, u0map, (0., 1.), pmap)
+
+    pmap = [a => 1, b => 2, c => 3, d => 4]
+    u0map = [x => 1, y => 2, z => 3, :0 => 3]
+    @test_throws InvalidKeyError ODEProblem(sys, u0map, (0., 1.), pmap)
 end

From 2a97325d0b4e9b70f3a18de8049d14559504880d Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Tue, 18 Feb 2025 15:19:38 -0500
Subject: [PATCH 095/111] feat: initialization of DiscreteSystem

---
 .../symbolics_tearing.jl                      | 114 +++++++++++++-----
 src/structural_transformation/utils.jl        |  55 ++++++---
 .../discrete_system/discrete_system.jl        |   8 +-
 src/variables.jl                              |  16 ++-
 test/discrete_system.jl                       |  68 ++++++++---
 5 files changed, 193 insertions(+), 68 deletions(-)

diff --git a/src/structural_transformation/symbolics_tearing.jl b/src/structural_transformation/symbolics_tearing.jl
index a166bb064b..bbb6853a7c 100644
--- a/src/structural_transformation/symbolics_tearing.jl
+++ b/src/structural_transformation/symbolics_tearing.jl
@@ -341,14 +341,18 @@ in order to properly generate the difference equations.
 
 In the system x(k) ~ x(k-1) + x(k-2), becomes Shift(t, 1)(x(t)) ~ x(t) + Shift(t, -1)(x(t))
 
-The lowest-order term is Shift(t, k)(x(t)), instead of x(t).
-As such we actually want dummy variables for the k-1 lowest order terms instead of the k-1 highest order terms.
+The lowest-order term is Shift(t, k)(x(t)), instead of x(t). As such we actually want 
+dummy variables for the k-1 lowest order terms instead of the k-1 highest order terms.
 
 Shift(t, -1)(x(t)) -> x\_{t-1}(t)
 
-Since Shift(t, -1)(x) is not a derivative, it is directly substituted in `fullvars`. No equation or variable is added for it. 
+Since Shift(t, -1)(x) is not a derivative, it is directly substituted in `fullvars`. 
+No equation or variable is added for it. 
 
-For ODESystems D(D(D(x))) in equations is recursively substituted as D(x) ~ x_t, D(x_t) ~ x_tt, etc. The analogue for discrete systems, Shift(t, 1)(Shift(t,1)(Shift(t,1)(Shift(t, -3)(x(t))))) does not actually appear. So `total_sub` in generate_system_equations` is directly initialized with all of the lowered variables `Shift(t, -3)(x) -> x_t-3(t)`, etc. 
+For ODESystems D(D(D(x))) in equations is recursively substituted as D(x) ~ x_t, D(x_t) ~ x_tt, etc. 
+The analogue for discrete systems, Shift(t, 1)(Shift(t,1)(Shift(t,1)(Shift(t, -3)(x(t))))) 
+does not actually appear. So `total_sub` in generate_system_equations` is directly 
+initialized with all of the lowered variables `Shift(t, -3)(x) -> x_t-3(t)`, etc. 
 =#
 """
 Generate new derivative variables for the system.
@@ -358,7 +362,7 @@ Effects on the system structure:
 - neweqs: add the identity equations for the new variables, D(x) ~ x_t
 - graph: update graph with the new equations and variables, and their connections
 - solvable_graph:
-- var_eq_matching: match D(x) to the added identity equation
+- var_eq_matching: match D(x) to the added identity equation D(x) ~ x_t
 """
 function generate_derivative_variables!(ts::TearingState, neweqs, var_eq_matching; mm = nothing, iv = nothing, D = nothing)
     @unpack fullvars, sys, structure = ts
@@ -406,7 +410,7 @@ function generate_derivative_variables!(ts::TearingState, neweqs, var_eq_matchin
 end
 
 """
-Check if there's `D(x) = x_t` already.
+Check if there's `D(x) ~ x_t` already.
 """
 function find_duplicate_dd(dv, solvable_graph, diff_to_var, linear_eqs, mm)
     for eq in 𝑑neighbors(solvable_graph, dv)
@@ -427,6 +431,10 @@ function find_duplicate_dd(dv, solvable_graph, diff_to_var, linear_eqs, mm)
     return nothing 
 end
 
+"""
+Add a dummy derivative variable x_t corresponding to symbolic variable D(x) 
+which has index dv in `fullvars`. Return the new index of x_t.
+"""
 function add_dd_variable!(s::SystemStructure, fullvars, x_t, dv)
     push!(fullvars, simplify_shifts(x_t))
     v_t = length(fullvars)
@@ -439,7 +447,11 @@ function add_dd_variable!(s::SystemStructure, fullvars, x_t, dv)
     v_t
 end
 
-# dv = index of D(x), v_t = index of x_t
+"""
+Add the equation D(x) - x_t ~ 0 to `neweqs`. `dv` and `v_t` are the indices
+of the higher-order derivative variable and the newly-introduced dummy
+derivative variable. Return the index of the new equation in `neweqs`.
+"""
 function add_dd_equation!(s::SystemStructure, neweqs, eq, dv, v_t)
     push!(neweqs, eq)
     add_vertex!(s.graph, SRC)
@@ -452,8 +464,33 @@ function add_dd_equation!(s::SystemStructure, neweqs, eq, dv, v_t)
 end
 
 """
-Solve the solvable equations of the system and generate differential (or discrete)
-equations in terms of the selected states.
+Solve the equations in `neweqs` to obtain the final equations of the 
+system.
+
+For each equation of `neweqs`, do one of the following: 
+   1. If the equation is solvable for a differentiated variable D(x),
+      then solve for D(x), and add D(x) ~ sol as a differential equation
+      of the system.
+   2. If the equation is solvable for an un-differentiated variable x, 
+      solve for x and then add x ~ sol as a solved equation. These will
+      become observables.
+   3. If the equation is not solvable, add it as an algebraic equation.
+
+Solved equations are added to `total_sub`. Occurrences of differential
+or solved variables on the RHS of the final equations will get substituted.
+The topological sort of the equations ensures that variables are solved for
+before they appear in equations. 
+
+Reorder the equations and unknowns to be:
+   [diffeqs; ...]
+   [diffvars; ...]
+such that the mass matrix is:
+   [I  0
+    0  0].
+
+Order the new equations and variables such that the differential equations
+and variables come first. Return the new equations, the solved equations,
+the new orderings, and the number of solved variables and equations.
 """
 function generate_system_equations!(state::TearingState, neweqs, var_eq_matching; simplify = false, iv = nothing, D = nothing)
     @unpack fullvars, sys, structure = state 
@@ -550,6 +587,9 @@ function generate_system_equations!(state::TearingState, neweqs, var_eq_matching
     return neweqs, solved_eqs, eq_ordering, var_ordering, length(solved_vars), length(solved_vars_set)
 end
 
+"""
+Occurs when a variable D(x) occurs in a non-differential system.
+"""
 struct UnexpectedDifferentialError
     eq::Equation
 end
@@ -558,12 +598,20 @@ function Base.showerror(io::IO, err::UnexpectedDifferentialError)
     error("Differential found in a non-differential system. Likely this is a bug in the construction of an initialization system. Please report this issue with a reproducible example. Offending equation: $(err.eq)")
 end
 
+"""
+Generate a first-order differential equation whose LHS is `dx`.
+
+`var` and `dx` represent the same variable, but `var` may be a higher-order differential and `dx` is always first-order. For example, if `var` is D(D(x)), then `dx` would be `D(x_t)`. Solve `eq` for `var`, substitute previously solved variables, and return the differential equation.
+"""
 function make_differential_equation(var, dx, eq, total_sub)
     dx ~ simplify_shifts(Symbolics.fixpoint_sub(
         Symbolics.symbolic_linear_solve(eq, var),
         total_sub; operator = ModelingToolkit.Shift))
 end
 
+"""
+Generate an algebraic equation. Substitute solved variables into `eq` and return the equation.
+"""
 function make_algebraic_equation(eq, total_sub)
     rhs = eq.rhs
     if !(eq.lhs isa Number && eq.lhs == 0)
@@ -572,6 +620,9 @@ function make_algebraic_equation(eq, total_sub)
     0 ~ simplify_shifts(Symbolics.fixpoint_sub(rhs, total_sub))
 end
 
+"""
+Solve equation `eq` for `var`, substitute previously solved variables, and return the solved equation.
+"""
 function make_solved_equation(var, eq, total_sub; simplify = false)
     residual = eq.lhs - eq.rhs
     a, b, islinear = linear_expansion(residual, var)
@@ -591,17 +642,13 @@ function make_solved_equation(var, eq, total_sub; simplify = false)
 end
 
 """
-Reorder the equations and unknowns to be:
-   [diffeqs; ...]
-   [diffvars; ...]
-such that the mass matrix is:
-   [I  0
-    0  0].
-
-Update the state to account for the new ordering and equations.
+Given the ordering returned by `generate_system_equations!`, update the 
+tearing state to account for the new order. Permute the variables and equations.
+Eliminate the solved variables and equations from the graph and permute the
+graph's vertices to account for the new variable/equation ordering.
 """
 # TODO: BLT sorting
-function reorder_vars!(state::TearingState, var_eq_matching, eq_ordering, var_ordering, nelim_eq, nelim_var)
+function reorder_vars!(state::TearingState, var_eq_matching, eq_ordering, var_ordering, nsolved_eq, nsolved_var)
     @unpack solvable_graph, var_to_diff, eq_to_diff, graph = state.structure
 
     eqsperm = zeros(Int, nsrcs(graph))
@@ -616,7 +663,7 @@ function reorder_vars!(state::TearingState, var_eq_matching, eq_ordering, var_or
     # Contract the vertices in the structure graph to make the structure match
     # the new reality of the system we've just created.
     new_graph = contract_variables(graph, var_eq_matching, varsperm, eqsperm,
-        nelim_eq, nelim_var)
+        nsolved_eq, nsolved_var)
 
     new_var_to_diff = complete(DiffGraph(length(var_ordering)))
     for (v, d) in enumerate(var_to_diff)
@@ -643,7 +690,7 @@ function reorder_vars!(state::TearingState, var_eq_matching, eq_ordering, var_or
 end
 
 """
-Set the system equations, unknowns, observables post-tearing.
+Update the system equations, unknowns, and observables after simplification.
 """
 function update_simplified_system!(state::TearingState, neweqs, solved_eqs, dummy_sub, var_eq_matching, extra_unknowns; 
         cse_hack = true, array_hack = true)
@@ -685,16 +732,10 @@ function update_simplified_system!(state::TearingState, neweqs, solved_eqs, dumm
     sys = schedule(sys)
 end
 
-# Terminology and Definition:
-# A general DAE is in the form of `F(u'(t), u(t), p, t) == 0`. We can
-# characterize variables in `u(t)` into two classes: differential variables
-# (denoted `v(t)`) and algebraic variables (denoted `z(t)`). Differential
-# variables are marked as `SelectedState` and they are differentiated in the
-# DAE system, i.e. `v'(t)` are all the variables in `u'(t)` that actually
-# appear in the system. Algebraic variables are variables that are not
-# differential variables.
     
-# Give the order of the variable indexed by dv
+"""
+Give the order of the variable indexed by dv.
+"""
 function var_order(dv, diff_to_var)
     order = 0
     while (dv′ = diff_to_var[dv]) !== nothing
@@ -704,6 +745,21 @@ function var_order(dv, diff_to_var)
     order, dv
 end
 
+"""
+Main internal function for structural simplification for DAE systems and discrete systems.
+Generate dummy derivative variables, new equations in terms of variables, return updated
+system and tearing state.
+
+Terminology and Definition:
+
+A general DAE is in the form of `F(u'(t), u(t), p, t) == 0`. We can
+characterize variables in `u(t)` into two classes: differential variables
+(denoted `v(t)`) and algebraic variables (denoted `z(t)`). Differential
+variables are marked as `SelectedState` and they are differentiated in the
+DAE system, i.e. `v'(t)` are all the variables in `u'(t)` that actually
+appear in the system. Algebraic variables are variables that are not
+differential variables.
+"""
 function tearing_reassemble(state::TearingState, var_eq_matching,
         full_var_eq_matching = nothing; simplify = false, mm = nothing, cse_hack = true, array_hack = true)
     extra_vars = Int[]
diff --git a/src/structural_transformation/utils.jl b/src/structural_transformation/utils.jl
index f64a8da132..2753446a94 100644
--- a/src/structural_transformation/utils.jl
+++ b/src/structural_transformation/utils.jl
@@ -449,7 +449,12 @@ end
 ### Misc
 ###
 
-# For discrete variables. Turn Shift(t, k)(x(t)) into xₜ₋ₖ(t)
+"""
+Handle renaming variable names for discrete structural simplification. Three cases: 
+- positive shift: do nothing
+- zero shift: x(t) => Shift(t, 0)(x(t))
+- negative shift: rename the variable
+"""
 function lower_shift_varname(var, iv)
     op = operation(var)
     op isa Shift || return Shift(iv, 0)(var, true) # hack to prevent simplification of x(t) - x(t)
@@ -460,30 +465,46 @@ function lower_shift_varname(var, iv)
     end
 end
 
-function shift2term(var) 
+"""
+Rename a Shift variable with negative shift, Shift(t, k)(x(t)) to xₜ₋ₖ(t).
+"""
+function shift2term(var)
     backshift = operation(var).steps
     iv = operation(var).t
-    num = join(Char(0x2080 + d) for d in reverse!(digits(-backshift)))
-    ds = join([Char(0x209c), Char(0x208b), num])
-    #ds = "$iv-$(-backshift)"
-    #d_separator = 'ˍ'
-
-    if ModelingToolkit.isoperator(var, ModelingToolkit.Shift)
-        O = only(arguments(var))
-        oldop = operation(O)
-        newname = Symbol(string(nameof(oldop)), ds)
-    else
-        O = var
-        oldop = operation(var) 
-        varname = split(string(nameof(oldop)), d_separator)[1]
-        newname = Symbol(varname, d_separator, ds)
-    end
+    num = join(Char(0x2080 + d) for d in reverse!(digits(-backshift))) # subscripted number, e.g. ₁
+    ds = join([Char(0x209c), Char(0x208b), num]) 
+    # Char(0x209c) = ₜ
+    # Char(0x208b) = ₋ (subscripted minus)
+
+    O = only(arguments(var))
+    oldop = operation(O)
+    newname = Symbol(string(nameof(oldop)), ds)
+
     newvar = maketerm(typeof(O), Symbolics.rename(oldop, newname), Symbolics.children(O), Symbolics.metadata(O))
     newvar = setmetadata(newvar, Symbolics.VariableSource, (:variables, newname))
     newvar = setmetadata(newvar, ModelingToolkit.VariableUnshifted, O)
+    newvar = setmetadata(newvar, ModelingToolkit.VariableShift, backshift)
     return newvar
 end
 
+function term2shift(var)
+    var = Symbolics.unwrap(var)
+    name = Symbolics.getname(var)
+    O = only(arguments(var))
+    oldop = operation(O)
+    iv = only(arguments(x))
+    # Split on ₋
+    if occursin(Char(0x208b), name)
+        substrings = split(name, Char(0x208b))
+        shift = last(split(name, Char(0x208b)))
+        newname = join(substrings[1:end-1])[1:end-1]
+        newvar = maketerm(typeof(O), Symbolics.rename(oldop, newname), Symbolics.children(O), Symbolics.metadata(O))
+        return Shift(iv, -shift)(newvar)
+    else
+        return var
+    end
+end
+
 function isdoubleshift(var)
     return ModelingToolkit.isoperator(var, ModelingToolkit.Shift) &&
            ModelingToolkit.isoperator(arguments(var)[1], ModelingToolkit.Shift)
diff --git a/src/systems/discrete_system/discrete_system.jl b/src/systems/discrete_system/discrete_system.jl
index 9a9ac47853..40bf632339 100644
--- a/src/systems/discrete_system/discrete_system.jl
+++ b/src/systems/discrete_system/discrete_system.jl
@@ -270,15 +270,19 @@ function shift_u0map_forward(sys::DiscreteSystem, u0map, defs)
         v = u0map[k]
         if !((op = operation(k)) isa Shift)
             error("Initial conditions must be for the past state of the unknowns. Instead of providing the condition for $k, provide the condition for $(Shift(iv, -1)(k)).")
+        elseif op.steps > 0
+            error("Initial conditions must be for the past state of the unknowns. Instead of providing the condition for $k, provide the condition for $(Shift(iv, -1)(only(arguments(k)))).")
         end
+
         updated[Shift(iv, op.steps + 1)(arguments(k)[1])] = v
     end
     for var in unknowns(sys)
         op = operation(var)
-        haskey(updated, var) && continue
         root = getunshifted(var)
+        shift = getshift(var)
         isnothing(root) && continue
-        haskey(defs, root) || error("Initial condition for $root not provided.")
+        (haskey(updated, Shift(iv, shift)(root)) || haskey(updated, var)) && continue
+        haskey(defs, root) || error("Initial condition for $var not provided.")
         updated[var] = defs[root]
     end
     return updated
diff --git a/src/variables.jl b/src/variables.jl
index a29a119607..4e13ad2c5d 100644
--- a/src/variables.jl
+++ b/src/variables.jl
@@ -6,7 +6,9 @@ struct VariableOutput end
 struct VariableIrreducible end
 struct VariableStatePriority end
 struct VariableMisc end
+# Metadata for renamed shift variables xₜ₋₁
 struct VariableUnshifted end
+struct VariableShift end
 Symbolics.option_to_metadata_type(::Val{:unit}) = VariableUnit
 Symbolics.option_to_metadata_type(::Val{:connect}) = VariableConnectType
 Symbolics.option_to_metadata_type(::Val{:input}) = VariableInput
@@ -15,6 +17,7 @@ Symbolics.option_to_metadata_type(::Val{:irreducible}) = VariableIrreducible
 Symbolics.option_to_metadata_type(::Val{:state_priority}) = VariableStatePriority
 Symbolics.option_to_metadata_type(::Val{:misc}) = VariableMisc
 Symbolics.option_to_metadata_type(::Val{:unshifted}) = VariableUnshifted
+Symbolics.option_to_metadata_type(::Val{:shift}) = VariableShift
 
 """
     dump_variable_metadata(var)
@@ -97,7 +100,7 @@ struct Stream <: AbstractConnectType end   # special stream connector
 
 Get the connect type of x. See also [`hasconnect`](@ref).
 """
-getconnect(x) = getconnect(unwrap(x))
+getconnect(x::Num) = getconnect(unwrap(x))
 getconnect(x::Symbolic) = Symbolics.getmetadata(x, VariableConnectType, nothing)
 """
     hasconnect(x)
@@ -264,7 +267,7 @@ end
 end
 
 struct IsHistory end
-ishistory(x) = ishistory(unwrap(x))
+ishistory(x::Num) = ishistory(unwrap(x))
 ishistory(x::Symbolic) = getmetadata(x, IsHistory, false)
 hist(x, t) = wrap(hist(unwrap(x), t))
 function hist(x::Symbolic, t)
@@ -575,7 +578,7 @@ end
 Fetch any miscellaneous data associated with symbolic variable `x`.
 See also [`hasmisc(x)`](@ref).
 """
-getmisc(x) = getmisc(unwrap(x))
+getmisc(x::Num) = getmisc(unwrap(x))
 getmisc(x::Symbolic) = Symbolics.getmetadata(x, VariableMisc, nothing)
 """
     hasmisc(x)
@@ -594,7 +597,7 @@ setmisc(x, miscdata) = setmetadata(x, VariableMisc, miscdata)
 
 Fetch the unit associated with variable `x`. This function is a metadata getter for an individual variable, while `get_unit` is used for unit inference on more complicated sdymbolic expressions.
 """
-getunit(x) = getunit(unwrap(x))
+getunit(x::Num) = getunit(unwrap(x))
 getunit(x::Symbolic) = Symbolics.getmetadata(x, VariableUnit, nothing)
 """
     hasunit(x)
@@ -603,5 +606,8 @@ Check if the variable `x` has a unit.
 """
 hasunit(x) = getunit(x) !== nothing
 
-getunshifted(x) = getunshifted(unwrap(x))
+getunshifted(x::Num) = getunshifted(unwrap(x))
 getunshifted(x::Symbolic) = Symbolics.getmetadata(x, VariableUnshifted, nothing)
+
+getshift(x::Num) = getshift(unwrap(x))
+getshift(x::Symbolic) = Symbolics.getmetadata(x, VariableShift, 0)
diff --git a/test/discrete_system.jl b/test/discrete_system.jl
index b63f66d4e4..fa7ba993e6 100644
--- a/test/discrete_system.jl
+++ b/test/discrete_system.jl
@@ -220,21 +220,6 @@ sol = solve(prob, FunctionMap())
 
 @test reduce(vcat, sol.u) == 1:11
 
-# test that default values apply to the entire history
-@variables x(t) = 1.0
-@mtkbuild de = DiscreteSystem([x ~ x(k - 1) + x(k - 2)], t)
-prob = DiscreteProblem(de, [], (0, 10))
-@test prob[x] == 2.0
-@test prob[x(k - 1)] == 1.0
-
-# must provide initial conditions for history
-@test_throws ErrorException DiscreteProblem(de, [x => 2.0], (0, 10))
-
-# initial values only affect _that timestep_, not the entire history
-prob = DiscreteProblem(de, [x(k - 1) => 2.0], (0, 10))
-@test prob[x] == 3.0
-@test prob[x(k - 1)] == 2.0
-
 # Issue#2585
 getdata(buffer, t) = buffer[mod1(Int(t), length(buffer))]
 @register_symbolic getdata(buffer::Vector, t)
@@ -272,6 +257,7 @@ k = ShiftIndex(t)
 @named sys = DiscreteSystem([x ~ x^2 + y^2, y ~ x(k - 1) + y(k - 1)], t)
 @test_throws ["algebraic equations", "not yet supported"] structural_simplify(sys)
 
+
 @testset "Passing `nothing` to `u0`" begin
     @variables x(t) = 1
     k = ShiftIndex()
@@ -279,3 +265,55 @@ k = ShiftIndex(t)
     prob = @test_nowarn DiscreteProblem(sys, nothing, (0.0, 1.0))
     @test_nowarn solve(prob, FunctionMap())
 end
+
+@testset "Initialization" begin
+    # test that default values apply to the entire history
+    @variables x(t) = 1.0
+    @mtkbuild de = DiscreteSystem([x ~ x(k - 1) + x(k - 2)], t)
+    prob = DiscreteProblem(de, [], (0, 10))
+    @test prob[x] == 2.0
+    @test prob[x(k - 1)] == 1.0
+    
+    # must provide initial conditions for history
+    @test_throws ErrorException DiscreteProblem(de, [x => 2.0], (0, 10))
+    @test_throws ErrorException DiscreteProblem(de, [x(k+1) => 2.], (0, 10))
+    
+    # initial values only affect _that timestep_, not the entire history
+    prob = DiscreteProblem(de, [x(k - 1) => 2.0], (0, 10))
+    @test prob[x] == 3.0
+    @test prob[x(k - 1)] == 2.0
+    @test prob[xₜ₋₁] == 2.0
+
+    # Test initial assignment with lowered variable
+    @variables xₜ₋₁(t)
+    prob = DiscreteProblem(de, [xₜ₋₁(k-1) => 4.0], (0, 10))
+    @test prob[x(k-1)] == prob[xₜ₋₁] == 1.0
+    @test prob[x] == 5.
+
+    # Test missing initial throws error
+    @variables x(t)
+    @mtkbuild de = DiscreteSystem([x ~ x(k-1) + x(k-2)*x(k-3)], t)
+    @test_throws ErrorException prob = DiscreteProblem(de, [x(k-3) => 2.], (0, 10))
+    @test_throws ErrorException prob = DiscreteProblem(de, [x(k-3) => 2., x(k-1) => 3.], (0, 10))
+
+    # Test non-assigned initials are given default value
+    @variables x(t) = 2.
+    prob = DiscreteProblem(de, [x(k-3) => 12.], (0, 10))
+    @test prob[x] == 26.0
+    @test prob[x(k-1)] == 2.0
+    @test prob[x(k-2)] == 2.0
+
+    # Elaborate test
+    eqs = [x ~ x(k-1) + z(k-2), 
+           z ~ x(k-2) * x(k-3) - z(k-1)^2]
+    @mtkbuild de = DiscreteSystem(eqs, t)
+    @variables xₜ₋₂(t) zₜ₋₁(t)
+    u0 = [x(k-1) => 3, 
+          xₜ₋₂(k-1) => 4, 
+          x(k-2) => 1, 
+          z(k-1) => 5, 
+          zₜ₋₁(k-1) => 12]
+    prob = DiscreteProblem(de, u0, (0, 10))
+    @test prob[x] == 15
+    @test prob[z] == -21
+end

From fc2a309f3d9f8e607329b9ee640672ca4f872eb3 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Wed, 19 Feb 2025 10:12:41 -0800
Subject: [PATCH 096/111] fix tests and shift2term

---
 src/structural_transformation/utils.jl        | 31 ++++++-------------
 .../discrete_system/discrete_system.jl        |  7 +++--
 test/discrete_system.jl                       | 13 ++++++--
 3 files changed, 25 insertions(+), 26 deletions(-)

diff --git a/src/structural_transformation/utils.jl b/src/structural_transformation/utils.jl
index 2753446a94..96f7f78f99 100644
--- a/src/structural_transformation/utils.jl
+++ b/src/structural_transformation/utils.jl
@@ -469,16 +469,21 @@ end
 Rename a Shift variable with negative shift, Shift(t, k)(x(t)) to xₜ₋ₖ(t).
 """
 function shift2term(var)
-    backshift = operation(var).steps
-    iv = operation(var).t
+    op = operation(var)
+    iv = op.t
+    arg = only(arguments(var))
+    is_lowered = !isnothing(ModelingToolkit.getunshifted(arg))
+
+    backshift = is_lowered ? op.steps + ModelingToolkit.getshift(arg) : op.steps
+
     num = join(Char(0x2080 + d) for d in reverse!(digits(-backshift))) # subscripted number, e.g. ₁
     ds = join([Char(0x209c), Char(0x208b), num]) 
     # Char(0x209c) = ₜ
     # Char(0x208b) = ₋ (subscripted minus)
 
-    O = only(arguments(var))
+    O = is_lowered ? ModelingToolkit.getunshifted(arg) : arg
     oldop = operation(O)
-    newname = Symbol(string(nameof(oldop)), ds)
+    newname = backshift != 0 ? Symbol(string(nameof(oldop)), ds) : Symbol(string(nameof(oldop)))
 
     newvar = maketerm(typeof(O), Symbolics.rename(oldop, newname), Symbolics.children(O), Symbolics.metadata(O))
     newvar = setmetadata(newvar, Symbolics.VariableSource, (:variables, newname))
@@ -487,24 +492,6 @@ function shift2term(var)
     return newvar
 end
 
-function term2shift(var)
-    var = Symbolics.unwrap(var)
-    name = Symbolics.getname(var)
-    O = only(arguments(var))
-    oldop = operation(O)
-    iv = only(arguments(x))
-    # Split on ₋
-    if occursin(Char(0x208b), name)
-        substrings = split(name, Char(0x208b))
-        shift = last(split(name, Char(0x208b)))
-        newname = join(substrings[1:end-1])[1:end-1]
-        newvar = maketerm(typeof(O), Symbolics.rename(oldop, newname), Symbolics.children(O), Symbolics.metadata(O))
-        return Shift(iv, -shift)(newvar)
-    else
-        return var
-    end
-end
-
 function isdoubleshift(var)
     return ModelingToolkit.isoperator(var, ModelingToolkit.Shift) &&
            ModelingToolkit.isoperator(arguments(var)[1], ModelingToolkit.Shift)
diff --git a/src/systems/discrete_system/discrete_system.jl b/src/systems/discrete_system/discrete_system.jl
index 40bf632339..3f8d9e85c6 100644
--- a/src/systems/discrete_system/discrete_system.jl
+++ b/src/systems/discrete_system/discrete_system.jl
@@ -269,12 +269,15 @@ function shift_u0map_forward(sys::DiscreteSystem, u0map, defs)
     for k in collect(keys(u0map))
         v = u0map[k]
         if !((op = operation(k)) isa Shift)
-            error("Initial conditions must be for the past state of the unknowns. Instead of providing the condition for $k, provide the condition for $(Shift(iv, -1)(k)).")
+            isnothing(getunshifted(k)) && error("Initial conditions must be for the past state of the unknowns. Instead of providing the condition for $k, provide the condition for $(Shift(iv, -1)(k)).")
+            
+            updated[Shift(iv, 1)(k)] = v
         elseif op.steps > 0
             error("Initial conditions must be for the past state of the unknowns. Instead of providing the condition for $k, provide the condition for $(Shift(iv, -1)(only(arguments(k)))).")
+        else
+            updated[Shift(iv, op.steps + 1)(only(arguments(k)))] = v
         end
 
-        updated[Shift(iv, op.steps + 1)(arguments(k)[1])] = v
     end
     for var in unknowns(sys)
         op = operation(var)
diff --git a/test/discrete_system.jl b/test/discrete_system.jl
index fa7ba993e6..756e5bca48 100644
--- a/test/discrete_system.jl
+++ b/test/discrete_system.jl
@@ -282,10 +282,10 @@ end
     prob = DiscreteProblem(de, [x(k - 1) => 2.0], (0, 10))
     @test prob[x] == 3.0
     @test prob[x(k - 1)] == 2.0
+    @variables xₜ₋₁(t)
     @test prob[xₜ₋₁] == 2.0
 
     # Test initial assignment with lowered variable
-    @variables xₜ₋₁(t)
     prob = DiscreteProblem(de, [xₜ₋₁(k-1) => 4.0], (0, 10))
     @test prob[x(k-1)] == prob[xₜ₋₁] == 1.0
     @test prob[x] == 5.
@@ -298,16 +298,17 @@ end
 
     # Test non-assigned initials are given default value
     @variables x(t) = 2.
+    @mtkbuild de = DiscreteSystem([x ~ x(k-1) + x(k-2)*x(k-3)], t)
     prob = DiscreteProblem(de, [x(k-3) => 12.], (0, 10))
     @test prob[x] == 26.0
     @test prob[x(k-1)] == 2.0
     @test prob[x(k-2)] == 2.0
 
     # Elaborate test
+    @variables xₜ₋₂(t) zₜ₋₁(t) z(t)
     eqs = [x ~ x(k-1) + z(k-2), 
            z ~ x(k-2) * x(k-3) - z(k-1)^2]
     @mtkbuild de = DiscreteSystem(eqs, t)
-    @variables xₜ₋₂(t) zₜ₋₁(t)
     u0 = [x(k-1) => 3, 
           xₜ₋₂(k-1) => 4, 
           x(k-2) => 1, 
@@ -316,4 +317,12 @@ end
     prob = DiscreteProblem(de, u0, (0, 10))
     @test prob[x] == 15
     @test prob[z] == -21
+
+    import ModelingToolkit: shift2term
+    # unknowns(de) = xₜ₋₁, x, zₜ₋₁, xₜ₋₂, z
+    vars = ModelingToolkit.value.(unknowns(de))
+    @test isequal(shift2term(Shift(t, 1)(vars[1])), vars[2])
+    @test isequal(shift2term(Shift(t, 1)(vars[4])), vars[1])
+    @test isequal(shift2term(Shift(t, -1)(vars[5])), vars[3])
+    @test isequal(shift2term(Shift(t, -2)(vars[2])), vars[4])
 end

From ceb9cdc7493f8c622fcfb9821437de2d74ef5e18 Mon Sep 17 00:00:00 2001
From: Fredrik Bagge Carlson <baggepinnen@gmail.com>
Date: Fri, 7 Feb 2025 09:20:28 +0100
Subject: [PATCH 097/111] add utilities and tests for disturbance modeling

Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

rm plot
---
 src/systems/analysis_points.jl            |  33 ++++
 src/systems/diffeqs/odesystem.jl          |  13 +-
 test/downstream/test_disturbance_model.jl | 215 ++++++++++++++++++++++
 test/runtests.jl                          |   1 +
 4 files changed, 261 insertions(+), 1 deletion(-)
 create mode 100644 test/downstream/test_disturbance_model.jl

diff --git a/src/systems/analysis_points.jl b/src/systems/analysis_points.jl
index 022a0909ed..135bf81729 100644
--- a/src/systems/analysis_points.jl
+++ b/src/systems/analysis_points.jl
@@ -960,3 +960,36 @@ Compute the (linearized) loop-transfer function in analysis point `ap`, from `ap
 
 See also [`get_sensitivity`](@ref), [`get_comp_sensitivity`](@ref), [`open_loop`](@ref).
 """ get_looptransfer
+# 
+
+"""
+    generate_control_function(sys::ModelingToolkit.AbstractODESystem, input_ap_name::Union{Symbol, Vector{Symbol}, AnalysisPoint, Vector{AnalysisPoint}}, dist_ap_name::Union{Symbol, Vector{Symbol}, AnalysisPoint, Vector{AnalysisPoint}}; system_modifier = identity, kwargs)
+
+When called with analysis points as input arguments, we assume that all analysis points corresponds to connections that should be opened (broken). The use case for this is to get rid of input signal blocks, such as `Step` or `Sine`, since these are useful for simulation but are not needed when using the plant model in a controller or state estimator.
+"""
+function generate_control_function(
+        sys::ModelingToolkit.AbstractODESystem, input_ap_name::Union{
+            Symbol, Vector{Symbol}, AnalysisPoint, Vector{AnalysisPoint}},
+        dist_ap_name::Union{
+            Nothing, Symbol, Vector{Symbol}, AnalysisPoint, Vector{AnalysisPoint}} = nothing;
+        system_modifier = identity,
+        kwargs...)
+    input_ap_name = canonicalize_ap(sys, input_ap_name)
+    u = []
+    for input_ap in input_ap_name
+        sys, (du, _) = open_loop(sys, input_ap)
+        push!(u, du)
+    end
+    if dist_ap_name === nothing
+        return ModelingToolkit.generate_control_function(system_modifier(sys), u; kwargs...)
+    end
+
+    dist_ap_name = canonicalize_ap(sys, dist_ap_name)
+    d = []
+    for dist_ap in dist_ap_name
+        sys, (du, _) = open_loop(sys, dist_ap)
+        push!(d, du)
+    end
+
+    ModelingToolkit.generate_control_function(system_modifier(sys), u, d; kwargs...)
+end
diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl
index 1590fd6ebd..20d1e495dc 100644
--- a/src/systems/diffeqs/odesystem.jl
+++ b/src/systems/diffeqs/odesystem.jl
@@ -436,6 +436,8 @@ an array of inputs `inputs` is given, and `param_only` is false for a time-depen
 """
 function build_explicit_observed_function(sys, ts;
         inputs = nothing,
+        disturbance_inputs = nothing,
+        disturbance_argument = false,
         expression = false,
         eval_expression = false,
         eval_module = @__MODULE__,
@@ -512,13 +514,22 @@ function build_explicit_observed_function(sys, ts;
         ps = setdiff(ps, inputs) # Inputs have been converted to parameters by io_preprocessing, remove those from the parameter list
         inputs = (inputs,)
     end
+    if disturbance_inputs !== nothing
+        # Disturbance inputs may or may not be included as inputs, depending on disturbance_argument
+        ps = setdiff(ps, disturbance_inputs)
+    end
+    if disturbance_argument
+        disturbance_inputs = (disturbance_inputs,)
+    else
+        disturbance_inputs = ()
+    end
     ps = reorder_parameters(sys, ps)
     iv = if is_time_dependent(sys)
         (get_iv(sys),)
     else
         ()
     end
-    args = (dvs..., inputs..., ps..., iv...)
+    args = (dvs..., inputs..., ps..., iv..., disturbance_inputs...)
     p_start = length(dvs) + length(inputs) + 1
     p_end = length(dvs) + length(inputs) + length(ps)
     fns = build_function_wrapper(
diff --git a/test/downstream/test_disturbance_model.jl b/test/downstream/test_disturbance_model.jl
new file mode 100644
index 0000000000..10fcf9fc1f
--- /dev/null
+++ b/test/downstream/test_disturbance_model.jl
@@ -0,0 +1,215 @@
+#=
+This file implements and tests a typical workflow for state estimation with disturbance models
+The primary subject of the tests is the analysis-point features and the
+analysis-point specific method for `generate_control_function`.
+=#
+using ModelingToolkit, OrdinaryDiffEq, LinearAlgebra, Test
+using ModelingToolkitStandardLibrary.Mechanical.Rotational
+using ModelingToolkitStandardLibrary.Blocks
+using ModelingToolkit: connect
+# using Plots
+
+using ModelingToolkit: t_nounits as t, D_nounits as D
+
+indexof(sym, syms) = findfirst(isequal(sym), syms)
+
+## Build the system model ======================================================
+@mtkmodel SystemModel begin
+    @parameters begin
+        m1 = 1
+        m2 = 1
+        k = 10 # Spring stiffness
+        c = 3  # Damping coefficient
+    end
+    @components begin
+        inertia1 = Inertia(; J = m1, phi = 0, w = 0)
+        inertia2 = Inertia(; J = m2, phi = 0, w = 0)
+        spring = Spring(; c = k)
+        damper = Damper(; d = c)
+        torque = Torque(use_support = false)
+    end
+    @equations begin
+        connect(torque.flange, inertia1.flange_a)
+        connect(inertia1.flange_b, spring.flange_a, damper.flange_a)
+        connect(inertia2.flange_a, spring.flange_b, damper.flange_b)
+    end
+end
+
+@mtkmodel ModelWithInputs begin
+    @components begin
+        input_signal = Blocks.Sine(frequency = 1, amplitude = 1)
+        disturbance_signal1 = Blocks.Constant(k = 0)
+        disturbance_signal2 = Blocks.Constant(k = 0)
+        disturbance_torque1 = Torque(use_support = false)
+        disturbance_torque2 = Torque(use_support = false)
+        system_model = SystemModel()
+    end
+    @equations begin
+        connect(input_signal.output, :u, system_model.torque.tau)
+        connect(disturbance_signal1.output, :d1, disturbance_torque1.tau)
+        connect(disturbance_signal2.output, :d2, disturbance_torque2.tau)
+        connect(disturbance_torque1.flange, system_model.inertia1.flange_b)
+        connect(disturbance_torque2.flange, system_model.inertia2.flange_b)
+    end
+end
+
+@named model = ModelWithInputs() # Model with load disturbance
+ssys = structural_simplify(model)
+prob = ODEProblem(ssys, [], (0.0, 10.0))
+sol = solve(prob, Tsit5())
+# plot(sol)
+
+##
+using ControlSystemsBase, ControlSystemsMTK
+cmodel = complete(model)
+P = cmodel.system_model
+lsys = named_ss(
+    model, [:u, :d1], [P.inertia1.phi, P.inertia2.phi, P.inertia1.w, P.inertia2.w])
+
+##
+# If we now want to add a disturbance model, we cannot do that since we have already connected a constant to the disturbance input. We have also already used the name `d` for an analysis point, but that might not be an issue since we create an outer model and get a new namespace.
+
+s = tf("s")
+dist(; name) = ODESystem(1 / s; name)
+
+@mtkmodel SystemModelWithDisturbanceModel begin
+    @components begin
+        input_signal = Blocks.Sine(frequency = 1, amplitude = 1)
+        disturbance_signal1 = Blocks.Constant(k = 0)
+        disturbance_signal2 = Blocks.Constant(k = 0)
+        disturbance_torque1 = Torque(use_support = false)
+        disturbance_torque2 = Torque(use_support = false)
+        disturbance_model = dist()
+        system_model = SystemModel()
+    end
+    @equations begin
+        connect(input_signal.output, :u, system_model.torque.tau)
+        connect(disturbance_signal1.output, :d1, disturbance_model.input)
+        connect(disturbance_model.output, disturbance_torque1.tau)
+        connect(disturbance_signal2.output, :d2, disturbance_torque2.tau)
+        connect(disturbance_torque1.flange, system_model.inertia1.flange_b)
+        connect(disturbance_torque2.flange, system_model.inertia2.flange_b)
+    end
+end
+
+@named model_with_disturbance = SystemModelWithDisturbanceModel()
+# ssys = structural_simplify(open_loop(model_with_disturbance, :d)) # Open loop worked, but it's a bit awkward that we have to use it here
+# lsys2 = named_ss(model_with_disturbance, [:u, :d1],
+# [P.inertia1.phi, P.inertia2.phi, P.inertia1.w, P.inertia2.w])
+ssys = structural_simplify(model_with_disturbance)
+prob = ODEProblem(ssys, [], (0.0, 10.0))
+sol = solve(prob, Tsit5())
+@test SciMLBase.successful_retcode(sol)
+# plot(sol)
+
+## 
+# Now we only have an integrating disturbance affecting inertia1, what if we want both integrating and direct Gaussian? We'd need a "PI controller" disturbancemodel. If we add the disturbance model (s+1)/s we get the integrating and non-integrating noises being correlated which is fine, it reduces the dimensions of the sigma point by 1.
+
+dist3(; name) = ODESystem(ss(1 + 10 / s, balance = false); name)
+
+@mtkmodel SystemModelWithDisturbanceModel begin
+    @components begin
+        input_signal = Blocks.Sine(frequency = 1, amplitude = 1)
+        disturbance_signal1 = Blocks.Constant(k = 0)
+        disturbance_signal2 = Blocks.Constant(k = 0)
+        disturbance_torque1 = Torque(use_support = false)
+        disturbance_torque2 = Torque(use_support = false)
+        disturbance_model = dist3()
+        system_model = SystemModel()
+
+        y = Blocks.Add()
+        angle_sensor = AngleSensor()
+        output_disturbance = Blocks.Constant(k = 0)
+    end
+    @equations begin
+        connect(input_signal.output, :u, system_model.torque.tau)
+        connect(disturbance_signal1.output, :d1, disturbance_model.input)
+        connect(disturbance_model.output, disturbance_torque1.tau)
+        connect(disturbance_signal2.output, :d2, disturbance_torque2.tau)
+        connect(disturbance_torque1.flange, system_model.inertia1.flange_b)
+        connect(disturbance_torque2.flange, system_model.inertia2.flange_b)
+
+        connect(system_model.inertia1.flange_b, angle_sensor.flange)
+        connect(angle_sensor.phi, y.input1)
+        connect(output_disturbance.output, :dy, y.input2)
+    end
+end
+
+@named model_with_disturbance = SystemModelWithDisturbanceModel()
+# ssys = structural_simplify(open_loop(model_with_disturbance, :d)) # Open loop worked, but it's a bit awkward that we have to use it here
+# lsys3 = named_ss(model_with_disturbance, [:u, :d1],
+#     [P.inertia1.phi, P.inertia2.phi, P.inertia1.w, P.inertia2.w])
+ssys = structural_simplify(model_with_disturbance)
+prob = ODEProblem(ssys, [], (0.0, 10.0))
+sol = solve(prob, Tsit5())
+@test SciMLBase.successful_retcode(sol)
+# plot(sol)
+
+## Generate function for an augmented Unscented Kalman Filter =====================
+# temp = open_loop(model_with_disturbance, :d)
+outputs = [P.inertia1.phi, P.inertia2.phi, P.inertia1.w, P.inertia2.w]
+(f_oop1, f_ip), x_sym, p_sym, io_sys = ModelingToolkit.generate_control_function(
+    model_with_disturbance, [:u], [:d1, :d2, :dy], split = false)
+
+(f_oop2, f_ip2), x_sym, p_sym, io_sys = ModelingToolkit.generate_control_function(
+    model_with_disturbance, [:u], [:d1, :d2, :dy],
+    disturbance_argument = true, split = false)
+
+measurement = ModelingToolkit.build_explicit_observed_function(
+    io_sys, outputs, inputs = ModelingToolkit.inputs(io_sys)[1:1])
+measurement2 = ModelingToolkit.build_explicit_observed_function(
+    io_sys, [io_sys.y.output.u], inputs = ModelingToolkit.inputs(io_sys)[1:1],
+    disturbance_inputs = ModelingToolkit.inputs(io_sys)[2:end],
+    disturbance_argument = true)
+
+op = ModelingToolkit.inputs(io_sys) .=> 0
+x0, p = ModelingToolkit.get_u0_p(io_sys, op, op)
+x = zeros(5)
+u = zeros(1)
+d = zeros(3)
+@test f_oop2(x, u, p, t, d) == zeros(5)
+@test measurement(x, u, p, 0.0) == [0, 0, 0, 0]
+@test measurement2(x, u, p, 0.0, d) == [0]
+
+# Add to the integrating disturbance input
+d = [1, 0, 0]
+@test sort(f_oop2(x, u, p, 0.0, d)) == [0, 0, 0, 1, 1] # Affects disturbance state and one velocity
+@test measurement2(x, u, p, 0.0, d) == [0]
+
+d = [0, 1, 0]
+@test sort(f_oop2(x, u, p, 0.0, d)) == [0, 0, 0, 0, 1] # Affects one velocity
+@test measurement(x, u, p, 0.0) == [0, 0, 0, 0]
+@test measurement2(x, u, p, 0.0, d) == [0]
+
+d = [0, 0, 1]
+@test sort(f_oop2(x, u, p, 0.0, d)) == [0, 0, 0, 0, 0] # Affects nothing
+@test measurement(x, u, p, 0.0) == [0, 0, 0, 0]
+@test measurement2(x, u, p, 0.0, d) == [1] # We have now disturbed the output
+
+## Further downstream tests that the functions generated above actually have the properties required to use for state estimation
+# 
+# using LowLevelParticleFilters, SeeToDee
+# Ts = 0.001
+# discrete_dynamics = SeeToDee.Rk4(f_oop2, Ts)
+# nx = length(x_sym)
+# nu = 1
+# nw = 2
+# ny = length(outputs)
+# R1 = Diagonal([1e-5, 1e-5])
+# R2 = 0.1 * I(ny)
+# op = ModelingToolkit.inputs(io_sys) .=> 0
+# x0, p = ModelingToolkit.get_u0_p(io_sys, op, op)
+# d0 = LowLevelParticleFilters.SimpleMvNormal(x0, 10.0I(nx))
+# measurement_model = UKFMeasurementModel{Float64, false, false}(measurement, R2; nx, ny)
+# kf = UnscentedKalmanFilter{false, false, true, false}(
+#     discrete_dynamics, measurement_model, R1, d0; nu, Ts, p)
+
+# tvec = 0:Ts:sol.t[end]
+# u = vcat.(Array(sol(tvec, idxs = P.torque.tau.u)))
+# y = collect.(eachcol(Array(sol(tvec, idxs = outputs)) .+ 1e-2 .* randn.()))
+
+# inds = 1:5805
+# res = forward_trajectory(kf, u, y)
+
+# plot(res, size = (1000, 1000));
+# plot!(sol, idxs = x_sym, sp = (1:nx)', l = :dash);
diff --git a/test/runtests.jl b/test/runtests.jl
index 966b02cacb..11c78e43ca 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -120,6 +120,7 @@ end
         @safetestset "Linearization Dummy Derivative Tests" include("downstream/linearization_dd.jl")
         @safetestset "Inverse Models Test" include("downstream/inversemodel.jl")
         @safetestset "Analysis Points Test" include("downstream/analysis_points.jl")
+        @safetestset "Analysis Points Test" include("downstream/test_disturbance_model.jl")
     end
 
     if GROUP == "All" || GROUP == "FMI"

From 56b74353dec77231dc3f671c7358bba2a4e23734 Mon Sep 17 00:00:00 2001
From: Aayush Sabharwal <aayush.sabharwal@gmail.com>
Date: Fri, 14 Feb 2025 17:43:32 +0530
Subject: [PATCH 098/111] refactor: remove `MTKHomotopyContinuationExt`

---
 Project.toml                      |   3 -
 ext/MTKHomotopyContinuationExt.jl | 225 ------------------------------
 2 files changed, 228 deletions(-)
 delete mode 100644 ext/MTKHomotopyContinuationExt.jl

diff --git a/Project.toml b/Project.toml
index ba47d83eb6..3196235379 100644
--- a/Project.toml
+++ b/Project.toml
@@ -65,7 +65,6 @@ BifurcationKit = "0f109fa4-8a5d-4b75-95aa-f515264e7665"
 ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
 DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6"
 FMI = "14a09403-18e3-468f-ad8a-74f8dda2d9ac"
-HomotopyContinuation = "f213a82b-91d6-5c5d-acf7-10f1c761b327"
 InfiniteOpt = "20393b10-9daf-11e9-18c9-8db751c92c57"
 LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
 
@@ -74,7 +73,6 @@ MTKBifurcationKitExt = "BifurcationKit"
 MTKChainRulesCoreExt = "ChainRulesCore"
 MTKDeepDiffsExt = "DeepDiffs"
 MTKFMIExt = "FMI"
-MTKHomotopyContinuationExt = "HomotopyContinuation"
 MTKInfiniteOptExt = "InfiniteOpt"
 MTKLabelledArraysExt = "LabelledArrays"
 
@@ -110,7 +108,6 @@ ForwardDiff = "0.10.3"
 FunctionWrappers = "1.1"
 FunctionWrappersWrappers = "0.1"
 Graphs = "1.5.2"
-HomotopyContinuation = "2.11"
 InfiniteOpt = "0.5"
 InteractiveUtils = "1"
 JuliaFormatter = "1.0.47"
diff --git a/ext/MTKHomotopyContinuationExt.jl b/ext/MTKHomotopyContinuationExt.jl
deleted file mode 100644
index 8f17c05b18..0000000000
--- a/ext/MTKHomotopyContinuationExt.jl
+++ /dev/null
@@ -1,225 +0,0 @@
-module MTKHomotopyContinuationExt
-
-using ModelingToolkit
-using ModelingToolkit.SciMLBase
-using ModelingToolkit.Symbolics: unwrap, symtype, BasicSymbolic, simplify_fractions
-using ModelingToolkit.SymbolicIndexingInterface
-using ModelingToolkit.DocStringExtensions
-using HomotopyContinuation
-using ModelingToolkit: iscomplete, parameters, has_index_cache, get_index_cache, get_u0,
-                       get_u0_p, check_eqs_u0, CommonSolve
-
-const MTK = ModelingToolkit
-
-"""
-$(TYPEDSIGNATURES)
-
-Convert `expr` from a symbolics expression to one that uses `HomotopyContinuation.ModelKit`.
-"""
-function symbolics_to_hc(expr)
-    if iscall(expr)
-        if operation(expr) == getindex
-            args = arguments(expr)
-            return ModelKit.Variable(getname(args[1]), args[2:end]...)
-        else
-            return operation(expr)(symbolics_to_hc.(arguments(expr))...)
-        end
-    elseif symbolic_type(expr) == NotSymbolic()
-        return expr
-    else
-        return ModelKit.Variable(getname(expr))
-    end
-end
-
-"""
-$(TYPEDEF)
-
-A subtype of `HomotopyContinuation.AbstractSystem` used to solve `HomotopyContinuationProblem`s.
-"""
-struct MTKHomotopySystem{F, P, J, V} <: HomotopyContinuation.AbstractSystem
-    """
-    The generated function for the residual of the polynomial system. In-place.
-    """
-    f::F
-    """
-    The parameter object.
-    """
-    p::P
-    """
-    The generated function for the jacobian of the polynomial system. In-place.
-    """
-    jac::J
-    """
-    The `HomotopyContinuation.ModelKit.Variable` representation of the unknowns of
-    the system.
-    """
-    vars::V
-    """
-    The number of polynomials in the system. Must also be equal to `length(vars)`.
-    """
-    nexprs::Int
-end
-
-Base.size(sys::MTKHomotopySystem) = (sys.nexprs, length(sys.vars))
-ModelKit.variables(sys::MTKHomotopySystem) = sys.vars
-
-function (sys::MTKHomotopySystem)(x, p = nothing)
-    sys.f(x, sys.p)
-end
-
-function ModelKit.evaluate!(u, sys::MTKHomotopySystem, x, p = nothing)
-    sys.f(u, x, sys.p)
-end
-
-function ModelKit.evaluate_and_jacobian!(u, U, sys::MTKHomotopySystem, x, p = nothing)
-    sys.f(u, x, sys.p)
-    sys.jac(U, x, sys.p)
-end
-
-SymbolicIndexingInterface.parameter_values(s::MTKHomotopySystem) = s.p
-
-"""
-    $(TYPEDSIGNATURES)
-
-Create a `HomotopyContinuationProblem` from a `NonlinearSystem` with polynomial equations.
-The problem will be solved by HomotopyContinuation.jl. The resultant `NonlinearSolution`
-will contain the polynomial root closest to the point specified by `u0map` (if real roots
-exist for the system).
-
-Keyword arguments:
-- `eval_expression`: Whether to `eval` the generated functions or use a `RuntimeGeneratedFunction`.
-- `eval_module`: The module to use for `eval`/`@RuntimeGeneratedFunction`
-- `warn_parametric_exponent`: Whether to warn if the system contains a parametric
-  exponent preventing the homotopy from being cached.
-
-All other keyword arguments are forwarded to `HomotopyContinuation.solver_startsystems`.
-"""
-function MTK.HomotopyContinuationProblem(
-        sys::NonlinearSystem, u0map, parammap = nothing; kwargs...)
-    prob = MTK._safe_HomotopyContinuationProblem(sys, u0map, parammap; kwargs...)
-    prob isa MTK.HomotopyContinuationProblem || throw(prob)
-    return prob
-end
-
-function MTK._safe_HomotopyContinuationProblem(sys, u0map, parammap = nothing;
-        fraction_cancel_fn = SymbolicUtils.simplify_fractions, kwargs...)
-    if !iscomplete(sys)
-        error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `HomotopyContinuationProblem`")
-    end
-    transformation = MTK.PolynomialTransformation(sys)
-    if transformation isa MTK.NotPolynomialError
-        return transformation
-    end
-    result = MTK.transform_system(sys, transformation; fraction_cancel_fn)
-    if result isa MTK.NotPolynomialError
-        return result
-    end
-    MTK.HomotopyContinuationProblem(sys, transformation, result, u0map, parammap; kwargs...)
-end
-
-function MTK.HomotopyContinuationProblem(
-        sys::MTK.NonlinearSystem, transformation::MTK.PolynomialTransformation,
-        result::MTK.PolynomialTransformationResult, u0map,
-        parammap = nothing; eval_expression = false,
-        eval_module = ModelingToolkit, warn_parametric_exponent = true, kwargs...)
-    sys2 = result.sys
-    denoms = result.denominators
-    polydata = transformation.polydata
-    new_dvs = transformation.new_dvs
-    all_solutions = transformation.all_solutions
-
-    _, u0, p = MTK.process_SciMLProblem(
-        MTK.EmptySciMLFunction, sys, u0map, parammap; eval_expression, eval_module)
-    nlfn = NonlinearFunction{true}(sys2; jac = true, eval_expression, eval_module)
-
-    denominator = MTK.build_explicit_observed_function(sys2, denoms)
-    unpack_solution = MTK.build_explicit_observed_function(sys2, all_solutions)
-
-    hvars = symbolics_to_hc.(new_dvs)
-    mtkhsys = MTKHomotopySystem(nlfn.f, p, nlfn.jac, hvars, length(new_dvs))
-
-    obsfn = MTK.ObservedFunctionCache(sys; eval_expression, eval_module)
-
-    has_parametric_exponents = any(d -> d.has_parametric_exponent, polydata)
-    if has_parametric_exponents
-        if warn_parametric_exponent
-            @warn """
-            The system has parametric exponents, preventing caching of the homotopy. \
-            This will cause `solve` to be slower. Pass `warn_parametric_exponent \
-            = false` to turn off this warning
-            """
-        end
-        solver_and_starts = nothing
-    else
-        solver_and_starts = HomotopyContinuation.solver_startsolutions(mtkhsys; kwargs...)
-    end
-    return MTK.HomotopyContinuationProblem(
-        u0, mtkhsys, denominator, sys, obsfn, solver_and_starts, unpack_solution)
-end
-
-"""
-$(TYPEDSIGNATURES)
-
-Solve a `HomotopyContinuationProblem`. Ignores the algorithm passed to it, and always
-uses `HomotopyContinuation.jl`. The original solution as returned by
-`HomotopyContinuation.jl` will be available in the `.original` field of the returned
-`NonlinearSolution`.
-
-All keyword arguments except the ones listed below are forwarded to
-`HomotopyContinuation.solve`. Note that the solver and start solutions are precomputed,
-and only keyword arguments related to the solve process are valid. All keyword
-arguments have their default values in HomotopyContinuation.jl, except `show_progress`
-which defaults to `false`.
-
-Extra keyword arguments:
-- `denominator_abstol`: In case `prob` is solving a rational function, roots which cause
-  the denominator to be below `denominator_abstol` will be discarded.
-"""
-function CommonSolve.solve(prob::MTK.HomotopyContinuationProblem,
-        alg = nothing; show_progress = false, denominator_abstol = 1e-7, kwargs...)
-    if prob.solver_and_starts === nothing
-        sol = HomotopyContinuation.solve(
-            prob.homotopy_continuation_system; show_progress, kwargs...)
-    else
-        solver, starts = prob.solver_and_starts
-        sol = HomotopyContinuation.solve(solver, starts; show_progress, kwargs...)
-    end
-    realsols = HomotopyContinuation.results(sol; only_real = true)
-    if isempty(realsols)
-        u = state_values(prob)
-        retcode = SciMLBase.ReturnCode.ConvergenceFailure
-        resid = prob.homotopy_continuation_system(u)
-    else
-        T = eltype(state_values(prob))
-        distance = T(Inf)
-        u = state_values(prob)
-        resid = nothing
-        for result in realsols
-            if any(<=(denominator_abstol),
-                prob.denominator(real.(result.solution), parameter_values(prob)))
-                continue
-            end
-            for truesol in prob.unpack_solution(result.solution, parameter_values(prob))
-                dist = norm(truesol - state_values(prob))
-                if dist < distance
-                    distance = dist
-                    u = T.(real.(truesol))
-                    resid = T.(real.(prob.homotopy_continuation_system(result.solution)))
-                end
-            end
-        end
-        # all roots cause denominator to be zero
-        if isinf(distance)
-            u = state_values(prob)
-            resid = prob.homotopy_continuation_system(u)
-            retcode = SciMLBase.ReturnCode.Infeasible
-        else
-            retcode = SciMLBase.ReturnCode.Success
-        end
-    end
-
-    return SciMLBase.build_solution(
-        prob, :HomotopyContinuation, u, resid; retcode, original = sol)
-end
-
-end

From ef5e7a00b3c22f051b2f53fc1a66a92588c2a9f2 Mon Sep 17 00:00:00 2001
From: Aayush Sabharwal <aayush.sabharwal@gmail.com>
Date: Fri, 14 Feb 2025 17:43:54 +0530
Subject: [PATCH 099/111] refactor: remove `use_homotopy_continuation` keyword
 from `NonlinearProblem`

---
 src/systems/nonlinear/nonlinearsystem.jl | 9 +--------
 1 file changed, 1 insertion(+), 8 deletions(-)

diff --git a/src/systems/nonlinear/nonlinearsystem.jl b/src/systems/nonlinear/nonlinearsystem.jl
index c68fefa7d1..73484eb0dd 100644
--- a/src/systems/nonlinear/nonlinearsystem.jl
+++ b/src/systems/nonlinear/nonlinearsystem.jl
@@ -512,17 +512,10 @@ end
 
 function DiffEqBase.NonlinearProblem{iip}(sys::NonlinearSystem, u0map,
         parammap = DiffEqBase.NullParameters();
-        check_length = true, use_homotopy_continuation = false, kwargs...) where {iip}
+        check_length = true, kwargs...) where {iip}
     if !iscomplete(sys)
         error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `NonlinearProblem`")
     end
-    if use_homotopy_continuation
-        prob = safe_HomotopyContinuationProblem(
-            sys, u0map, parammap; check_length, kwargs...)
-        if prob isa HomotopyContinuationProblem
-            return prob
-        end
-    end
     f, u0, p = process_SciMLProblem(NonlinearFunction{iip}, sys, u0map, parammap;
         check_length, kwargs...)
     pt = something(get_metadata(sys), StandardNonlinearProblem())

From 18029477936545a07ab95da6f8cb2f296bc9a1be Mon Sep 17 00:00:00 2001
From: Aayush Sabharwal <aayush.sabharwal@gmail.com>
Date: Fri, 14 Feb 2025 17:44:36 +0530
Subject: [PATCH 100/111] refactor: rewrite `HomotopyContinuationProblem` to
 target NonlinearSolveHomotopyContinuation.jl

---
 .../nonlinear/homotopy_continuation.jl        | 179 +++++++++---------
 test/extensions/homotopy_continuation.jl      | 114 ++++++-----
 2 files changed, 142 insertions(+), 151 deletions(-)

diff --git a/src/systems/nonlinear/homotopy_continuation.jl b/src/systems/nonlinear/homotopy_continuation.jl
index 09e11199e8..474224eacd 100644
--- a/src/systems/nonlinear/homotopy_continuation.jl
+++ b/src/systems/nonlinear/homotopy_continuation.jl
@@ -1,93 +1,3 @@
-"""
-$(TYPEDEF)
-
-A type of Nonlinear problem which specializes on polynomial systems and uses
-HomotopyContinuation.jl to solve the system. Requires importing HomotopyContinuation.jl to
-create and solve.
-"""
-struct HomotopyContinuationProblem{uType, H, D, O, SS, U} <:
-       SciMLBase.AbstractNonlinearProblem{uType, true}
-    """
-    The initial values of states in the system. If there are multiple real roots of
-    the system, the one closest to this point is returned.
-    """
-    u0::uType
-    """
-    A subtype of `HomotopyContinuation.AbstractSystem` to solve. Also contains the
-    parameter object.
-    """
-    homotopy_continuation_system::H
-    """
-    A function with signature `(u, p) -> resid`. In case of rational functions, this
-    is used to rule out roots of the system which would cause the denominator to be
-    zero.
-    """
-    denominator::D
-    """
-    The `NonlinearSystem` used to create this problem. Used for symbolic indexing.
-    """
-    sys::NonlinearSystem
-    """
-    A function which generates and returns observed expressions for the given system.
-    """
-    obsfn::O
-    """
-    The HomotopyContinuation.jl solver and start system, obtained through
-    `HomotopyContinuation.solver_startsystems`.
-    """
-    solver_and_starts::SS
-    """
-    A function which takes a solution of the transformed system, and returns a vector
-    of solutions for the original system. This is utilized when converting systems
-    to polynomials.
-    """
-    unpack_solution::U
-end
-
-function HomotopyContinuationProblem(::AbstractSystem, _u0, _p; kwargs...)
-    error("HomotopyContinuation.jl is required to create and solve `HomotopyContinuationProblem`s. Please run `Pkg.add(\"HomotopyContinuation\")` to continue.")
-end
-
-"""
-    $(TYPEDSIGNATURES)
-
-Utility function for `safe_HomotopyContinuationProblem`, implemented in the extension.
-"""
-function _safe_HomotopyContinuationProblem end
-
-"""
-    $(TYPEDSIGNATURES)
-
-Return a `HomotopyContinuationProblem` if the extension is loaded and the system is
-polynomial. If the extension is not loaded, return `nothing`. If the system is not
-polynomial, return the appropriate `NotPolynomialError`.
-"""
-function safe_HomotopyContinuationProblem(sys::NonlinearSystem, args...; kwargs...)
-    if Base.get_extension(ModelingToolkit, :MTKHomotopyContinuationExt) === nothing
-        return nothing
-    end
-    return _safe_HomotopyContinuationProblem(sys, args...; kwargs...)
-end
-
-SymbolicIndexingInterface.symbolic_container(p::HomotopyContinuationProblem) = p.sys
-SymbolicIndexingInterface.state_values(p::HomotopyContinuationProblem) = p.u0
-function SymbolicIndexingInterface.set_state!(p::HomotopyContinuationProblem, args...)
-    set_state!(p.u0, args...)
-end
-function SymbolicIndexingInterface.parameter_values(p::HomotopyContinuationProblem)
-    parameter_values(p.homotopy_continuation_system)
-end
-function SymbolicIndexingInterface.set_parameter!(p::HomotopyContinuationProblem, args...)
-    set_parameter!(parameter_values(p), args...)
-end
-function SymbolicIndexingInterface.observed(p::HomotopyContinuationProblem, sym)
-    if p.obsfn !== nothing
-        return p.obsfn(sym)
-    else
-        return SymbolicIndexingInterface.observed(p.sys, sym)
-    end
-end
-
 function contains_variable(x, wrt)
     any(y -> occursin(y, x), wrt)
 end
@@ -562,3 +472,92 @@ function handle_rational_polynomials(x, wrt; fraction_cancel_fn = simplify_fract
     end
     return num, den
 end
+
+function SciMLBase.HomotopyNonlinearFunction(sys::NonlinearSystem, args...; kwargs...)
+    ODEFunction{true}(sys, args...; kwargs...)
+end
+
+function SciMLBase.HomotopyNonlinearFunction{true}(sys::NonlinearSystem, args...;
+        kwargs...)
+    ODEFunction{true, SciMLBase.AutoSpecialize}(sys, args...; kwargs...)
+end
+
+function SciMLBase.HomotopyNonlinearFunction{false}(sys::NonlinearSystem, args...;
+        kwargs...)
+    ODEFunction{false, SciMLBase.FullSpecialize}(sys, args...; kwargs...)
+end
+
+function SciMLBase.HomotopyNonlinearFunction{iip, specialize}(
+        sys::NonlinearSystem, args...; eval_expression = false, eval_module = @__MODULE__,
+        p = nothing, fraction_cancel_fn = SymbolicUtils.simplify_fractions,
+        kwargs...) where {iip, specialize}
+    if !iscomplete(sys)
+        error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `HomotopyContinuationFunction`")
+    end
+    transformation = PolynomialTransformation(sys)
+    if transformation isa NotPolynomialError
+        throw(transformation)
+    end
+    result = transform_system(sys, transformation; fraction_cancel_fn)
+    if result isa NotPolynomialError
+        throw(result)
+    end
+
+    sys2 = result.sys
+    denoms = result.denominators
+    polydata = transformation.polydata
+    new_dvs = transformation.new_dvs
+    all_solutions = transformation.all_solutions
+
+    # we want to create f, jac etc. according to `sys2` since that will do the solving
+    # but the `sys` inside for symbolic indexing should be the non-polynomial system
+    fn = NonlinearFunction{iip}(sys2; eval_expression, eval_module, kwargs...)
+    obsfn = ObservedFunctionCache(
+        sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))
+    fn = remake(fn; sys = sys, observed = obsfn)
+
+    denominator = build_explicit_observed_function(sys2, denoms)
+    unpolynomialize = build_explicit_observed_function(sys2, all_solutions)
+
+    inv_mapping = Dict(v => k for (k, v) in transformation.substitution_rules)
+    polynomialize_terms = [get(inv_mapping, var, var) for var in unknowns(sys2)]
+    polynomialize = build_explicit_observed_function(sys, polynomialize_terms)
+
+    return HomotopyNonlinearFunction{iip, specialize}(
+        fn; polynomialize, unpolynomialize, denominator)
+end
+
+struct HomotopyContinuationProblem{iip, specialization} end
+
+function HomotopyContinuationProblem(sys::NonlinearSystem, args...; kwargs...)
+    HomotopyContinuationProblem{true}(sys, args...; kwargs...)
+end
+
+function HomotopyContinuationProblem(sys::NonlinearSystem, t,
+        u0map::StaticArray,
+        args...;
+        kwargs...)
+    HomotopyContinuationProblem{false, SciMLBase.FullSpecialize}(
+        sys, t, u0map, args...; kwargs...)
+end
+
+function HomotopyContinuationProblem{true}(sys::NonlinearSystem, args...; kwargs...)
+    HomotopyContinuationProblem{true, SciMLBase.AutoSpecialize}(sys, args...; kwargs...)
+end
+
+function HomotopyContinuationProblem{false}(sys::NonlinearSystem, args...; kwargs...)
+    HomotopyContinuationProblem{false, SciMLBase.FullSpecialize}(sys, args...; kwargs...)
+end
+
+function HomotopyContinuationProblem{iip, spec}(
+        sys::NonlinearSystem, u0map, pmap = SciMLBase.NullParameters();
+        kwargs...) where {iip, spec}
+    if !iscomplete(sys)
+        error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `HomotopyContinuationProblem`")
+    end
+    f, u0, p = process_SciMLProblem(
+        HomotopyNonlinearFunction{iip, spec}, sys, u0map, pmap; kwargs...)
+
+    kwargs = filter_kwargs(kwargs)
+    return NonlinearProblem{iip}(f, u0, p; kwargs...)
+end
diff --git a/test/extensions/homotopy_continuation.jl b/test/extensions/homotopy_continuation.jl
index 9e15ea857e..ed630ea8e3 100644
--- a/test/extensions/homotopy_continuation.jl
+++ b/test/extensions/homotopy_continuation.jl
@@ -1,20 +1,32 @@
-using ModelingToolkit, NonlinearSolve, SymbolicIndexingInterface
+using ModelingToolkit, NonlinearSolve, NonlinearSolveHomotopyContinuation,
+      SymbolicIndexingInterface
 using SymbolicUtils
 import ModelingToolkit as MTK
 using LinearAlgebra
 using Test
 
-@testset "Safe HCProblem" begin
-    @variables x y z
-    eqs = [0 ~ x^2 + y^2 + 2x * y
-           0 ~ x^2 + 4x + 4
-           0 ~ y * z + 4x^2]
-    @mtkbuild sys = NonlinearSystem(eqs)
-    prob = MTK.safe_HomotopyContinuationProblem(sys, [x => 1.0, y => 1.0, z => 1.0], [])
-    @test prob === nothing
+allrootsalg = HomotopyContinuationJL{true}(; threading = false)
+singlerootalg = HomotopyContinuationJL{false}(; threading = false)
+
+function test_single_root(sol; atol = 1e-10)
+    @test SciMLBase.successful_retcode(sol)
+    @test norm(sol.resid)≈0.0 atol=atol
 end
 
-import HomotopyContinuation
+function test_all_roots(sol; atol = 1e-10)
+    @test sol.converged
+    for nlsol in sol.u
+        @test SciMLBase.successful_retcode(nlsol)
+        @test norm(nlsol.resid)≈0.0 atol=1e-10
+    end
+end
+
+function solve_allroots_closest(prob)
+    sol = solve(prob, allrootsalg)
+    return argmin(sol.u) do nlsol
+        return norm(nlsol.u - prob.u0)
+    end
+end
 
 @testset "No parameters" begin
     @variables x y z
@@ -24,19 +36,13 @@ import HomotopyContinuation
     @mtkbuild sys = NonlinearSystem(eqs)
     u0 = [x => 1.0, y => 1.0, z => 1.0]
     prob = HomotopyContinuationProblem(sys, u0)
+    @test prob isa NonlinearProblem
     @test prob[x] == prob[y] == prob[z] == 1.0
     @test prob[x + y] == 2.0
-    sol = solve(prob; threading = false)
-    @test SciMLBase.successful_retcode(sol)
-    @test norm(sol.resid)≈0.0 atol=1e-10
-
-    prob2 = NonlinearProblem(sys, u0; use_homotopy_continuation = true)
-    @test prob2 isa HomotopyContinuationProblem
-    sol = solve(prob2; threading = false)
-    @test SciMLBase.successful_retcode(sol)
-    @test norm(sol.resid)≈0.0 atol=1e-10
-
-    @test NonlinearProblem(sys, u0; use_homotopy_continuation = false) isa NonlinearProblem
+    sol = solve(prob, singlerootalg)
+    test_single_root(sol)
+    sol = solve(prob, allrootsalg)
+    test_all_roots(sol)
 end
 
 struct Wrapper
@@ -61,9 +67,10 @@ end
     @test prob.ps[q] == 4
     @test prob.ps[r].x == [1.0 1.0; 0.0 0.0]
     @test prob.ps[p * q] == 8.0
-    sol = solve(prob; threading = false)
-    @test SciMLBase.successful_retcode(sol)
-    @test norm(sol.resid)≈0.0 atol=1e-10
+    sol = solve(prob, singlerootalg)
+    test_single_root(sol)
+    sol = solve(prob, allrootsalg)
+    test_all_roots(sol)
 end
 
 @testset "Array variables" begin
@@ -79,7 +86,7 @@ end
     @test prob[x] == 2ones(3)
     prob.ps[p] = [2, 3, 4]
     @test prob.ps[p] == [2, 3, 4]
-    sol = @test_nowarn solve(prob; threading = false)
+    sol = @test_nowarn solve(prob, singlerootalg)
     @test sol.retcode == ReturnCode.ConvergenceFailure
 end
 
@@ -87,11 +94,11 @@ end
     @variables x = 1.0
     @parameters n::Integer = 4
     @mtkbuild sys = NonlinearSystem([x^n + x^2 - 1 ~ 0])
-    prob = @test_warn ["parametric", "exponent"] HomotopyContinuationProblem(sys, [])
-    @test prob.solver_and_starts === nothing
-    @test_nowarn HomotopyContinuationProblem(sys, []; warn_parametric_exponent = false)
-    sol = solve(prob; threading = false)
-    @test SciMLBase.successful_retcode(sol)
+    prob = HomotopyContinuationProblem(sys, [])
+    sol = solve(prob, singlerootalg)
+    test_single_root(sol)
+    sol = solve(prob, allrootsalg)
+    test_all_roots(sol)
 end
 
 @testset "Polynomial check and warnings" begin
@@ -100,45 +107,31 @@ end
     @test_throws ["Cannot convert", "Unable", "symbolically solve",
         "Exponent", "not an integer", "not a polynomial"] HomotopyContinuationProblem(
         sys, [])
-    @test MTK.safe_HomotopyContinuationProblem(sys, []) isa MTK.NotPolynomialError
-    @test NonlinearProblem(sys, []) isa NonlinearProblem
 
     @mtkbuild sys = NonlinearSystem([x^x - x ~ 0])
     @test_throws ["Cannot convert", "Unable", "symbolically solve",
         "Exponent", "unknowns", "not a polynomial"] HomotopyContinuationProblem(
         sys, [])
-    @test MTK.safe_HomotopyContinuationProblem(sys, []) isa MTK.NotPolynomialError
-    @test NonlinearProblem(sys, []) isa NonlinearProblem
     @mtkbuild sys = NonlinearSystem([((x^2) / sin(x))^2 + x ~ 0])
     @test_throws ["Cannot convert", "both polynomial", "non-polynomial",
         "recognized", "sin", "not a polynomial"] HomotopyContinuationProblem(
         sys, [])
-    @test MTK.safe_HomotopyContinuationProblem(sys, []) isa MTK.NotPolynomialError
-    @test NonlinearProblem(sys, []) isa NonlinearProblem
 
     @variables y = 2.0
     @mtkbuild sys = NonlinearSystem([x^2 + y^2 + 2 ~ 0, y ~ sin(x)])
     @test_throws ["Cannot convert", "recognized", "sin", "not a polynomial"] HomotopyContinuationProblem(
         sys, [])
-    @test MTK.safe_HomotopyContinuationProblem(sys, []) isa MTK.NotPolynomialError
-    @test NonlinearProblem(sys, []) isa NonlinearProblem
 
     @mtkbuild sys = NonlinearSystem([x^2 + y^2 - 2 ~ 0, sin(x + y) ~ 0])
     @test_throws ["Cannot convert", "function of multiple unknowns"] HomotopyContinuationProblem(
         sys, [])
-    @test MTK.safe_HomotopyContinuationProblem(sys, []) isa MTK.NotPolynomialError
-    @test NonlinearProblem(sys, []) isa NonlinearProblem
 
     @mtkbuild sys = NonlinearSystem([sin(x)^2 + 1 ~ 0, cos(y) - cos(x) - 1 ~ 0])
     @test_throws ["Cannot convert", "multiple non-polynomial terms", "same unknown"] HomotopyContinuationProblem(
         sys, [])
-    @test MTK.safe_HomotopyContinuationProblem(sys, []) isa MTK.NotPolynomialError
-    @test NonlinearProblem(sys, []) isa NonlinearProblem
 
     @mtkbuild sys = NonlinearSystem([sin(x^2)^2 + sin(x^2) - 1 ~ 0])
     @test_throws ["import Nemo"] HomotopyContinuationProblem(sys, [])
-    @test MTK.safe_HomotopyContinuationProblem(sys, []) isa MTK.NotPolynomialError
-    @test NonlinearProblem(sys, []) isa NonlinearProblem
 end
 
 import Nemo
@@ -148,7 +141,8 @@ import Nemo
     @mtkbuild sys = NonlinearSystem([sin(x^2)^2 + sin(x^2) - 1 ~ 0])
     prob = HomotopyContinuationProblem(sys, [])
     @test prob[1] ≈ 2.0
-    sol = solve(prob; threading = false)
+    # singlerootalg doesn't converge
+    sol = solve(prob, allrootsalg).u[1]
     _x = sol[1]
     @test sin(_x^2)^2 + sin(_x^2) - 1≈0.0 atol=1e-12
 end
@@ -162,9 +156,7 @@ end
     prob = HomotopyContinuationProblem(sys, [])
     @test prob[x] ≈ 0.25
     @test prob[y] ≈ 0.125
-    sol = solve(prob; threading = false)
-    # can't replicate the solve failure locally, so CI logs might help
-    @show sol.u sol.original.path_results
+    sol = solve(prob, allrootsalg).u[1]
     @test SciMLBase.successful_retcode(sol)
     @test sol[a]≈0.5 atol=1e-6
     @test sol[b]≈0.25 atol=1e-6
@@ -177,12 +169,12 @@ end
         0 ~ (x^2 - n * x + n) * (x - 1) / (x - 2) / (x - 3)
     ])
     prob = HomotopyContinuationProblem(sys, [])
-    sol = solve(prob; threading = false)
+    sol = solve_allroots_closest(prob)
     @test sol[x] ≈ 1.0
     p = parameter_values(prob)
     for invalid in [2.0, 3.0]
         for err in [-9e-8, 0, 9e-8]
-            @test any(<=(1e-7), prob.denominator([invalid + err, 2.0], p))
+            @test any(<=(1e-7), prob.f.denominator([invalid + err, 2.0], p))
         end
     end
 
@@ -195,7 +187,7 @@ end
         [n])
     sys = complete(sys)
     prob = HomotopyContinuationProblem(sys, [])
-    sol = solve(prob; threading = false)
+    sol = solve(prob, singlerootalg)
     disallowed_x = [4, 5.5]
     disallowed_y = [7, 5, 4]
     @test all(!isapprox(sol[x]; atol = 1e-8), disallowed_x)
@@ -205,30 +197,30 @@ end
     p = parameter_values(prob)
     for val in disallowed_x
         for err in [-9e-8, 0, 9e-8]
-            @test any(<=(1e-7), prob.denominator([val + err, 2.0], p))
+            @test any(<=(1e-7), prob.f.denominator([val + err, 2.0], p))
         end
     end
     for val in disallowed_y
         for err in [-9e-8, 0, 9e-8]
-            @test any(<=(1e-7), prob.denominator([2.0, val + err], p))
+            @test any(<=(1e-7), prob.f.denominator([2.0, val + err], p))
         end
     end
-    @test prob.denominator([2.0, 4.0], p)[1] <= 1e-8
+    @test prob.f.denominator([2.0, 4.0], p)[1] <= 1e-8
 
     @testset "Rational function in observed" begin
         @variables x=1 y=1
         @mtkbuild sys = NonlinearSystem([x^2 + y^2 - 2x - 2 ~ 0, y ~ (x - 1) / (x - 2)])
         prob = HomotopyContinuationProblem(sys, [])
-        @test any(prob.denominator([2.0], parameter_values(prob)) .≈ 0.0)
-        @test SciMLBase.successful_retcode(solve(prob; threading = false))
+        @test any(prob.f.denominator([2.0], parameter_values(prob)) .≈ 0.0)
+        @test SciMLBase.successful_retcode(solve(prob, singlerootalg))
     end
 
     @testset "Rational function forced to common denominators" begin
         @variables x = 1
         @mtkbuild sys = NonlinearSystem([0 ~ 1 / (1 + x) - x])
         prob = HomotopyContinuationProblem(sys, [])
-        @test any(prob.denominator([-1.0], parameter_values(prob)) .≈ 0.0)
-        sol = solve(prob; threading = false)
+        @test any(prob.f.denominator([-1.0], parameter_values(prob)) .≈ 0.0)
+        sol = solve(prob, singlerootalg)
         @test SciMLBase.successful_retcode(sol)
         @test 1 / (1 + sol.u[1]) - sol.u[1]≈0.0 atol=1e-10
     end
@@ -238,7 +230,7 @@ end
     @variables x=1 y
     @mtkbuild sys = NonlinearSystem([x^2 - 2 ~ 0, y ~ sin(x)])
     prob = HomotopyContinuationProblem(sys, [])
-    sol = @test_nowarn solve(prob; threading = false)
+    sol = @test_nowarn solve(prob, singlerootalg)
     @test sol[x] ≈ √2.0
     @test sol[y] ≈ sin(√2.0)
 end
@@ -251,10 +243,10 @@ end
 
     @testset "`simplify_fractions`" begin
         prob = HomotopyContinuationProblem(sys, [])
-        @test prob.denominator([0.0], parameter_values(prob)) ≈ [4.0]
+        @test prob.f.denominator([0.0], parameter_values(prob)) ≈ [4.0]
     end
     @testset "`nothing`" begin
         prob = HomotopyContinuationProblem(sys, []; fraction_cancel_fn = nothing)
-        @test sort(prob.denominator([0.0], parameter_values(prob))) ≈ [2.0, 4.0^3]
+        @test sort(prob.f.denominator([0.0], parameter_values(prob))) ≈ [2.0, 4.0^3]
     end
 end

From 486bba7cc982bea2784b99553e67b21212651c75 Mon Sep 17 00:00:00 2001
From: Aayush Sabharwal <aayush.sabharwal@gmail.com>
Date: Fri, 14 Feb 2025 22:51:17 +0530
Subject: [PATCH 101/111] test: add NonlinearSolveHomotopyContinuation to
 `test/extensions` env

---
 test/extensions/Project.toml | 1 +
 1 file changed, 1 insertion(+)

diff --git a/test/extensions/Project.toml b/test/extensions/Project.toml
index 5f7afe222a..5b0de73cdf 100644
--- a/test/extensions/Project.toml
+++ b/test/extensions/Project.toml
@@ -10,6 +10,7 @@ JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
 LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
 ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
 Nemo = "2edaba10-b0f1-5616-af89-8c11ac63239a"
+NonlinearSolveHomotopyContinuation = "2ac3b008-d579-4536-8c91-a1a5998c2f8b"
 OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
 SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
 SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226"

From 18e766efa165188ea6df5b2ea9a095c85c9ff40d Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Fri, 21 Feb 2025 11:06:26 -0800
Subject: [PATCH 102/111] fix bug

---
 src/systems/problem_utils.jl | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl
index e7345f5d5d..1f40f61eed 100644
--- a/src/systems/problem_utils.jl
+++ b/src/systems/problem_utils.jl
@@ -755,7 +755,7 @@ function process_SciMLProblem(
 
     u0map = to_varmap(u0map, dvs)
     symbols_to_symbolics!(sys, u0map)
-    pmap = to_varmap(pmap, ps)
+    pmap = to_varmap(pmap, parameters(sys))
     symbols_to_symbolics!(sys, pmap)
 
     check_inputmap_keys(sys, u0map, pmap)

From bb1522869cd80da051b7c3930f69384ef1880af5 Mon Sep 17 00:00:00 2001
From: Aayush Sabharwal <aayush.sabharwal@gmail.com>
Date: Fri, 14 Feb 2025 13:28:35 +0530
Subject: [PATCH 103/111] feat: add `available_vars` to
 `observed_equations_used_by`

---
 src/utils.jl | 12 +++++++++++-
 1 file changed, 11 insertions(+), 1 deletion(-)

diff --git a/src/utils.jl b/src/utils.jl
index 962801622a..1fe7da603d 100644
--- a/src/utils.jl
+++ b/src/utils.jl
@@ -1068,14 +1068,24 @@ Keyword arguments:
   providing this keyword is not necessary and is only useful to avoid repeatedly calling
   `vars(exprs)`
 - `obs`: the list of observed equations.
+- `available_vars`: If `exprs` involves a variable `x[1]`, this function will look for
+  observed equations whose LHS is `x[1]` OR `x`. Sometimes, the latter is not required
+  since `x[1]` might already be present elsewhere in the generated code (e.g. an argument
+  to the function) but other elements of `x` are part of the observed equations, thus
+  requiring them to be obtained from the equation for `x`. Any variable present in
+  `available_vars` will not be searched for in the observed equations.
 """
 function observed_equations_used_by(sys::AbstractSystem, exprs;
-        involved_vars = vars(exprs; op = Union{Shift, Differential}), obs = observed(sys))
+        involved_vars = vars(exprs; op = Union{Shift, Differential}), obs = observed(sys), available_vars = [])
     obsvars = getproperty.(obs, :lhs)
     graph = observed_dependency_graph(obs)
+    if !(available_vars isa Set)
+        available_vars = Set(available_vars)
+    end
 
     obsidxs = BitSet()
     for sym in involved_vars
+        sym in available_vars && continue
         arrsym = iscall(sym) && operation(sym) === getindex ? arguments(sym)[1] : nothing
         idx = findfirst(v -> isequal(v, sym) || isequal(v, arrsym), obsvars)
         idx === nothing && continue

From 7ee87a5f4c9d81895099aeacb57d3b232877be88 Mon Sep 17 00:00:00 2001
From: Aayush Sabharwal <aayush.sabharwal@gmail.com>
Date: Fri, 14 Feb 2025 13:28:50 +0530
Subject: [PATCH 104/111] fix: fix array varables split across SCCs in
 SCCNonlinearProblem

---
 src/systems/nonlinear/nonlinearsystem.jl | 17 ++++++++++-------
 test/scc_nonlinear_problem.jl            | 10 ++++++++++
 2 files changed, 20 insertions(+), 7 deletions(-)

diff --git a/src/systems/nonlinear/nonlinearsystem.jl b/src/systems/nonlinear/nonlinearsystem.jl
index c68fefa7d1..e0aa84dbdb 100644
--- a/src/systems/nonlinear/nonlinearsystem.jl
+++ b/src/systems/nonlinear/nonlinearsystem.jl
@@ -676,23 +676,23 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map,
     scc_cacheexprs = Dict{TypeT, Vector{Any}}[]
     scc_eqs = Vector{Equation}[]
     scc_obs = Vector{Equation}[]
+    # variables solved in previous SCCs
+    available_vars = Set()
     for (i, (escc, vscc)) in enumerate(zip(eq_sccs, var_sccs))
         # subset unknowns and equations
         _dvs = dvs[vscc]
         _eqs = eqs[escc]
         # get observed equations required by this SCC
-        obsidxs = observed_equations_used_by(sys, _eqs)
+        union!(available_vars, _dvs)
+        obsidxs = observed_equations_used_by(sys, _eqs; available_vars)
         # the ones used by previous SCCs can be precomputed into the cache
         setdiff!(obsidxs, prevobsidxs)
         _obs = obs[obsidxs]
+        union!(available_vars, getproperty.(_obs, (:lhs,)))
 
         # get all subexpressions in the RHS which we can precompute in the cache
         # precomputed subexpressions should not contain `banned_vars`
         banned_vars = Set{Any}(vcat(_dvs, getproperty.(_obs, (:lhs,))))
-        filter!(banned_vars) do var
-            symbolic_type(var) != ArraySymbolic() ||
-                all(j -> var[j] in banned_vars, eachindex(var))
-        end
         state = Dict()
         for i in eachindex(_obs)
             _obs[i] = _obs[i].lhs ~ subexpressions_not_involving_vars!(
@@ -753,9 +753,12 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map,
         _obs = scc_obs[i]
         cachevars = scc_cachevars[i]
         cacheexprs = scc_cacheexprs[i]
+        available_vars = [dvs[reduce(vcat, var_sccs[1:(i - 1)]; init = Int[])];
+                          getproperty.(
+                              reduce(vcat, scc_obs[1:(i - 1)]; init = []), (:lhs,))]
         _prevobsidxs = vcat(_prevobsidxs,
-            observed_equations_used_by(sys, reduce(vcat, values(cacheexprs); init = [])))
-
+            observed_equations_used_by(
+                sys, reduce(vcat, values(cacheexprs); init = []); available_vars))
         if isempty(cachevars)
             push!(explicitfuns, Returns(nothing))
         else
diff --git a/test/scc_nonlinear_problem.jl b/test/scc_nonlinear_problem.jl
index 57f3d72fb7..ef4638cdb0 100644
--- a/test/scc_nonlinear_problem.jl
+++ b/test/scc_nonlinear_problem.jl
@@ -253,3 +253,13 @@ import ModelingToolkitStandardLibrary.Hydraulic.IsothermalCompressible as IC
     sol = solve(prob)
     @test SciMLBase.successful_retcode(sol)
 end
+
+@testset "Array variables split across SCCs" begin
+    @variables x[1:3]
+    @parameters (f::Function)(..)
+    @mtkbuild sys = NonlinearSystem([
+        0 ~ x[1]^2 - 9, x[2] ~ 2x[1], 0 ~ x[3]^2 - x[1]^2 + f(x)])
+    prob = SCCNonlinearProblem(sys, [x => ones(3)], [f => sum])
+    sol = solve(prob, NewtonRaphson())
+    @test SciMLBase.successful_retcode(sol)
+end

From a797008ab49846fcfa659842a2168c070866b31d Mon Sep 17 00:00:00 2001
From: Aayush Sabharwal <aayush.sabharwal@gmail.com>
Date: Mon, 24 Feb 2025 18:11:46 +0530
Subject: [PATCH 105/111] refactor: format

---
 src/systems/diffeqs/abstractodesystem.jl |  14 +-
 src/systems/diffeqs/odesystem.jl         |  35 ++--
 src/systems/problem_utils.jl             |   7 +-
 test/bvproblem.jl                        | 193 ++++++++++++-----------
 test/odesystem.jl                        |  12 +-
 test/problem_validation.jl               |  22 +--
 6 files changed, 150 insertions(+), 133 deletions(-)

diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl
index 1f6014a357..205d7e4601 100644
--- a/src/systems/diffeqs/abstractodesystem.jl
+++ b/src/systems/diffeqs/abstractodesystem.jl
@@ -841,35 +841,35 @@ function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = []
         eval_expression = false,
         eval_module = @__MODULE__,
         kwargs...) where {iip, specialize}
-
     if !iscomplete(sys)
         error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `BVProblem`")
     end
     !isnothing(callback) && error("BVP solvers do not support callbacks.")
 
-    has_alg_eqs(sys) && error("The BVProblem constructor currently does not support ODESystems with algebraic equations.") # Remove this when the BVDAE solvers get updated, the codegen should work when it does.
+    has_alg_eqs(sys) &&
+        error("The BVProblem constructor currently does not support ODESystems with algebraic equations.") # Remove this when the BVDAE solvers get updated, the codegen should work when it does.
 
     sts = unknowns(sys)
     ps = parameters(sys)
     constraintsys = get_constraintsystem(sys)
 
     if !isnothing(constraintsys)
-        (length(constraints(constraintsys)) + length(u0map) > length(sts)) && 
-        @warn "The BVProblem is overdetermined. The total number of conditions (# constraints + # fixed initial values given by u0map) exceeds the total number of states. The BVP solvers will default to doing a nonlinear least-squares optimization."
+        (length(constraints(constraintsys)) + length(u0map) > length(sts)) &&
+            @warn "The BVProblem is overdetermined. The total number of conditions (# constraints + # fixed initial values given by u0map) exceeds the total number of states. The BVP solvers will default to doing a nonlinear least-squares optimization."
     end
 
     # ODESystems without algebraic equations should use both fixed values + guesses
     # for initialization.
-    _u0map = has_alg_eqs(sys) ? u0map : merge(Dict(u0map), Dict(guesses)) 
+    _u0map = has_alg_eqs(sys) ? u0map : merge(Dict(u0map), Dict(guesses))
     f, u0, p = process_SciMLProblem(ODEFunction{iip, specialize}, sys, _u0map, parammap;
         t = tspan !== nothing ? tspan[1] : tspan, guesses,
         check_length, warn_initialize_determined, eval_expression, eval_module, kwargs...)
 
     stidxmap = Dict([v => i for (i, v) in enumerate(sts)])
-    u0_idxs = has_alg_eqs(sys) ? collect(1:length(sts)) : [stidxmap[k] for (k,v) in u0map]
+    u0_idxs = has_alg_eqs(sys) ? collect(1:length(sts)) : [stidxmap[k] for (k, v) in u0map]
 
     fns = generate_function_bc(sys, u0, u0_idxs, tspan)
-    bc_oop, bc_iip = eval_or_rgf.(fns; eval_expression, eval_module) 
+    bc_oop, bc_iip = eval_or_rgf.(fns; eval_expression, eval_module)
     bc(sol, p, t) = bc_oop(sol, p, t)
     bc(resid, u, p, t) = bc_iip(resid, u, p, t)
 
diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl
index 37f921303c..33cdc59909 100644
--- a/src/systems/diffeqs/odesystem.jl
+++ b/src/systems/diffeqs/odesystem.jl
@@ -193,7 +193,8 @@ struct ODESystem <: AbstractODESystem
     """
     parent::Any
 
-    function ODESystem(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, constraints, tgrad,
+    function ODESystem(
+            tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, constraints, tgrad,
             jac, ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses,
             torn_matching, initializesystem, initialization_eqs, schedule,
             connector_type, preface, cevents,
@@ -214,7 +215,8 @@ struct ODESystem <: AbstractODESystem
             u = __get_unit_type(dvs, ps, iv)
             check_units(u, deqs)
         end
-        new(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, constraints, tgrad, jac,
+        new(tag, deqs, iv, dvs, ps, tspan, var_to_name,
+            ctrls, observed, constraints, tgrad, jac,
             ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses, torn_matching,
             initializesystem, initialization_eqs, schedule, connector_type, preface,
             cevents, devents, parameter_dependencies, assertions, metadata,
@@ -300,16 +302,16 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
     if is_dde === nothing
         is_dde = _check_if_dde(deqs, iv′, systems)
     end
-            
+
     if !isempty(systems) && !isnothing(constraintsystem)
         conssystems = ConstraintsSystem[]
         for sys in systems
             cons = get_constraintsystem(sys)
-            cons !== nothing && push!(conssystems, cons) 
+            cons !== nothing && push!(conssystems, cons)
         end
         @show conssystems
         @set! constraintsystem.systems = conssystems
-    end        
+    end
 
     assertions = Dict{BasicSymbolic, Any}(unwrap(k) => v for (k, v) in assertions)
 
@@ -359,9 +361,9 @@ function ODESystem(eqs, iv; constraints = Equation[], kwargs...)
     if !isempty(constraints)
         constraintsystem = process_constraint_system(constraints, allunknowns, new_ps, iv)
         for st in get_unknowns(constraintsystem)
-            iscall(st) ? 
-                !in(operation(st)(iv), allunknowns) && push!(consvars, st) :
-                !in(st, allunknowns) && push!(consvars, st)
+            iscall(st) ?
+            !in(operation(st)(iv), allunknowns) && push!(consvars, st) :
+            !in(st, allunknowns) && push!(consvars, st)
         end
         for p in parameters(constraintsystem)
             !in(p, new_ps) && push!(new_ps, p)
@@ -712,7 +714,8 @@ end
 # Validate that all the variables in the BVP constraints are well-formed states or parameters.
 #  - Callable/delay variables (e.g. of the form x(0.6) should be unknowns of the system (and have one arg, etc.)
 #  - Callable/delay parameters should be parameters of the system (and have one arg, etc.)
-function process_constraint_system(constraints::Vector{Equation}, sts, ps, iv; consname = :cons)
+function process_constraint_system(
+        constraints::Vector{Equation}, sts, ps, iv; consname = :cons)
     isempty(constraints) && return nothing
 
     constraintsts = OrderedSet()
@@ -725,22 +728,26 @@ function process_constraint_system(constraints::Vector{Equation}, sts, ps, iv; c
     # Validate the states.
     for var in constraintsts
         if !iscall(var)
-            occursin(iv, var) && (var ∈ sts || throw(ArgumentError("Time-dependent variable $var is not an unknown of the system.")))
+            occursin(iv, var) && (var ∈ sts ||
+             throw(ArgumentError("Time-dependent variable $var is not an unknown of the system.")))
         elseif length(arguments(var)) > 1
             throw(ArgumentError("Too many arguments for variable $var."))
         elseif length(arguments(var)) == 1
             arg = only(arguments(var))
-            operation(var)(iv) ∈ sts || 
+            operation(var)(iv) ∈ sts ||
                 throw(ArgumentError("Variable $var is not a variable of the ODESystem. Called variables must be variables of the ODESystem."))
 
-            isequal(arg, iv) || isparameter(arg) || arg isa Integer || arg isa AbstractFloat || 
+            isequal(arg, iv) || isparameter(arg) || arg isa Integer ||
+                arg isa AbstractFloat ||
                 throw(ArgumentError("Invalid argument specified for variable $var. The argument of the variable should be either $iv, a parameter, or a value specifying the time that the constraint holds."))
 
             isparameter(arg) && push!(constraintps, arg)
         else
-            var ∈ sts && @warn "Variable $var has no argument. It will be interpreted as $var($iv), and the constraint will apply to the entire interval."
+            var ∈ sts &&
+                @warn "Variable $var has no argument. It will be interpreted as $var($iv), and the constraint will apply to the entire interval."
         end
     end
 
-    ConstraintsSystem(constraints, collect(constraintsts), collect(constraintps); name = consname)
+    ConstraintsSystem(
+        constraints, collect(constraintsts), collect(constraintps); name = consname)
 end
diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl
index 1f40f61eed..77f4229696 100644
--- a/src/systems/problem_utils.jl
+++ b/src/systems/problem_utils.jl
@@ -872,7 +872,8 @@ function check_inputmap_keys(sys, u0map, pmap)
             push!(badparamkeys, k)
         end
     end
-    (isempty(badvarkeys) && isempty(badparamkeys)) || throw(InvalidKeyError(collect(badvarkeys), collect(badparamkeys)))
+    (isempty(badvarkeys) && isempty(badparamkeys)) ||
+        throw(InvalidKeyError(collect(badvarkeys), collect(badparamkeys)))
 end
 
 const BAD_KEY_MESSAGE = """
@@ -885,14 +886,12 @@ struct InvalidKeyError <: Exception
     params::Any
 end
 
-function Base.showerror(io::IO, e::InvalidKeyError) 
+function Base.showerror(io::IO, e::InvalidKeyError)
     println(io, BAD_KEY_MESSAGE)
     println(io, "u0map: $(join(e.vars, ", "))")
     println(io, "pmap: $(join(e.params, ", "))")
 end
 
-
-
 ##############
 # Legacy functions for backward compatibility
 ##############
diff --git a/test/bvproblem.jl b/test/bvproblem.jl
index f05be90281..c5451f681b 100644
--- a/test/bvproblem.jl
+++ b/test/bvproblem.jl
@@ -12,70 +12,73 @@ solvers = [MIRK4]
 daesolvers = [Ascher2, Ascher4, Ascher6]
 
 let
-     @parameters α=7.5 β=4.0 γ=8.0 δ=5.0
-     @variables x(t)=1.0 y(t)=2.0
-     
-     eqs = [D(x) ~ α * x - β * x * y,
-         D(y) ~ -γ * y + δ * x * y]
-     
-     u0map = [x => 1.0, y => 2.0]
-     parammap = [α => 7.5, β => 4, γ => 8.0, δ => 5.0]
-     tspan = (0.0, 10.0)
-     
-     @mtkbuild lotkavolterra = ODESystem(eqs, t)
-     op = ODEProblem(lotkavolterra, u0map, tspan, parammap)
-     osol = solve(op, Vern9())
-     
-     bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lotkavolterra, u0map, tspan, parammap)
-     
-     for solver in solvers
-         sol = solve(bvp, solver(), dt = 0.01)
-         @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
-         @test sol.u[1] == [1.0, 2.0]
-     end
-     
-     # Test out of place
-     bvp2 = SciMLBase.BVProblem{false, SciMLBase.AutoSpecialize}(lotkavolterra, u0map, tspan, parammap)
-     
-     for solver in solvers
-         sol = solve(bvp2, solver(), dt = 0.01)
-         @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
-         @test sol.u[1] == [1.0, 2.0]
-     end
+    @parameters α=7.5 β=4.0 γ=8.0 δ=5.0
+    @variables x(t)=1.0 y(t)=2.0
+
+    eqs = [D(x) ~ α * x - β * x * y,
+        D(y) ~ -γ * y + δ * x * y]
+
+    u0map = [x => 1.0, y => 2.0]
+    parammap = [α => 7.5, β => 4, γ => 8.0, δ => 5.0]
+    tspan = (0.0, 10.0)
+
+    @mtkbuild lotkavolterra = ODESystem(eqs, t)
+    op = ODEProblem(lotkavolterra, u0map, tspan, parammap)
+    osol = solve(op, Vern9())
+
+    bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(
+        lotkavolterra, u0map, tspan, parammap)
+
+    for solver in solvers
+        sol = solve(bvp, solver(), dt = 0.01)
+        @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
+        @test sol.u[1] == [1.0, 2.0]
+    end
+
+    # Test out of place
+    bvp2 = SciMLBase.BVProblem{false, SciMLBase.AutoSpecialize}(
+        lotkavolterra, u0map, tspan, parammap)
+
+    for solver in solvers
+        sol = solve(bvp2, solver(), dt = 0.01)
+        @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
+        @test sol.u[1] == [1.0, 2.0]
+    end
 end
 
 ### Testing on pendulum
 let
-     @parameters g=9.81 L=1.0
-     @variables θ(t) = π / 2 θ_t(t)
-     
-     eqs = [D(θ) ~ θ_t
-            D(θ_t) ~ -(g / L) * sin(θ)]
-     
-     @mtkbuild pend = ODESystem(eqs, t)
-     
-     u0map = [θ => π / 2, θ_t => π / 2]
-     parammap = [:L => 1.0, :g => 9.81]
-     tspan = (0.0, 6.0)
-     
-     op = ODEProblem(pend, u0map, tspan, parammap)
-     osol = solve(op, Vern9())
-     
-     bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap)
-     for solver in solvers
-         sol = solve(bvp, solver(), dt = 0.01)
-         @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
-         @test sol.u[1] == [π / 2, π / 2]
-     end
-     
-     # Test out-of-place
-     bvp2 = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(pend, u0map, tspan, parammap)
-     
-     for solver in solvers
-         sol = solve(bvp2, solver(), dt = 0.01)
-         @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
-         @test sol.u[1] == [π / 2, π / 2]
-     end
+    @parameters g=9.81 L=1.0
+    @variables θ(t)=π / 2 θ_t(t)
+
+    eqs = [D(θ) ~ θ_t
+           D(θ_t) ~ -(g / L) * sin(θ)]
+
+    @mtkbuild pend = ODESystem(eqs, t)
+
+    u0map = [θ => π / 2, θ_t => π / 2]
+    parammap = [:L => 1.0, :g => 9.81]
+    tspan = (0.0, 6.0)
+
+    op = ODEProblem(pend, u0map, tspan, parammap)
+    osol = solve(op, Vern9())
+
+    bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap)
+    for solver in solvers
+        sol = solve(bvp, solver(), dt = 0.01)
+        @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
+        @test sol.u[1] == [π / 2, π / 2]
+    end
+
+    # Test out-of-place
+    bvp2 = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(
+        pend, u0map, tspan, parammap)
+
+    for solver in solvers
+        sol = solve(bvp2, solver(), dt = 0.01)
+        @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
+        @test sol.u[1] == [π / 2, π / 2]
+    end
 end
 
 ##################################################################
@@ -87,40 +90,42 @@ let
     @parameters α=1.5 β=1.0 γ=3.0 δ=1.0
     @variables x(..) y(..)
     eqs = [D(x(t)) ~ α * x(t) - β * x(t) * y(t),
-           D(y(t)) ~ -γ * y(t) + δ * x(t) * y(t)]
-    
-    tspan = (0., 1.)
+        D(y(t)) ~ -γ * y(t) + δ * x(t) * y(t)]
+
+    tspan = (0.0, 1.0)
     @mtkbuild lksys = ODESystem(eqs, t)
 
-    function lotkavolterra!(du, u, p, t) 
-        du[1] = p[1]*u[1] - p[2]*u[1]*u[2]
-        du[2] = -p[4]*u[2] + p[3]*u[1]*u[2]
+    function lotkavolterra!(du, u, p, t)
+        du[1] = p[1] * u[1] - p[2] * u[1] * u[2]
+        du[2] = -p[4] * u[2] + p[3] * u[1] * u[2]
     end
 
-    function lotkavolterra(u, p, t) 
-        [p[1]*u[1] - p[2]*u[1]*u[2], -p[4]*u[2] + p[3]*u[1]*u[2]]
+    function lotkavolterra(u, p, t)
+        [p[1] * u[1] - p[2] * u[1] * u[2], -p[4] * u[2] + p[3] * u[1] * u[2]]
     end
 
     # Test with a constraint.
-    constr = [y(0.5) ~ 2.]
+    constr = [y(0.5) ~ 2.0]
     @mtkbuild lksys = ODESystem(eqs, t; constraints = constr)
 
-    function bc!(resid, u, p, t) 
-        resid[1] = u(0.0)[1] - 1.
-        resid[2] = u(0.5)[2] - 2.
+    function bc!(resid, u, p, t)
+        resid[1] = u(0.0)[1] - 1.0
+        resid[2] = u(0.5)[2] - 2.0
     end
     function bc(u, p, t)
-        [u(0.0)[1] - 1., u(0.5)[2] - 2.]
+        [u(0.0)[1] - 1.0, u(0.5)[2] - 2.0]
     end
 
-    u0 = [1., 1.]
-    tspan = (0., 1.)
-    p = [1.5, 1., 1., 3.]
+    u0 = [1.0, 1.0]
+    tspan = (0.0, 1.0)
+    p = [1.5, 1.0, 1.0, 3.0]
     bvpi1 = SciMLBase.BVProblem(lotkavolterra!, bc!, u0, tspan, p)
     bvpi2 = SciMLBase.BVProblem(lotkavolterra, bc, u0, tspan, p)
-    bvpi3 = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lksys, [x(t) => 1.], tspan; guesses = [y(t) => 1.])
-    bvpi4 = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(lksys, [x(t) => 1.], tspan; guesses = [y(t) => 1.])
-    
+    bvpi3 = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(
+        lksys, [x(t) => 1.0], tspan; guesses = [y(t) => 1.0])
+    bvpi4 = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(
+        lksys, [x(t) => 1.0], tspan; guesses = [y(t) => 1.0])
+
     sol1 = @btime solve($bvpi1, MIRK4(), dt = 0.01)
     sol2 = @btime solve($bvpi2, MIRK4(), dt = 0.01)
     sol3 = @btime solve($bvpi3, MIRK4(), dt = 0.01)
@@ -128,12 +133,15 @@ let
     @test sol1 ≈ sol2 ≈ sol3 ≈ sol4 # don't get true equality here, not sure why
 end
 
-function test_solvers(solvers, prob, u0map, constraints, equations = []; dt = 0.05, atol = 1e-2)
+function test_solvers(
+        solvers, prob, u0map, constraints, equations = []; dt = 0.05, atol = 1e-2)
     for solver in solvers
         println("Solver: $solver")
         sol = @btime solve($prob, $solver(), dt = $dt, abstol = $atol)
         @test SciMLBase.successful_retcode(sol.retcode)
-        p = prob.p; t = sol.t; bc = prob.f.bc
+        p = prob.p
+        t = sol.t
+        bc = prob.f.bc
         ns = length(prob.u0)
         if isinplace(prob.f)
             resid = zeros(ns)
@@ -148,7 +156,7 @@ function test_solvers(solvers, prob, u0map, constraints, equations = []; dt = 0.
         for (k, v) in u0map
             @test sol[k][1] == v
         end
-         
+
         # for cons in constraints
         #     @test sol[cons.rhs - cons.lhs] ≈ 0
         # end
@@ -163,28 +171,31 @@ end
 let
     @parameters α=1.5 β=1.0 γ=3.0 δ=1.0
     @variables x(..) y(..)
-    
+
     eqs = [D(x(t)) ~ α * x(t) - β * x(t) * y(t),
-           D(y(t)) ~ -γ * y(t) + δ * x(t) * y(t)]
-    
+        D(y(t)) ~ -γ * y(t) + δ * x(t) * y(t)]
+
     u0map = []
     tspan = (0.0, 1.0)
     guess = [x(t) => 4.0, y(t) => 2.0]
-    constr = [x(.6) ~ 3.5, x(.3) ~ 7.]
+    constr = [x(0.6) ~ 3.5, x(0.3) ~ 7.0]
     @mtkbuild lksys = ODESystem(eqs, t; constraints = constr)
 
-    bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lksys, u0map, tspan; guesses = guess)
+    bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(
+        lksys, u0map, tspan; guesses = guess)
     test_solvers(solvers, bvp, u0map, constr; dt = 0.05)
 
     # Testing that more complicated constraints give correct solutions.
-    constr = [y(.2) + x(.8) ~ 3., y(.3) ~ 2.]
+    constr = [y(0.2) + x(0.8) ~ 3.0, y(0.3) ~ 2.0]
     @mtkbuild lksys = ODESystem(eqs, t; constraints = constr)
-    bvp = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(lksys, u0map, tspan; guesses = guess)
+    bvp = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(
+        lksys, u0map, tspan; guesses = guess)
     test_solvers(solvers, bvp, u0map, constr; dt = 0.05)
 
-    constr = [α * β - x(.6) ~ 0.0, y(.2) ~ 3.]
+    constr = [α * β - x(0.6) ~ 0.0, y(0.2) ~ 3.0]
     @mtkbuild lksys = ODESystem(eqs, t; constraints = constr)
-    bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lksys, u0map, tspan; guesses = guess)
+    bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(
+        lksys, u0map, tspan; guesses = guess)
     test_solvers(solvers, bvp, u0map, constr)
 end
 
diff --git a/test/odesystem.jl b/test/odesystem.jl
index ae39aa4c5b..01df23ede8 100644
--- a/test/odesystem.jl
+++ b/test/odesystem.jl
@@ -1676,9 +1676,9 @@ end
 
 @testset "Constraint system construction" begin
     @variables x(..) y(..) z(..)
-    @parameters a b c d e 
-    eqs = [D(x(t)) ~ 3*a*y(t), D(y(t)) ~ x(t) - z(t), D(z(t)) ~ e*x(t)^2]
-    cons = [x(0.3) ~ c*d, y(0.7) ~ 3]
+    @parameters a b c d e
+    eqs = [D(x(t)) ~ 3 * a * y(t), D(y(t)) ~ x(t) - z(t), D(z(t)) ~ e * x(t)^2]
+    cons = [x(0.3) ~ c * d, y(0.7) ~ 3]
 
     # Test variables + parameters infer correctly.
     @mtkbuild sys = ODESystem(eqs, t; constraints = cons)
@@ -1688,12 +1688,12 @@ end
     @parameters t_c
     cons = [x(t_c) ~ 3]
     @mtkbuild sys = ODESystem(eqs, t; constraints = cons)
-    @test issetequal(parameters(sys), [a, e, t_c]) 
+    @test issetequal(parameters(sys), [a, e, t_c])
 
     @parameters g(..) h i
     cons = [g(h, i) * x(3) ~ c]
     @mtkbuild sys = ODESystem(eqs, t; constraints = cons)
-    @test issetequal(parameters(sys), [g, h, i, a, e, c]) 
+    @test issetequal(parameters(sys), [g, h, i, a, e, c])
 
     # Test that bad constraints throw errors.
     cons = [x(3, 4) ~ 3] # unknowns cannot have multiple args.
@@ -1716,7 +1716,7 @@ end
            2 0 0 2 1
            0 0 2 0 5]
     eqs = D(x(t)) ~ mat * x(t)
-    cons = [x(3) ~ [2,3,3,5,4]]
+    cons = [x(3) ~ [2, 3, 3, 5, 4]]
     @mtkbuild ode = ODESystem(D(x(t)) ~ mat * x(t), t; constraints = cons)
     @test length(constraints(ModelingToolkit.get_constraintsystem(ode))) == 5
 end
diff --git a/test/problem_validation.jl b/test/problem_validation.jl
index bce39b51d2..a0a7afaf3c 100644
--- a/test/problem_validation.jl
+++ b/test/problem_validation.jl
@@ -2,33 +2,33 @@ using ModelingToolkit
 using ModelingToolkit: t_nounits as t, D_nounits as D
 
 @testset "Input map validation" begin
-    import ModelingToolkit: InvalidKeyError, MissingParametersError 
+    import ModelingToolkit: InvalidKeyError, MissingParametersError
     @variables X(t)
     @parameters p d
-    eqs = [D(X) ~ p - d*X]
+    eqs = [D(X) ~ p - d * X]
     @mtkbuild osys = ODESystem(eqs, t)
-    
+
     p = "I accidentally renamed p"
     u0 = [X => 1.0]
     ps = [p => 1.0, d => 0.5]
-    @test_throws MissingParametersError oprob = ODEProblem(osys, u0, (0.0, 1.0), ps)
-    
+    @test_throws MissingParametersError oprob=ODEProblem(osys, u0, (0.0, 1.0), ps)
+
     @parameters p d
     ps = [p => 1.0, d => 0.5, "Random stuff" => 3.0]
-    @test_throws InvalidKeyError oprob = ODEProblem(osys, u0, (0.0, 1.0), ps)
+    @test_throws InvalidKeyError oprob=ODEProblem(osys, u0, (0.0, 1.0), ps)
 
     u0 = [:X => 1.0, "random" => 3.0]
-    @test_throws InvalidKeyError oprob = ODEProblem(osys, u0, (0.0, 1.0), ps)
+    @test_throws InvalidKeyError oprob=ODEProblem(osys, u0, (0.0, 1.0), ps)
 
     @variables x(t) y(t) z(t)
-    @parameters a b c d 
-    eqs = [D(x) ~ x*a, D(y) ~ y*c, D(z) ~ b + d]
+    @parameters a b c d
+    eqs = [D(x) ~ x * a, D(y) ~ y * c, D(z) ~ b + d]
     @mtkbuild sys = ODESystem(eqs, t)
     pmap = [a => 1, b => 2, c => 3, d => 4, "b" => 2]
     u0map = [x => 1, y => 2, z => 3]
-    @test_throws InvalidKeyError ODEProblem(sys, u0map, (0., 1.), pmap)
+    @test_throws InvalidKeyError ODEProblem(sys, u0map, (0.0, 1.0), pmap)
 
     pmap = [a => 1, b => 2, c => 3, d => 4]
     u0map = [x => 1, y => 2, z => 3, :0 => 3]
-    @test_throws InvalidKeyError ODEProblem(sys, u0map, (0., 1.), pmap)
+    @test_throws InvalidKeyError ODEProblem(sys, u0map, (0.0, 1.0), pmap)
 end

From f29824e829041f01f60a32ada65c5249f91a9746 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Mon, 24 Feb 2025 09:25:21 -0500
Subject: [PATCH 106/111] Format

---
 .../StructuralTransformations.jl              |  6 ++-
 .../symbolics_tearing.jl                      | 54 +++++++++++--------
 src/structural_transformation/utils.jl        |  8 +--
 .../discrete_system/discrete_system.jl        |  6 +--
 src/systems/systemstructure.jl                |  6 +--
 src/variables.jl                              |  2 +-
 test/discrete_system.jl                       | 44 +++++++--------
 7 files changed, 69 insertions(+), 57 deletions(-)

diff --git a/src/structural_transformation/StructuralTransformations.jl b/src/structural_transformation/StructuralTransformations.jl
index 7d2b9afa26..f0124d7f4b 100644
--- a/src/structural_transformation/StructuralTransformations.jl
+++ b/src/structural_transformation/StructuralTransformations.jl
@@ -11,7 +11,8 @@ using SymbolicUtils: maketerm, iscall
 
 using ModelingToolkit
 using ModelingToolkit: ODESystem, AbstractSystem, var_from_nested_derivative, Differential,
-                       unknowns, equations, vars, Symbolic, diff2term_with_unit, shift2term_with_unit, value,
+                       unknowns, equations, vars, Symbolic, diff2term_with_unit,
+                       shift2term_with_unit, value,
                        operation, arguments, Sym, Term, simplify, symbolic_linear_solve,
                        isdiffeq, isdifferential, isirreducible,
                        empty_substitutions, get_substitutions,
@@ -22,7 +23,8 @@ using ModelingToolkit: ODESystem, AbstractSystem, var_from_nested_derivative, Di
                        get_postprocess_fbody, vars!,
                        IncrementalCycleTracker, add_edge_checked!, topological_sort,
                        invalidate_cache!, Substitutions, get_or_construct_tearing_state,
-                       filter_kwargs, lower_varname_with_unit, lower_shift_varname_with_unit, setio, SparseMatrixCLIL,
+                       filter_kwargs, lower_varname_with_unit,
+                       lower_shift_varname_with_unit, setio, SparseMatrixCLIL,
                        get_fullvars, has_equations, observed,
                        Schedule, schedule
 
diff --git a/src/structural_transformation/symbolics_tearing.jl b/src/structural_transformation/symbolics_tearing.jl
index bbb6853a7c..ea594f2172 100644
--- a/src/structural_transformation/symbolics_tearing.jl
+++ b/src/structural_transformation/symbolics_tearing.jl
@@ -248,7 +248,8 @@ called dummy derivatives.
 State selection is done. All non-differentiated variables are algebraic 
 variables, and all variables that appear differentiated are differential variables.
 """
-function substitute_derivatives_algevars!(ts::TearingState, neweqs, var_eq_matching, dummy_sub; iv = nothing, D = nothing)
+function substitute_derivatives_algevars!(
+        ts::TearingState, neweqs, var_eq_matching, dummy_sub; iv = nothing, D = nothing)
     @unpack fullvars, sys, structure = ts
     @unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure
     diff_to_var = invview(var_to_diff)
@@ -288,7 +289,7 @@ end
 #= 
 There are three cases where we want to generate new variables to convert
 the system into first order (semi-implicit) ODEs.
-    
+
 1. To first order:
 Whenever higher order differentiated variable like `D(D(D(x)))` appears,
 we introduce new variables `x_t`, `x_tt`, and `x_ttt` and new equations
@@ -364,7 +365,8 @@ Effects on the system structure:
 - solvable_graph:
 - var_eq_matching: match D(x) to the added identity equation D(x) ~ x_t
 """
-function generate_derivative_variables!(ts::TearingState, neweqs, var_eq_matching; mm = nothing, iv = nothing, D = nothing)
+function generate_derivative_variables!(
+        ts::TearingState, neweqs, var_eq_matching; mm = nothing, iv = nothing, D = nothing)
     @unpack fullvars, sys, structure = ts
     @unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure
     eq_var_matching = invview(var_eq_matching)
@@ -395,7 +397,7 @@ function generate_derivative_variables!(ts::TearingState, neweqs, var_eq_matchin
         dx = fullvars[dv]
         order, lv = var_order(dv, diff_to_var)
         x_t = is_discrete ? lower_shift_varname_with_unit(fullvars[dv], iv) :
-                          lower_varname_with_unit(fullvars[lv], iv, order)
+              lower_varname_with_unit(fullvars[lv], iv, order)
 
         # Add `x_t` to the graph
         v_t = add_dd_variable!(structure, fullvars, x_t, dv)
@@ -405,7 +407,7 @@ function generate_derivative_variables!(ts::TearingState, neweqs, var_eq_matchin
         # Update matching
         push!(var_eq_matching, unassigned)
         var_eq_matching[dv] = unassigned
-        eq_var_matching[dummy_eq] = dv 
+        eq_var_matching[dummy_eq] = dv
     end
 end
 
@@ -428,7 +430,7 @@ function find_duplicate_dd(dv, solvable_graph, diff_to_var, linear_eqs, mm)
             return eq, v_t
         end
     end
-    return nothing 
+    return nothing
 end
 
 """
@@ -492,8 +494,9 @@ Order the new equations and variables such that the differential equations
 and variables come first. Return the new equations, the solved equations,
 the new orderings, and the number of solved variables and equations.
 """
-function generate_system_equations!(state::TearingState, neweqs, var_eq_matching; simplify = false, iv = nothing, D = nothing)
-    @unpack fullvars, sys, structure = state 
+function generate_system_equations!(state::TearingState, neweqs, var_eq_matching;
+        simplify = false, iv = nothing, D = nothing)
+    @unpack fullvars, sys, structure = state
     @unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure
     eq_var_matching = invview(var_eq_matching)
     diff_to_var = invview(var_to_diff)
@@ -502,11 +505,12 @@ function generate_system_equations!(state::TearingState, neweqs, var_eq_matching
     if is_only_discrete(structure)
         for (i, v) in enumerate(fullvars)
             op = operation(v)
-            op isa Shift && (op.steps < 0) && begin
-                lowered = lower_shift_varname_with_unit(v, iv)
-                total_sub[v] = lowered
-                fullvars[i] = lowered
-            end
+            op isa Shift && (op.steps < 0) &&
+                begin
+                    lowered = lower_shift_varname_with_unit(v, iv)
+                    total_sub[v] = lowered
+                    fullvars[i] = lowered
+                end
         end
     end
 
@@ -581,10 +585,11 @@ function generate_system_equations!(state::TearingState, neweqs, var_eq_matching
     end
     solved_vars_set = BitSet(solved_vars)
     var_ordering = [diff_vars;
-                   setdiff!(setdiff(1:ndsts(graph), diff_vars_set),
-                       solved_vars_set)]
+                    setdiff!(setdiff(1:ndsts(graph), diff_vars_set),
+                        solved_vars_set)]
 
-    return neweqs, solved_eqs, eq_ordering, var_ordering, length(solved_vars), length(solved_vars_set)
+    return neweqs, solved_eqs, eq_ordering, var_ordering, length(solved_vars),
+    length(solved_vars_set)
 end
 
 """
@@ -648,7 +653,8 @@ Eliminate the solved variables and equations from the graph and permute the
 graph's vertices to account for the new variable/equation ordering.
 """
 # TODO: BLT sorting
-function reorder_vars!(state::TearingState, var_eq_matching, eq_ordering, var_ordering, nsolved_eq, nsolved_var)
+function reorder_vars!(state::TearingState, var_eq_matching, eq_ordering,
+        var_ordering, nsolved_eq, nsolved_var)
     @unpack solvable_graph, var_to_diff, eq_to_diff, graph = state.structure
 
     eqsperm = zeros(Int, nsrcs(graph))
@@ -692,7 +698,8 @@ end
 """
 Update the system equations, unknowns, and observables after simplification.
 """
-function update_simplified_system!(state::TearingState, neweqs, solved_eqs, dummy_sub, var_eq_matching, extra_unknowns; 
+function update_simplified_system!(
+        state::TearingState, neweqs, solved_eqs, dummy_sub, var_eq_matching, extra_unknowns;
         cse_hack = true, array_hack = true)
     @unpack solvable_graph, var_to_diff, eq_to_diff, graph = state.structure
     diff_to_var = invview(var_to_diff)
@@ -732,7 +739,6 @@ function update_simplified_system!(state::TearingState, neweqs, solved_eqs, dumm
     sys = schedule(sys)
 end
 
-    
 """
 Give the order of the variable indexed by dv.
 """
@@ -790,12 +796,14 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
 
     generate_derivative_variables!(state, neweqs, var_eq_matching; mm, iv, D)
 
-    neweqs, solved_eqs, eq_ordering, var_ordering, nelim_eq, nelim_var = 
-        generate_system_equations!(state, neweqs, var_eq_matching; simplify, iv, D)
+    neweqs, solved_eqs, eq_ordering, var_ordering, nelim_eq, nelim_var = generate_system_equations!(
+        state, neweqs, var_eq_matching; simplify, iv, D)
 
-    state = reorder_vars!(state, var_eq_matching, eq_ordering, var_ordering, nelim_eq, nelim_var)
+    state = reorder_vars!(
+        state, var_eq_matching, eq_ordering, var_ordering, nelim_eq, nelim_var)
 
-    sys = update_simplified_system!(state, neweqs, solved_eqs, dummy_sub, var_eq_matching, extra_unknowns; cse_hack, array_hack)
+    sys = update_simplified_system!(state, neweqs, solved_eqs, dummy_sub, var_eq_matching,
+        extra_unknowns; cse_hack, array_hack)
 
     @set! state.sys = sys
     @set! sys.tearing_state = state
diff --git a/src/structural_transformation/utils.jl b/src/structural_transformation/utils.jl
index 96f7f78f99..f3cea9c7ba 100644
--- a/src/structural_transformation/utils.jl
+++ b/src/structural_transformation/utils.jl
@@ -477,15 +477,17 @@ function shift2term(var)
     backshift = is_lowered ? op.steps + ModelingToolkit.getshift(arg) : op.steps
 
     num = join(Char(0x2080 + d) for d in reverse!(digits(-backshift))) # subscripted number, e.g. ₁
-    ds = join([Char(0x209c), Char(0x208b), num]) 
+    ds = join([Char(0x209c), Char(0x208b), num])
     # Char(0x209c) = ₜ
     # Char(0x208b) = ₋ (subscripted minus)
 
     O = is_lowered ? ModelingToolkit.getunshifted(arg) : arg
     oldop = operation(O)
-    newname = backshift != 0 ? Symbol(string(nameof(oldop)), ds) : Symbol(string(nameof(oldop)))
+    newname = backshift != 0 ? Symbol(string(nameof(oldop)), ds) :
+              Symbol(string(nameof(oldop)))
 
-    newvar = maketerm(typeof(O), Symbolics.rename(oldop, newname), Symbolics.children(O), Symbolics.metadata(O))
+    newvar = maketerm(typeof(O), Symbolics.rename(oldop, newname),
+        Symbolics.children(O), Symbolics.metadata(O))
     newvar = setmetadata(newvar, Symbolics.VariableSource, (:variables, newname))
     newvar = setmetadata(newvar, ModelingToolkit.VariableUnshifted, O)
     newvar = setmetadata(newvar, ModelingToolkit.VariableShift, backshift)
diff --git a/src/systems/discrete_system/discrete_system.jl b/src/systems/discrete_system/discrete_system.jl
index 3f8d9e85c6..e4949d812f 100644
--- a/src/systems/discrete_system/discrete_system.jl
+++ b/src/systems/discrete_system/discrete_system.jl
@@ -269,15 +269,15 @@ function shift_u0map_forward(sys::DiscreteSystem, u0map, defs)
     for k in collect(keys(u0map))
         v = u0map[k]
         if !((op = operation(k)) isa Shift)
-            isnothing(getunshifted(k)) && error("Initial conditions must be for the past state of the unknowns. Instead of providing the condition for $k, provide the condition for $(Shift(iv, -1)(k)).")
-            
+            isnothing(getunshifted(k)) &&
+                error("Initial conditions must be for the past state of the unknowns. Instead of providing the condition for $k, provide the condition for $(Shift(iv, -1)(k)).")
+
             updated[Shift(iv, 1)(k)] = v
         elseif op.steps > 0
             error("Initial conditions must be for the past state of the unknowns. Instead of providing the condition for $k, provide the condition for $(Shift(iv, -1)(only(arguments(k)))).")
         else
             updated[Shift(iv, op.steps + 1)(only(arguments(k)))] = v
         end
-
     end
     for var in unknowns(sys)
         op = operation(var)
diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl
index 6fee78cfd6..0643f32ec4 100644
--- a/src/systems/systemstructure.jl
+++ b/src/systems/systemstructure.jl
@@ -473,9 +473,9 @@ function shift_discrete_system(ts::TearingState)
     end
     iv = get_iv(sys)
 
-    discmap = Dict(k => StructuralTransformations.simplify_shifts(Shift(iv, 1)(k))    
-                   for k in discvars 
-                   if any(isequal(k), fullvars) && !isa(operation(k), Union{Sample, Hold})) 
+    discmap = Dict(k => StructuralTransformations.simplify_shifts(Shift(iv, 1)(k))
+    for k in discvars
+    if any(isequal(k), fullvars) && !isa(operation(k), Union{Sample, Hold}))
 
     for i in eachindex(fullvars)
         fullvars[i] = StructuralTransformations.simplify_shifts(fast_substitute(
diff --git a/src/variables.jl b/src/variables.jl
index 4e13ad2c5d..a7dde165f9 100644
--- a/src/variables.jl
+++ b/src/variables.jl
@@ -138,7 +138,7 @@ function default_toterm(x)
     if iscall(x) && (op = operation(x)) isa Operator
         if !(op isa Differential)
             if op isa Shift && op.steps < 0
-                return shift2term(x) 
+                return shift2term(x)
             end
             x = normalize_to_differential(op)(arguments(x)...)
         end
diff --git a/test/discrete_system.jl b/test/discrete_system.jl
index 756e5bca48..f232faaf81 100644
--- a/test/discrete_system.jl
+++ b/test/discrete_system.jl
@@ -257,7 +257,6 @@ k = ShiftIndex(t)
 @named sys = DiscreteSystem([x ~ x^2 + y^2, y ~ x(k - 1) + y(k - 1)], t)
 @test_throws ["algebraic equations", "not yet supported"] structural_simplify(sys)
 
-
 @testset "Passing `nothing` to `u0`" begin
     @variables x(t) = 1
     k = ShiftIndex()
@@ -273,11 +272,11 @@ end
     prob = DiscreteProblem(de, [], (0, 10))
     @test prob[x] == 2.0
     @test prob[x(k - 1)] == 1.0
-    
+
     # must provide initial conditions for history
     @test_throws ErrorException DiscreteProblem(de, [x => 2.0], (0, 10))
-    @test_throws ErrorException DiscreteProblem(de, [x(k+1) => 2.], (0, 10))
-    
+    @test_throws ErrorException DiscreteProblem(de, [x(k + 1) => 2.0], (0, 10))
+
     # initial values only affect _that timestep_, not the entire history
     prob = DiscreteProblem(de, [x(k - 1) => 2.0], (0, 10))
     @test prob[x] == 3.0
@@ -286,34 +285,35 @@ end
     @test prob[xₜ₋₁] == 2.0
 
     # Test initial assignment with lowered variable
-    prob = DiscreteProblem(de, [xₜ₋₁(k-1) => 4.0], (0, 10))
-    @test prob[x(k-1)] == prob[xₜ₋₁] == 1.0
-    @test prob[x] == 5.
+    prob = DiscreteProblem(de, [xₜ₋₁(k - 1) => 4.0], (0, 10))
+    @test prob[x(k - 1)] == prob[xₜ₋₁] == 1.0
+    @test prob[x] == 5.0
 
     # Test missing initial throws error
     @variables x(t)
-    @mtkbuild de = DiscreteSystem([x ~ x(k-1) + x(k-2)*x(k-3)], t)
-    @test_throws ErrorException prob = DiscreteProblem(de, [x(k-3) => 2.], (0, 10))
-    @test_throws ErrorException prob = DiscreteProblem(de, [x(k-3) => 2., x(k-1) => 3.], (0, 10))
+    @mtkbuild de = DiscreteSystem([x ~ x(k - 1) + x(k - 2) * x(k - 3)], t)
+    @test_throws ErrorException prob=DiscreteProblem(de, [x(k - 3) => 2.0], (0, 10))
+    @test_throws ErrorException prob=DiscreteProblem(
+        de, [x(k - 3) => 2.0, x(k - 1) => 3.0], (0, 10))
 
     # Test non-assigned initials are given default value
-    @variables x(t) = 2.
-    @mtkbuild de = DiscreteSystem([x ~ x(k-1) + x(k-2)*x(k-3)], t)
-    prob = DiscreteProblem(de, [x(k-3) => 12.], (0, 10))
+    @variables x(t) = 2.0
+    @mtkbuild de = DiscreteSystem([x ~ x(k - 1) + x(k - 2) * x(k - 3)], t)
+    prob = DiscreteProblem(de, [x(k - 3) => 12.0], (0, 10))
     @test prob[x] == 26.0
-    @test prob[x(k-1)] == 2.0
-    @test prob[x(k-2)] == 2.0
+    @test prob[x(k - 1)] == 2.0
+    @test prob[x(k - 2)] == 2.0
 
     # Elaborate test
     @variables xₜ₋₂(t) zₜ₋₁(t) z(t)
-    eqs = [x ~ x(k-1) + z(k-2), 
-           z ~ x(k-2) * x(k-3) - z(k-1)^2]
+    eqs = [x ~ x(k - 1) + z(k - 2),
+        z ~ x(k - 2) * x(k - 3) - z(k - 1)^2]
     @mtkbuild de = DiscreteSystem(eqs, t)
-    u0 = [x(k-1) => 3, 
-          xₜ₋₂(k-1) => 4, 
-          x(k-2) => 1, 
-          z(k-1) => 5, 
-          zₜ₋₁(k-1) => 12]
+    u0 = [x(k - 1) => 3,
+        xₜ₋₂(k - 1) => 4,
+        x(k - 2) => 1,
+        z(k - 1) => 5,
+        zₜ₋₁(k - 1) => 12]
     prob = DiscreteProblem(de, u0, (0, 10))
     @test prob[x] == 15
     @test prob[z] == -21

From b40c2a1dbda36d481067ffa5450599540cefdc40 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Mon, 24 Feb 2025 11:58:05 -0500
Subject: [PATCH 107/111] fix Complex typecheck

---
 src/utils.jl      | 16 ++++++++++++----
 test/odesystem.jl |  4 ++++
 2 files changed, 16 insertions(+), 4 deletions(-)

diff --git a/src/utils.jl b/src/utils.jl
index 962801622a..4685e76ebc 100644
--- a/src/utils.jl
+++ b/src/utils.jl
@@ -222,7 +222,8 @@ function collect_ivs_from_nested_operator!(ivs, x, target_op)
 end
 
 function iv_from_nested_derivative(x, op = Differential)
-    if iscall(x) && operation(x) == getindex
+    if iscall(x) &&
+       (operation(x) == getindex || operation(x) == real || operation(x) == imag)
         iv_from_nested_derivative(arguments(x)[1], op)
     elseif iscall(x)
         operation(x) isa op ? iv_from_nested_derivative(arguments(x)[1], op) :
@@ -1204,7 +1205,7 @@ end
 Find all the unknowns and parameters from the equations of a SDESystem or ODESystem. Return re-ordered equations, differential variables, all variables, and parameters.
 """
 function process_equations(eqs, iv)
-    eqs = collect(eqs)
+    eqs = collect(Iterators.flatten(eqs))
 
     diffvars = OrderedSet()
     allunknowns = OrderedSet()
@@ -1237,8 +1238,8 @@ function process_equations(eqs, iv)
                     throw(ArgumentError("An ODESystem can only have one independent variable."))
                 diffvar in diffvars &&
                     throw(ArgumentError("The differential variable $diffvar is not unique in the system of equations."))
-                !(symtype(diffvar) === Real || eltype(symtype(diffvar)) === Real) &&
-                    throw(ArgumentError("Differential variable $diffvar has type $(symtype(diffvar)). Differential variables should not be concretely typed."))
+                !has_diffvar_type(diffvar) &&
+                    throw(ArgumentError("Differential variable $diffvar has type $(symtype(diffvar)). Differential variables should be of a continuous, non-concrete number type: Real, Complex, AbstractFloat, or Number."))
                 push!(diffvars, diffvar)
             end
             push!(diffeq, eq)
@@ -1250,6 +1251,13 @@ function process_equations(eqs, iv)
     diffvars, allunknowns, ps, Equation[diffeq; algeeq; compressed_eqs]
 end
 
+function has_diffvar_type(diffvar)
+    st = symtype(diffvar)
+    st === Real || eltype(st) === Real || st === Complex || eltype(st) === Complex ||
+        st === Number || eltype(st) === Number || st === AbstractFloat ||
+        eltype(st) === AbstractFloat
+end
+
 """
     $(TYPEDSIGNATURES)
 
diff --git a/test/odesystem.jl b/test/odesystem.jl
index 01df23ede8..b9a41e1c3d 100644
--- a/test/odesystem.jl
+++ b/test/odesystem.jl
@@ -1589,6 +1589,10 @@ end
     @variables Y(t)[1:3]::String
     eq = D(Y) ~ [p, p, p]
     @test_throws ArgumentError @mtkbuild osys = ODESystem([eq], t)
+
+    @variables X(t)::Complex
+    eq = D(X) ~ p - d * X
+    @test_nowarn @named osys = ODESystem([eq], t)
 end
 
 # Test `isequal`

From 956079768e8a8c1785440a8b370b3acb5b1fb756 Mon Sep 17 00:00:00 2001
From: vyudu <vincent.duyuan@gmail.com>
Date: Mon, 24 Feb 2025 12:27:10 -0500
Subject: [PATCH 108/111] up

---
 src/utils.jl | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/src/utils.jl b/src/utils.jl
index 4685e76ebc..950af5a1e2 100644
--- a/src/utils.jl
+++ b/src/utils.jl
@@ -1205,7 +1205,8 @@ end
 Find all the unknowns and parameters from the equations of a SDESystem or ODESystem. Return re-ordered equations, differential variables, all variables, and parameters.
 """
 function process_equations(eqs, iv)
-    eqs = collect(Iterators.flatten(eqs))
+    eltype(eqs) <: Vector && (eqs = vcat(eqs...))
+    eqs = collect(eqs)
 
     diffvars = OrderedSet()
     allunknowns = OrderedSet()

From ddf0d12c06e9afc7777d83777f6537bdb214757c Mon Sep 17 00:00:00 2001
From: Aayush Sabharwal <aayush.sabharwal@gmail.com>
Date: Wed, 26 Feb 2025 22:08:02 +0530
Subject: [PATCH 109/111] feat: mark `getproperty(::AbstractSystem, ::Symbol)`
 as non-differentiable

---
 ext/MTKChainRulesCoreExt.jl |  2 ++
 test/extensions/ad.jl       | 10 ++++++++++
 2 files changed, 12 insertions(+)

diff --git a/ext/MTKChainRulesCoreExt.jl b/ext/MTKChainRulesCoreExt.jl
index f153ee77de..9cf17d203d 100644
--- a/ext/MTKChainRulesCoreExt.jl
+++ b/ext/MTKChainRulesCoreExt.jl
@@ -103,4 +103,6 @@ function ChainRulesCore.rrule(
     newbuf, pullback
 end
 
+ChainRulesCore.@non_differentiable Base.getproperty(sys::MTK.AbstractSystem, x::Symbol)
+
 end
diff --git a/test/extensions/ad.jl b/test/extensions/ad.jl
index 0e72b2b7b7..adaf6117c6 100644
--- a/test/extensions/ad.jl
+++ b/test/extensions/ad.jl
@@ -124,3 +124,13 @@ fwd, back = ChainRulesCore.rrule(remake_buffer, sys, ps, idxs, vals)
     nsol = solve(nprob, NewtonRaphson())
     @test nsol[1] ≈ 10.0 / 1.0 + 9.81 * 1.0 / 2 # anal free fall solution is y = v0*t - g*t^2/2 -> v0 = y/t + g*t/2
 end
+
+@testset "`sys.var` is non-differentiable" begin
+    @variables x(t)
+    @mtkbuild sys = ODESystem(D(x) ~ x, t)
+    prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0))
+
+    grad = Zygote.gradient(prob) do prob
+        prob[sys.x]
+    end
+end

From f289c5a5934766376e4b05bab4a0b777499a931c Mon Sep 17 00:00:00 2001
From: David Widmann <devmotion@users.noreply.github.com>
Date: Wed, 26 Feb 2025 19:31:41 +0100
Subject: [PATCH 110/111] Replace `vcat(eqs...)` with `reduce(vcat, eqs)`

---
 src/utils.jl | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/src/utils.jl b/src/utils.jl
index 22ef4dc160..78c665ffca 100644
--- a/src/utils.jl
+++ b/src/utils.jl
@@ -1217,7 +1217,9 @@ end
 Find all the unknowns and parameters from the equations of a SDESystem or ODESystem. Return re-ordered equations, differential variables, all variables, and parameters.
 """
 function process_equations(eqs, iv)
-    eltype(eqs) <: Vector && (eqs = vcat(eqs...))
+    if eltype(eqs) <: AbstractVector
+        eqs = reduce(vcat, eqs)
+    end
     eqs = collect(eqs)
 
     diffvars = OrderedSet()

From 3d9a8d847560717a8eab615f4ad16bb42451947c Mon Sep 17 00:00:00 2001
From: Christopher Rackauckas <accounts@chrisrackauckas.com>
Date: Wed, 26 Feb 2025 10:57:59 -0800
Subject: [PATCH 111/111] Update Project.toml

---
 Project.toml | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/Project.toml b/Project.toml
index 14ea5743fa..8c4a58045e 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,7 +1,7 @@
 name = "ModelingToolkit"
 uuid = "961ee093-0014-501f-94e3-6117800e7a78"
 authors = ["Yingbo Ma <mayingbo5@gmail.com>", "Chris Rackauckas <accounts@chrisrackauckas.com> and contributors"]
-version = "9.64.1"
+version = "9.64.2"
 
 [deps]
 AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"