Skip to content

Commit 1fcf103

Browse files
Merge pull request #3417 from AayushSabharwal/as/associated-equation
feat: add `map_variables_to_equations`
2 parents 18d2362 + 55601e0 commit 1fcf103

File tree

4 files changed

+157
-2
lines changed

4 files changed

+157
-2
lines changed

docs/src/basics/DependencyGraphs.md

+6
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,9 @@ asdigraph
2222
eqeq_dependencies
2323
varvar_dependencies
2424
```
25+
26+
# Miscellaneous
27+
28+
```@docs
29+
map_variables_to_equations
30+
```

src/ModelingToolkit.jl

+1
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,7 @@ export TearingState
277277
export BipartiteGraph, equation_dependencies, variable_dependencies
278278
export eqeq_dependencies, varvar_dependencies
279279
export asgraph, asdigraph
280+
export map_variables_to_equations
280281

281282
export toexpr, get_variables
282283
export simplify, substitute

src/systems/systems.jl

+54
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,57 @@ function __structural_simplify(sys::AbstractSystem, io = nothing; simplify = fal
158158
guesses = guesses(sys), initialization_eqs = initialization_equations(sys))
159159
end
160160
end
161+
162+
"""
163+
$(TYPEDSIGNATURES)
164+
165+
Given a system that has been simplified via `structural_simplify`, return a `Dict` mapping
166+
variables of the system to equations that are used to solve for them. This includes
167+
observed variables.
168+
169+
# Keyword Arguments
170+
171+
- `rename_dummy_derivatives`: Whether to rename dummy derivative variable keys into their
172+
`Differential` forms. For example, this would turn the key `yˍt(t)` into
173+
`Differential(t)(y(t))`.
174+
"""
175+
function map_variables_to_equations(sys::AbstractSystem; rename_dummy_derivatives = true)
176+
if !has_tearing_state(sys)
177+
throw(ArgumentError("$(typeof(sys)) is not supported."))
178+
end
179+
ts = get_tearing_state(sys)
180+
if ts === nothing
181+
throw(ArgumentError("`map_variables_to_equations` requires a simplified system. Call `structural_simplify` on the system before calling this function."))
182+
end
183+
184+
dummy_sub = Dict()
185+
if rename_dummy_derivatives && has_schedule(sys) && (sc = get_schedule(sys)) !== nothing
186+
dummy_sub = Dict(v => k for (k, v) in sc.dummy_sub if isequal(default_toterm(k), v))
187+
end
188+
189+
mapping = Dict{Union{Num, BasicSymbolic}, Equation}()
190+
eqs = equations(sys)
191+
for eq in eqs
192+
isdifferential(eq.lhs) || continue
193+
var = arguments(eq.lhs)[1]
194+
var = get(dummy_sub, var, var)
195+
mapping[var] = eq
196+
end
197+
198+
graph = ts.structure.graph
199+
algvars = BitSet(findall(
200+
Base.Fix1(StructuralTransformations.isalgvar, ts.structure), 1:ndsts(graph)))
201+
algeqs = BitSet(findall(1:nsrcs(graph)) do eq
202+
all(!Base.Fix1(isdervar, ts.structure), 𝑠neighbors(graph, eq))
203+
end)
204+
alge_var_eq_matching = complete(maximal_matching(graph, in(algeqs), in(algvars)))
205+
for (i, eq) in enumerate(alge_var_eq_matching)
206+
eq isa Unassigned && continue
207+
mapping[get(dummy_sub, ts.fullvars[i], ts.fullvars[i])] = eqs[eq]
208+
end
209+
for eq in observed(sys)
210+
mapping[get(dummy_sub, eq.lhs, eq.lhs)] = eq
211+
end
212+
213+
return mapping
214+
end

test/structural_transformation/utils.jl

+96-2
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ using ModelingToolkit
33
using Graphs
44
using SparseArrays
55
using UnPack
6-
using ModelingToolkit: t_nounits as t, D_nounits as D
7-
const ST = StructuralTransformations
6+
using ModelingToolkit: t_nounits as t, D_nounits as D, default_toterm
7+
using Symbolics: unwrap
88

99
# Define some variables
1010
@parameters L g
@@ -162,3 +162,97 @@ end
162162
structural_simplify(sys; additional_passes = [pass])
163163
@test value[] == 1
164164
end
165+
166+
@testset "`map_variables_to_equations`" begin
167+
@testset "Not supported for systems without `.tearing_state`" begin
168+
@variables x
169+
@mtkbuild sys = OptimizationSystem(x^2)
170+
@test_throws ArgumentError map_variables_to_equations(sys)
171+
end
172+
@testset "Requires simplified system" begin
173+
@variables x(t) y(t)
174+
@named sys = ODESystem([D(x) ~ x, y ~ 2x], t)
175+
sys = complete(sys)
176+
@test_throws ArgumentError map_variables_to_equations(sys)
177+
end
178+
@testset "`ODESystem`" begin
179+
@variables x(t) y(t) z(t)
180+
@mtkbuild sys = ODESystem([D(x) ~ 2x + y, y ~ x + z, z^3 + x^3 ~ 12], t)
181+
mapping = map_variables_to_equations(sys)
182+
@test mapping[x] == (D(x) ~ 2x + y)
183+
@test mapping[y] == (y ~ x + z)
184+
@test mapping[z] == (0 ~ 12 - z^3 - x^3)
185+
@test length(mapping) == 3
186+
187+
@testset "With dummy derivatives" begin
188+
@parameters g
189+
@variables x(t) y(t) [state_priority = 10] λ(t)
190+
eqs = [D(D(x)) ~ λ * x
191+
D(D(y)) ~ λ * y - g
192+
x^2 + y^2 ~ 1]
193+
@mtkbuild sys = ODESystem(eqs, t)
194+
mapping = map_variables_to_equations(sys)
195+
196+
yt = default_toterm(unwrap(D(y)))
197+
xt = default_toterm(unwrap(D(x)))
198+
xtt = default_toterm(unwrap(D(D(x))))
199+
@test mapping[x] == (0 ~ 1 - x^2 - y^2)
200+
@test mapping[y] == (D(y) ~ yt)
201+
@test mapping[D(y)] == (D(yt) ~ -g + y * λ)
202+
@test mapping[D(x)] == (0 ~ -2xt * x - 2yt * y)
203+
@test mapping[D(D(x))] == (xtt ~ x * λ)
204+
@test length(mapping) == 5
205+
206+
@testset "`rename_dummy_derivatives = false`" begin
207+
mapping = map_variables_to_equations(sys; rename_dummy_derivatives = false)
208+
209+
@test mapping[x] == (0 ~ 1 - x^2 - y^2)
210+
@test mapping[y] == (D(y) ~ yt)
211+
@test mapping[yt] == (D(yt) ~ -g + y * λ)
212+
@test mapping[xt] == (0 ~ -2xt * x - 2yt * y)
213+
@test mapping[xtt] == (xtt ~ x * λ)
214+
@test length(mapping) == 5
215+
end
216+
end
217+
@testset "DDEs" begin
218+
function oscillator(; name, k = 1.0, τ = 0.01)
219+
@parameters k=k τ=τ
220+
@variables x(..)=0.1 y(t)=0.1 jcn(t)=0.0 delx(t)
221+
eqs = [D(x(t)) ~ y,
222+
D(y) ~ -k * x(t - τ) + jcn,
223+
delx ~ x(t - τ)]
224+
return System(eqs, t; name = name)
225+
end
226+
227+
systems = @named begin
228+
osc1 = oscillator(k = 1.0, τ = 0.01)
229+
osc2 = oscillator(k = 2.0, τ = 0.04)
230+
end
231+
eqs = [osc1.jcn ~ osc2.delx,
232+
osc2.jcn ~ osc1.delx]
233+
@named coupledOsc = System(eqs, t)
234+
@mtkbuild sys = compose(coupledOsc, systems)
235+
mapping = map_variables_to_equations(sys)
236+
x1 = operation(unwrap(osc1.x))
237+
x2 = operation(unwrap(osc2.x))
238+
@test mapping[osc1.x] == (D(osc1.x) ~ osc1.y)
239+
@test mapping[osc1.y] == (D(osc1.y) ~ osc1.jcn - osc1.k * x1(t - osc1.τ))
240+
@test mapping[osc1.delx] == (osc1.delx ~ x1(t - osc1.τ))
241+
@test mapping[osc1.jcn] == (osc1.jcn ~ osc2.delx)
242+
@test mapping[osc2.x] == (D(osc2.x) ~ osc2.y)
243+
@test mapping[osc2.y] == (D(osc2.y) ~ osc2.jcn - osc2.k * x2(t - osc2.τ))
244+
@test mapping[osc2.delx] == (osc2.delx ~ x2(t - osc2.τ))
245+
@test mapping[osc2.jcn] == (osc2.jcn ~ osc1.delx)
246+
@test length(mapping) == 8
247+
end
248+
end
249+
@testset "`NonlinearSystem`" begin
250+
@variables x y z
251+
@mtkbuild sys = NonlinearSystem([x^2 ~ 2y^2 + 1, sin(z) ~ y, z^3 + 4z + 1 ~ 0])
252+
mapping = map_variables_to_equations(sys)
253+
@test mapping[x] == (0 ~ 2y^2 + 1 - x^2)
254+
@test mapping[y] == (y ~ sin(z))
255+
@test mapping[z] == (0 ~ -1 - 4z - z^3)
256+
@test length(mapping) == 3
257+
end
258+
end

0 commit comments

Comments
 (0)