Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add map_variables_to_equations #3417

Merged
merged 3 commits into from
Feb 27, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/src/basics/DependencyGraphs.md
Original file line number Diff line number Diff line change
@@ -22,3 +22,9 @@ asdigraph
eqeq_dependencies
varvar_dependencies
```

# Miscellaneous

```@docs
map_variables_to_equations
```
1 change: 1 addition & 0 deletions src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
@@ -276,6 +276,7 @@ export TearingState
export BipartiteGraph, equation_dependencies, variable_dependencies
export eqeq_dependencies, varvar_dependencies
export asgraph, asdigraph
export map_variables_to_equations

export toexpr, get_variables
export simplify, substitute
54 changes: 54 additions & 0 deletions src/systems/systems.jl
Original file line number Diff line number Diff line change
@@ -158,3 +158,57 @@ function __structural_simplify(sys::AbstractSystem, io = nothing; simplify = fal
guesses = guesses(sys), initialization_eqs = initialization_equations(sys))
end
end

"""
$(TYPEDSIGNATURES)

Given a system that has been simplified via `structural_simplify`, return a `Dict` mapping
variables of the system to equations that are used to solve for them. This includes
observed variables.

# Keyword Arguments

- `rename_dummy_derivatives`: Whether to rename dummy derivative variable keys into their
`Differential` forms. For example, this would turn the key `yˍt(t)` into
`Differential(t)(y(t))`.
"""
function map_variables_to_equations(sys::AbstractSystem; rename_dummy_derivatives = true)
if !has_tearing_state(sys)
throw(ArgumentError("$(typeof(sys)) is not supported."))
end
ts = get_tearing_state(sys)
if ts === nothing
throw(ArgumentError("`map_variables_to_equations` requires a simplified system. Call `structural_simplify` on the system before calling this function."))
end

dummy_sub = Dict()
if rename_dummy_derivatives && has_schedule(sys) && (sc = get_schedule(sys)) !== nothing
dummy_sub = Dict(v => k for (k, v) in sc.dummy_sub if isequal(default_toterm(k), v))
end

mapping = Dict{Union{Num, BasicSymbolic}, Equation}()
eqs = equations(sys)
for eq in eqs
isdifferential(eq.lhs) || continue
var = arguments(eq.lhs)[1]
var = get(dummy_sub, var, var)
mapping[var] = eq
end

graph = ts.structure.graph
algvars = BitSet(findall(
Base.Fix1(StructuralTransformations.isalgvar, ts.structure), 1:ndsts(graph)))
algeqs = BitSet(findall(1:nsrcs(graph)) do eq
all(!Base.Fix1(isdervar, ts.structure), 𝑠neighbors(graph, eq))
end)
alge_var_eq_matching = complete(maximal_matching(graph, in(algeqs), in(algvars)))
for (i, eq) in enumerate(alge_var_eq_matching)
eq isa Unassigned && continue
mapping[get(dummy_sub, ts.fullvars[i], ts.fullvars[i])] = eqs[eq]
end
for eq in observed(sys)
mapping[get(dummy_sub, eq.lhs, eq.lhs)] = eq
end

return mapping
end
98 changes: 96 additions & 2 deletions test/structural_transformation/utils.jl
Original file line number Diff line number Diff line change
@@ -3,8 +3,8 @@ using ModelingToolkit
using Graphs
using SparseArrays
using UnPack
using ModelingToolkit: t_nounits as t, D_nounits as D
const ST = StructuralTransformations
using ModelingToolkit: t_nounits as t, D_nounits as D, default_toterm
using Symbolics: unwrap

# Define some variables
@parameters L g
@@ -162,3 +162,97 @@ end
structural_simplify(sys; additional_passes = [pass])
@test value[] == 1
end

@testset "`map_variables_to_equations`" begin
@testset "Not supported for systems without `.tearing_state`" begin
@variables x
@mtkbuild sys = OptimizationSystem(x^2)
@test_throws ArgumentError map_variables_to_equations(sys)
end
@testset "Requires simplified system" begin
@variables x(t) y(t)
@named sys = ODESystem([D(x) ~ x, y ~ 2x], t)
sys = complete(sys)
@test_throws ArgumentError map_variables_to_equations(sys)
end
@testset "`ODESystem`" begin
@variables x(t) y(t) z(t)
@mtkbuild sys = ODESystem([D(x) ~ 2x + y, y ~ x + z, z^3 + x^3 ~ 12], t)
mapping = map_variables_to_equations(sys)
@test mapping[x] == (D(x) ~ 2x + y)
@test mapping[y] == (y ~ x + z)
@test mapping[z] == (0 ~ 12 - z^3 - x^3)
@test length(mapping) == 3

@testset "With dummy derivatives" begin
@parameters g
@variables x(t) y(t) [state_priority = 10] λ(t)
eqs = [D(D(x)) ~ λ * x
D(D(y)) ~ λ * y - g
x^2 + y^2 ~ 1]
@mtkbuild sys = ODESystem(eqs, t)
mapping = map_variables_to_equations(sys)

yt = default_toterm(unwrap(D(y)))
xt = default_toterm(unwrap(D(x)))
xtt = default_toterm(unwrap(D(D(x))))
@test mapping[x] == (0 ~ 1 - x^2 - y^2)
@test mapping[y] == (D(y) ~ yt)
@test mapping[D(y)] == (D(yt) ~ -g + y * λ)
@test mapping[D(x)] == (0 ~ -2xt * x - 2yt * y)
@test mapping[D(D(x))] == (xtt ~ x * λ)
@test length(mapping) == 5

@testset "`rename_dummy_derivatives = false`" begin
mapping = map_variables_to_equations(sys; rename_dummy_derivatives = false)

@test mapping[x] == (0 ~ 1 - x^2 - y^2)
@test mapping[y] == (D(y) ~ yt)
@test mapping[yt] == (D(yt) ~ -g + y * λ)
@test mapping[xt] == (0 ~ -2xt * x - 2yt * y)
@test mapping[xtt] == (xtt ~ x * λ)
@test length(mapping) == 5
end
end
@testset "DDEs" begin
function oscillator(; name, k = 1.0, τ = 0.01)
@parameters k=k τ=τ
@variables x(..)=0.1 y(t)=0.1 jcn(t)=0.0 delx(t)
eqs = [D(x(t)) ~ y,
D(y) ~ -k * x(t - τ) + jcn,
delx ~ x(t - τ)]
return System(eqs, t; name = name)
end

systems = @named begin
osc1 = oscillator(k = 1.0, τ = 0.01)
osc2 = oscillator(k = 2.0, τ = 0.04)
end
eqs = [osc1.jcn ~ osc2.delx,
osc2.jcn ~ osc1.delx]
@named coupledOsc = System(eqs, t)
@mtkbuild sys = compose(coupledOsc, systems)
mapping = map_variables_to_equations(sys)
x1 = operation(unwrap(osc1.x))
x2 = operation(unwrap(osc2.x))
@test mapping[osc1.x] == (D(osc1.x) ~ osc1.y)
@test mapping[osc1.y] == (D(osc1.y) ~ osc1.jcn - osc1.k * x1(t - osc1.τ))
@test mapping[osc1.delx] == (osc1.delx ~ x1(t - osc1.τ))
@test mapping[osc1.jcn] == (osc1.jcn ~ osc2.delx)
@test mapping[osc2.x] == (D(osc2.x) ~ osc2.y)
@test mapping[osc2.y] == (D(osc2.y) ~ osc2.jcn - osc2.k * x2(t - osc2.τ))
@test mapping[osc2.delx] == (osc2.delx ~ x2(t - osc2.τ))
@test mapping[osc2.jcn] == (osc2.jcn ~ osc1.delx)
@test length(mapping) == 8
end
end
@testset "`NonlinearSystem`" begin
@variables x y z
@mtkbuild sys = NonlinearSystem([x^2 ~ 2y^2 + 1, sin(z) ~ y, z^3 + 4z + 1 ~ 0])
mapping = map_variables_to_equations(sys)
@test mapping[x] == (0 ~ 2y^2 + 1 - x^2)
@test mapping[y] == (y ~ sin(z))
@test mapping[z] == (0 ~ -1 - 4z - z^3)
@test length(mapping) == 3
end
end
Loading