Skip to content

Commit e7917b8

Browse files
refactor: update implementation of discrete save interface
1 parent 51c5446 commit e7917b8

6 files changed

+80
-55
lines changed

src/systems/abstractsystem.jl

+40-20
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,8 @@ function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym)
447447
sym = unwrap(sym)
448448
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
449449
return sym isa ParameterIndex || is_parameter(ic, sym) ||
450-
iscall(sym) && operation(sym) === getindex &&
450+
iscall(sym) &&
451+
operation(sym) === getindex &&
451452
is_parameter(ic, first(arguments(sym)))
452453
end
453454
if unwrap(sym) isa Int
@@ -526,34 +527,19 @@ function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym::Sym
526527
end
527528

528529
function SymbolicIndexingInterface.is_timeseries_parameter(sys::AbstractSystem, sym)
530+
is_time_dependent(sys) || return false
529531
has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing || return false
530532
is_timeseries_parameter(ic, sym)
531533
end
532534

533535
function SymbolicIndexingInterface.timeseries_parameter_index(sys::AbstractSystem, sym)
536+
is_time_dependent(sys) || return nothing
534537
has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing || return nothing
535538
timeseries_parameter_index(ic, sym)
536539
end
537540

538541
function SymbolicIndexingInterface.parameter_observed(sys::AbstractSystem, sym)
539542
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
540-
allvars = vars(sym; op = Symbolics.Operator)
541-
ts_idxs = Set{Int}()
542-
for var in allvars
543-
var = unwrap(var)
544-
# FIXME: Shouldn't have to shift systems
545-
if istree(var) && (op = operation(var)) isa Shift && op.steps == 1
546-
var = only(arguments(var))
547-
end
548-
ts_idx = check_index_map(ic.discrete_idx, unwrap(var))
549-
ts_idx === nothing && continue
550-
push!(ts_idxs, ts_idx[1])
551-
end
552-
if length(ts_idxs) == 1
553-
ts_idx = only(ts_idxs)
554-
else
555-
ts_idx = nothing
556-
end
557543
rawobs = build_explicit_observed_function(
558544
sys, sym; param_only = true, return_inplace = true)
559545
if rawobs isa Tuple
@@ -580,10 +566,44 @@ function SymbolicIndexingInterface.parameter_observed(sys::AbstractSystem, sym)
580566
end
581567
end
582568
else
583-
ts_idx = nothing
584569
obsfn = build_explicit_observed_function(sys, sym; param_only = true)
585570
end
586-
return ParameterObservedFunction(ts_idx, obsfn)
571+
return obsfn
572+
end
573+
574+
function _all_ts_idxs!(ts_idxs, ::NotSymbolic, sys, sym)
575+
if is_variable(sys, sym)
576+
push!(ts_idxs, ContinuousTimeseries())
577+
elseif is_timeseries_parameter(sys, sym)
578+
push!(ts_idxs, timeseries_parameter_index(sys, sym).timeseries_idx)
579+
end
580+
end
581+
# Need this to avoid ambiguity with the array case
582+
for traitT in [
583+
ScalarSymbolic,
584+
ArraySymbolic
585+
]
586+
@eval function _all_ts_idxs!(ts_idxs, ::$traitT, sys, sym)
587+
allsyms = vars(sym; op = Symbolics.Operator)
588+
foreach(allsyms) do s
589+
_all_ts_idxs!(ts_idxs, sys, s)
590+
end
591+
end
592+
end
593+
function _all_ts_idxs!(ts_idxs, ::NotSymbolic, sys, sym::AbstractArray)
594+
foreach(sym) do s
595+
_all_ts_idxs!(ts_idxs, sys, s)
596+
end
597+
end
598+
_all_ts_idxs!(ts_idxs, sys, sym) = _all_ts_idxs!(ts_idxs, NotSymbolic(), sys, sym)
599+
600+
function SymbolicIndexingInterface.get_all_timeseries_indexes(sys::AbstractSystem, sym)
601+
if !is_time_dependent(sys)
602+
return Set()
603+
end
604+
ts_idxs = Set()
605+
_all_ts_idxs!(ts_idxs, sys, sym)
606+
return ts_idxs
587607
end
588608

589609
function SymbolicIndexingInterface.parameter_symbols(sys::AbstractSystem)

src/systems/index_cache.jl

+9-9
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ function IndexCache(sys::AbstractSystem)
113113
error("Discrete subsystem $i input $inp is not a parameter")
114114
disc_clocks[inp] = i
115115
disc_clocks[default_toterm(inp)] = i
116-
if hasname(inp) && (!istree(inp) || operation(inp) !== getindex)
116+
if hasname(inp) && (!iscall(inp) || operation(inp) !== getindex)
117117
disc_clocks[getname(inp)] = i
118118
disc_clocks[default_toterm(inp)] = i
119119
end
@@ -126,7 +126,7 @@ function IndexCache(sys::AbstractSystem)
126126
error("Discrete subsystem $i unknown $sym is not a parameter")
127127
disc_clocks[sym] = i
128128
disc_clocks[default_toterm(sym)] = i
129-
if hasname(sym) && (!istree(sym) || operation(sym) !== getindex)
129+
if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex)
130130
disc_clocks[getname(sym)] = i
131131
disc_clocks[getname(default_toterm(sym))] = i
132132
end
@@ -138,13 +138,13 @@ function IndexCache(sys::AbstractSystem)
138138
# FIXME: This shouldn't be necessary
139139
eq.rhs === -0.0 && continue
140140
sym = eq.lhs
141-
if istree(sym) && operation(sym) == Shift(t, 1)
141+
if iscall(sym) && operation(sym) == Shift(t, 1)
142142
sym = only(arguments(sym))
143143
end
144144
disc_clocks[sym] = i
145145
disc_clocks[sym] = i
146146
disc_clocks[default_toterm(sym)] = i
147-
if hasname(sym) && (!istree(sym) || operation(sym) !== getindex)
147+
if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex)
148148
disc_clocks[getname(sym)] = i
149149
disc_clocks[getname(default_toterm(sym))] = i
150150
end
@@ -153,7 +153,7 @@ function IndexCache(sys::AbstractSystem)
153153

154154
for par in inputs[continuous_id]
155155
is_parameter(sys, par) || error("Discrete subsystem input is not a parameter")
156-
istree(par) && operation(par) isa Hold ||
156+
iscall(par) && operation(par) isa Hold ||
157157
error("Continuous subsystem input is not a Hold")
158158
if haskey(disc_clocks, par)
159159
sym = par
@@ -176,7 +176,7 @@ function IndexCache(sys::AbstractSystem)
176176
disc_clocks[affect.lhs] = user_affect_clock
177177
disc_clocks[default_toterm(affect.lhs)] = user_affect_clock
178178
if hasname(affect.lhs) &&
179-
(!istree(affect.lhs) || operation(affect.lhs) !== getindex)
179+
(!iscall(affect.lhs) || operation(affect.lhs) !== getindex)
180180
disc_clocks[getname(affect.lhs)] = user_affect_clock
181181
disc_clocks[getname(default_toterm(affect.lhs))] = user_affect_clock
182182
end
@@ -190,7 +190,7 @@ function IndexCache(sys::AbstractSystem)
190190
disc = unwrap(disc)
191191
disc_clocks[disc] = user_affect_clock
192192
disc_clocks[default_toterm(disc)] = user_affect_clock
193-
if hasname(disc) && (!istree(disc) || operation(disc) !== getindex)
193+
if hasname(disc) && (!iscall(disc) || operation(disc) !== getindex)
194194
disc_clocks[getname(disc)] = user_affect_clock
195195
disc_clocks[getname(default_toterm(disc))] = user_affect_clock
196196
end
@@ -245,7 +245,7 @@ function IndexCache(sys::AbstractSystem)
245245
for (j, sym) in enumerate(buffer[btype])
246246
disc_idxs[sym] = (clockidx, i, j)
247247
disc_idxs[default_toterm(sym)] = (clockidx, i, j)
248-
if hasname(sym) && (!istree(sym) || operation(sym) !== getindex)
248+
if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex)
249249
disc_idxs[getname(sym)] = (clockidx, i, j)
250250
disc_idxs[getname(default_toterm(sym))] = (clockidx, i, j)
251251
end
@@ -256,7 +256,7 @@ function IndexCache(sys::AbstractSystem)
256256
haskey(disc_idxs, sym) && continue
257257
disc_idxs[sym] = (clockid, 0, 0)
258258
disc_idxs[default_toterm(sym)] = (clockid, 0, 0)
259-
if hasname(sym) && (!istree(sym) || operation(sym) !== getindex)
259+
if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex)
260260
disc_idxs[getname(sym)] = (clockid, 0, 0)
261261
disc_idxs[getname(default_toterm(sym))] = (clockid, 0, 0)
262262
end

src/systems/parameter_buffer.jl

+3-2
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ function SymbolicIndexingInterface.set_parameter!(
363363
if validate_size && size(val) !== size(p.discrete[i][j][k])
364364
throw(InvalidParameterSizeException(size(p.discrete[i][j][k]), size(val)))
365365
end
366-
p.discrete[i][j][k][l...] = val
366+
p.discrete[i][j][k] = val
367367
else
368368
p.discrete[i][j][k][l...] = val
369369
end
@@ -586,7 +586,8 @@ end
586586
Base.size(::NestedGetIndex) = ()
587587

588588
function SymbolicIndexingInterface.with_updated_parameter_timeseries_values(
589-
ps::MTKParameters, args::Pair{A, B}...) where {A, B <: NestedGetIndex}
589+
::AbstractSystem, ps::MTKParameters, args::Pair{A, B}...) where {
590+
A, B <: NestedGetIndex}
590591
for (i, val) in args
591592
ps.discrete[i] = val.x
592593
end

test/mtkparameters.jl

+9-5
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ using ModelingToolkit
22
using ModelingToolkit: t_nounits as t, D_nounits as D, MTKParameters
33
using SymbolicIndexingInterface
44
using SciMLStructures: SciMLStructures, canonicalize, Tunable, Discrete, Constants
5+
using StaticArrays: SizedVector
56
using OrdinaryDiffEq
67
using ForwardDiff
78
using JET
@@ -309,19 +310,22 @@ end
309310
end
310311

311312
# Parameter timeseries
312-
ps = MTKParameters(([1.0, 1.0],), SizedArray{2}([([0.0, 0.0],), ([0.0, 0.0],)]), (), (), (), nothing, nothing)
313+
ps = MTKParameters(([1.0, 1.0],), SizedVector{2}([([0.0, 0.0],), ([0.0, 0.0],)]),
314+
(), (), (), nothing, nothing)
313315
with_updated_parameter_timeseries_values(
314-
ps, 1 => ModelingToolkit.NestedGetIndex(([5.0, 10.0],)))
316+
sys, ps, 1 => ModelingToolkit.NestedGetIndex(([5.0, 10.0],)))
315317
@test ps.discrete[1][1] == [5.0, 10.0]
316318
with_updated_parameter_timeseries_values(
317-
ps, 1 => ModelingToolkit.NestedGetIndex(([3.0, 30.0],)),
319+
sys, ps, 1 => ModelingToolkit.NestedGetIndex(([3.0, 30.0],)),
318320
2 => ModelingToolkit.NestedGetIndex(([4.0, 40.0],)))
319321
@test ps.discrete[1][1] == [3.0, 30.0]
320322
@test ps.discrete[2][1] == [4.0, 40.0]
321323
@test SciMLBase.get_saveable_values(ps, 1).x == ps.discrete[1]
322324

323325
# With multiple types and clocks
324-
ps = MTKParameters((), SizedVector{2}([([1.0, 2.0, 3.0], falses(1)), ([4.0, 5.0, 6.0], falses(0))]), (), (), (), nothing, nothing)
326+
ps = MTKParameters(
327+
(), SizedVector{2}([([1.0, 2.0, 3.0], falses(1)), ([4.0, 5.0, 6.0], falses(0))]),
328+
(), (), (), nothing, nothing)
325329
@test SciMLBase.get_saveable_values(ps, 1).x isa Tuple{Vector{Float64}, BitVector}
326330
tsidx1 = 1
327331
tsidx2 = 2
@@ -330,6 +334,6 @@ tsidx2 = 2
330334
@test length(ps.discrete[tsidx2][1]) == 3
331335
@test length(ps.discrete[tsidx2][2]) == 0
332336
with_updated_parameter_timeseries_values(
333-
ps, tsidx1 => ModelingToolkit.NestedGetIndex(([10.0, 11.0, 12.0], [false])))
337+
sys, ps, tsidx1 => ModelingToolkit.NestedGetIndex(([10.0, 11.0, 12.0], [false])))
334338
@test ps.discrete[tsidx1][1] == [10.0, 11.0, 12.0]
335339
@test ps.discrete[tsidx1][2][] == false

test/parameter_dependencies.jl

+6-3
Original file line numberDiff line numberDiff line change
@@ -173,18 +173,21 @@ end
173173
@test_skip begin
174174
Tf = 1.0
175175
prob = ODEProblem(sys, [x => 0.0, y => 0.0], (0.0, Tf),
176-
[kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0; z(k - 2) => 4.0; yd(k - 2) => 2.0])
176+
[kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0; z(k - 2) => 4.0;
177+
yd(k - 2) => 2.0])
177178
@test_nowarn solve(prob, Tsit5())
178179

179180
@mtkbuild sys = ODESystem(eqs, t; parameter_dependencies = [kq => 2kp],
180181
discrete_events = [[0.5] => [kp ~ 2.0]])
181182
prob = ODEProblem(sys, [x => 0.0, y => 0.0], (0.0, Tf),
182-
[kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0; z(k - 2) => 4.0; yd(k - 2) => 2.0])
183+
[kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0; z(k - 2) => 4.0;
184+
yd(k - 2) => 2.0])
183185
@test prob.ps[kp] == 1.0
184186
@test prob.ps[kq] == 2.0
185187
@test_nowarn solve(prob, Tsit5())
186188
prob = ODEProblem(sys, [x => 0.0, y => 0.0], (0.0, Tf),
187-
[kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0; z(k - 2) => 4.0; yd(k - 2) => 2.0])
189+
[kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0; z(k - 2) => 4.0;
190+
yd(k - 2) => 2.0])
188191
integ = init(prob, Tsit5())
189192
@test integ.ps[kp] == 1.0
190193
@test integ.ps[kq] == 2.0

test/symbolic_indexing_interface.jl

+13-16
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,12 @@ using SciMLStructures: Tunable
3838
odesys = complete(odesys)
3939
@test default_values(odesys)[xy] == 3.0
4040
pobs = parameter_observed(odesys, a + b)
41-
@test pobs.timeseries_idx === nothing
42-
@test pobs.observed_fn(
41+
@test isempty(get_all_timeseries_indexes(odesys, a + b))
42+
@test pobs(
4343
ModelingToolkit.MTKParameters(odesys, [a => 1.0, b => 2.0]), 0.0) 3.0
4444
pobs = parameter_observed(odesys, [a + b, a - b])
45-
@test pobs.timeseries_idx === nothing
46-
@test pobs.observed_fn(
45+
@test isempty(get_all_timeseries_indexes(odesys, [a + b, a - b]))
46+
@test pobs(
4747
ModelingToolkit.MTKParameters(odesys, [a => 1.0, b => 2.0]), 0.0) [3.0, -1.0]
4848
end
4949

@@ -102,11 +102,11 @@ end
102102
@test !is_time_dependent(ns)
103103
ps = ModelingToolkit.MTKParameters(ns, [σ => 1.0, ρ => 2.0, β => 3.0])
104104
pobs = parameter_observed(ns, σ + ρ)
105-
@test pobs.timeseries_idx === nothing
106-
@test pobs.observed_fn(ps) == 3.0
105+
@test isempty(get_all_timeseries_indexes(ns, σ + ρ))
106+
@test pobs(ps) == 3.0
107107
pobs = parameter_observed(ns, [σ + ρ, ρ + β])
108-
@test pobs.timeseries_idx === nothing
109-
@test pobs.observed_fn(ps) == [3.0, 5.0]
108+
@test isempty(get_all_timeseries_indexes(ns, [σ + ρ, ρ + β]))
109+
@test pobs(ps) == [3.0, 5.0]
110110
end
111111

112112
@testset "PDESystem" begin
@@ -147,6 +147,11 @@ end
147147
domains = [t (0.0, 1.0),
148148
x (0.0, 1.0)]
149149

150+
analytic = [u(t, x) ~ -h * x * (x - 1) * sin(x) * exp(-2 * h * t)]
151+
analytic_function = (ps, t, x) -> -ps[1] * x * (x - 1) * sin(x) * exp(-2 * ps[1] * t)
152+
153+
@named pdesys = PDESystem(eq, bcs, domains, [t, x], [u], [h], analytic = analytic)
154+
150155
@test isequal(pdesys.ps, [h])
151156
@test isequal(parameter_symbols(pdesys), [h])
152157
@test isequal(parameters(pdesys), [h])
@@ -179,12 +184,4 @@ get_dep = @test_nowarn getu(prob, 2p1)
179184
@test getu(prob, z)(prob) == getu(prob, :z)(prob)
180185
@test getu(prob, p1)(prob) == getu(prob, :p1)(prob)
181186
@test getu(prob, p2)(prob) == getu(prob, :p2)(prob)
182-
analytic = [u(t, x) ~ -h * x * (x - 1) * sin(x) * exp(-2 * h * t)]
183-
analytic_function = (ps, t, x) -> -ps[1] * x * (x - 1) * sin(x) * exp(-2 * ps[1] * t)
184-
185-
@named pdesys = PDESystem(eq, bcs, domains, [t, x], [u], [h], analytic = analytic)
186-
187-
@test isequal(pdesys.ps, [h])
188-
@test isequal(parameter_symbols(pdesys), [h])
189-
@test isequal(parameters(pdesys), [h])
190187
end

0 commit comments

Comments
 (0)