Skip to content

Commit e4c4f11

Browse files
refactor: remove parameter dependencies from MTKParameters
1 parent b1852bb commit e4c4f11

17 files changed

+194
-273
lines changed

src/inputoutput.jl

+3-1
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,9 @@ function generate_control_function(sys::AbstractODESystem, inputs = unbound_inpu
246246
end
247247
process = get_postprocess_fbody(sys)
248248
f = build_function(rhss, args...; postprocess_fbody = process,
249-
expression = Val{true}, wrap_code = wrap_array_vars(sys, rhss; dvs, ps), kwargs...)
249+
expression = Val{true}, wrap_code = wrap_array_vars(sys, rhss; dvs, ps) .∘
250+
wrap_parameter_dependencies(sys, false),
251+
kwargs...)
250252
f = eval_or_rgf.(f; eval_expression, eval_module)
251253
(; f, dvs, ps, io_sys = sys)
252254
end

src/systems/abstractsystem.jl

+58-54
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ function calculate_hessian end
8282

8383
"""
8484
```julia
85-
generate_tgrad(sys::AbstractTimeDependentSystem, dvs = unknowns(sys), ps = full_parameters(sys),
85+
generate_tgrad(sys::AbstractTimeDependentSystem, dvs = unknowns(sys), ps = parameters(sys),
8686
expression = Val{true}; kwargs...)
8787
```
8888
@@ -93,7 +93,7 @@ function generate_tgrad end
9393

9494
"""
9595
```julia
96-
generate_gradient(sys::AbstractSystem, dvs = unknowns(sys), ps = full_parameters(sys),
96+
generate_gradient(sys::AbstractSystem, dvs = unknowns(sys), ps = parameters(sys),
9797
expression = Val{true}; kwargs...)
9898
```
9999
@@ -104,7 +104,7 @@ function generate_gradient end
104104

105105
"""
106106
```julia
107-
generate_jacobian(sys::AbstractSystem, dvs = unknowns(sys), ps = full_parameters(sys),
107+
generate_jacobian(sys::AbstractSystem, dvs = unknowns(sys), ps = parameters(sys),
108108
expression = Val{true}; sparse = false, kwargs...)
109109
```
110110
@@ -115,7 +115,7 @@ function generate_jacobian end
115115

116116
"""
117117
```julia
118-
generate_factorized_W(sys::AbstractSystem, dvs = unknowns(sys), ps = full_parameters(sys),
118+
generate_factorized_W(sys::AbstractSystem, dvs = unknowns(sys), ps = parameters(sys),
119119
expression = Val{true}; sparse = false, kwargs...)
120120
```
121121
@@ -126,7 +126,7 @@ function generate_factorized_W end
126126

127127
"""
128128
```julia
129-
generate_hessian(sys::AbstractSystem, dvs = unknowns(sys), ps = full_parameters(sys),
129+
generate_hessian(sys::AbstractSystem, dvs = unknowns(sys), ps = parameters(sys),
130130
expression = Val{true}; sparse = false, kwargs...)
131131
```
132132
@@ -137,7 +137,7 @@ function generate_hessian end
137137

138138
"""
139139
```julia
140-
generate_function(sys::AbstractSystem, dvs = unknowns(sys), ps = full_parameters(sys),
140+
generate_function(sys::AbstractSystem, dvs = unknowns(sys), ps = parameters(sys),
141141
expression = Val{true}; kwargs...)
142142
```
143143
@@ -148,7 +148,7 @@ function generate_function end
148148
"""
149149
```julia
150150
generate_custom_function(sys::AbstractSystem, exprs, dvs = unknowns(sys),
151-
ps = full_parameters(sys); kwargs...)
151+
ps = parameters(sys); kwargs...)
152152
```
153153
154154
Generate a function to evaluate `exprs`. `exprs` is a symbolic expression or
@@ -187,7 +187,8 @@ function generate_custom_function(sys::AbstractSystem, exprs, dvs = unknowns(sys
187187
postprocess_fbody,
188188
states,
189189
wrap_code = wrap_code .∘ wrap_mtkparameters(sys, isscalar) .∘
190-
wrap_array_vars(sys, exprs; dvs),
190+
wrap_array_vars(sys, exprs; dvs) .∘
191+
wrap_parameter_dependencies(sys, isscalar),
191192
expression = Val{true}
192193
)
193194
else
@@ -198,7 +199,8 @@ function generate_custom_function(sys::AbstractSystem, exprs, dvs = unknowns(sys
198199
postprocess_fbody,
199200
states,
200201
wrap_code = wrap_code .∘ wrap_mtkparameters(sys, isscalar) .∘
201-
wrap_array_vars(sys, exprs; dvs),
202+
wrap_array_vars(sys, exprs; dvs) .∘
203+
wrap_parameter_dependencies(sys, isscalar),
202204
expression = Val{true}
203205
)
204206
end
@@ -223,6 +225,10 @@ function wrap_assignments(isscalar, assignments; let_block = false)
223225
end
224226
end
225227

228+
function wrap_parameter_dependencies(sys::AbstractSystem, isscalar)
229+
wrap_assignments(isscalar, [eq.lhs eq.rhs for eq in parameter_dependencies(sys)])
230+
end
231+
226232
function wrap_array_vars(
227233
sys::AbstractSystem, exprs; dvs = unknowns(sys), ps = parameters(sys), inputs = nothing)
228234
isscalar = !(exprs isa AbstractArray)
@@ -757,7 +763,7 @@ function SymbolicIndexingInterface.get_all_timeseries_indexes(sys::AbstractSyste
757763
end
758764

759765
function SymbolicIndexingInterface.parameter_symbols(sys::AbstractSystem)
760-
return full_parameters(sys)
766+
return parameters(sys)
761767
end
762768

763769
function SymbolicIndexingInterface.is_independent_variable(sys::AbstractSystem, sym)
@@ -1214,11 +1220,6 @@ function namespace_guesses(sys)
12141220
Dict(unknowns(sys, k) => namespace_expr(v, sys) for (k, v) in guess)
12151221
end
12161222

1217-
function namespace_parameter_dependencies(sys)
1218-
pdeps = parameter_dependencies(sys)
1219-
Dict(parameters(sys, k) => namespace_expr(v, sys) for (k, v) in pdeps)
1220-
end
1221-
12221223
function namespace_equations(sys::AbstractSystem, ivs = independent_variables(sys))
12231224
eqs = equations(sys)
12241225
isempty(eqs) && return Equation[]
@@ -1325,25 +1326,11 @@ function parameters(sys::AbstractSystem)
13251326
ps = first.(ps)
13261327
end
13271328
systems = get_systems(sys)
1328-
result = unique(isempty(systems) ? ps :
1329-
[ps; reduce(vcat, namespace_parameters.(systems))])
1330-
if has_parameter_dependencies(sys) &&
1331-
(pdeps = parameter_dependencies(sys)) !== nothing
1332-
filter(result) do sym
1333-
!haskey(pdeps, sym)
1334-
end
1335-
else
1336-
result
1337-
end
1329+
unique(isempty(systems) ? ps : [ps; reduce(vcat, namespace_parameters.(systems))])
13381330
end
13391331

13401332
function dependent_parameters(sys::AbstractSystem)
1341-
if has_parameter_dependencies(sys) &&
1342-
!isempty(parameter_dependencies(sys))
1343-
collect(keys(parameter_dependencies(sys)))
1344-
else
1345-
[]
1346-
end
1333+
return map(eq -> eq.lhs, parameter_dependencies(sys))
13471334
end
13481335

13491336
"""
@@ -1353,17 +1340,19 @@ Get the parameter dependencies of the system `sys` and its subsystems.
13531340
See also [`defaults`](@ref) and [`ModelingToolkit.get_parameter_dependencies`](@ref).
13541341
"""
13551342
function parameter_dependencies(sys::AbstractSystem)
1356-
pdeps = get_parameter_dependencies(sys)
1357-
if isnothing(pdeps)
1358-
pdeps = Dict()
1343+
if !has_parameter_dependencies(sys)
1344+
return Equation[]
13591345
end
1346+
pdeps = get_parameter_dependencies(sys)
13601347
systems = get_systems(sys)
1361-
isempty(systems) && return pdeps
1362-
for subsys in systems
1363-
pdeps = merge(pdeps, namespace_parameter_dependencies(subsys))
1364-
end
1365-
# @info pdeps
1366-
return pdeps
1348+
# put pdeps after those of subsystems to maintain topological sorted order
1349+
return vcat(
1350+
reduce(vcat,
1351+
[map(eq -> namespace_equation(eq, s), parameter_dependencies(s))
1352+
for s in systems];
1353+
init = Equation[]),
1354+
pdeps
1355+
)
13671356
end
13681357

13691358
function full_parameters(sys::AbstractSystem)
@@ -2317,7 +2306,7 @@ function linearization_function(sys::AbstractSystem, inputs,
23172306
initfn = NonlinearFunction(initsys; eval_expression, eval_module)
23182307
initprobmap = build_explicit_observed_function(
23192308
initsys, unknowns(sys); eval_expression, eval_module)
2320-
ps = full_parameters(sys)
2309+
ps = parameters(sys)
23212310
h = build_explicit_observed_function(sys, outputs; eval_expression, eval_module)
23222311
lin_fun = let diff_idxs = diff_idxs,
23232312
alge_idxs = alge_idxs,
@@ -2420,7 +2409,7 @@ function linearize_symbolic(sys::AbstractSystem, inputs,
24202409
kwargs...)
24212410
sts = unknowns(sys)
24222411
t = get_iv(sys)
2423-
ps = full_parameters(sys)
2412+
ps = parameters(sys)
24242413
p = reorder_parameters(sys, ps)
24252414

24262415
fun_expr = generate_function(sys, sts, ps; expression = Val{true})[1]
@@ -2852,7 +2841,7 @@ function extend(sys::AbstractSystem, basesys::AbstractSystem; name::Symbol = nam
28522841
eqs = union(get_eqs(basesys), get_eqs(sys))
28532842
sts = union(get_unknowns(basesys), get_unknowns(sys))
28542843
ps = union(get_ps(basesys), get_ps(sys))
2855-
dep_ps = union_nothing(parameter_dependencies(basesys), parameter_dependencies(sys))
2844+
dep_ps = union(parameter_dependencies(basesys), parameter_dependencies(sys))
28562845
obs = union(get_observed(basesys), get_observed(sys))
28572846
cevs = union(get_continuous_events(basesys), get_continuous_events(sys))
28582847
devs = union(get_discrete_events(basesys), get_discrete_events(sys))
@@ -2956,15 +2945,28 @@ function Symbolics.substitute(sys::AbstractSystem, rules::Union{Vector{<:Pair},
29562945
end
29572946

29582947
function process_parameter_dependencies(pdeps, ps)
2959-
pdeps === nothing && return pdeps, ps
2960-
if pdeps isa Vector && eltype(pdeps) <: Pair
2961-
pdeps = Dict(pdeps)
2962-
elseif !(pdeps isa Dict)
2963-
error("parameter_dependencies must be a `Dict` or `Vector{<:Pair}`")
2948+
if pdeps === nothing || isempty(pdeps)
2949+
return Equation[], ps
2950+
elseif eltype(pdeps) <: Pair
2951+
pdeps = [lhs ~ rhs for (lhs, rhs) in pdeps]
29642952
end
2965-
2953+
if !(eltype(pdeps) <: Equation)
2954+
error("Parameter dependencies must be a `Dict`, `Vector{Pair}` or `Vector{Equation}`")
2955+
end
2956+
lhss = BasicSymbolic[]
2957+
for p in pdeps
2958+
if !isparameter(p.lhs)
2959+
error("LHS of parameter dependency must be a single parameter. Found $(p.lhs).")
2960+
end
2961+
syms = vars(p.rhs)
2962+
if !all(isparameter, syms)
2963+
error("RHS of parameter dependency must only include parameters. Found $(p.rhs)")
2964+
end
2965+
push!(lhss, p.lhs)
2966+
end
2967+
pdeps = topsort_equations(pdeps, union(ps, lhss))
29662968
ps = filter(ps) do p
2967-
!haskey(pdeps, p)
2969+
!any(isequal(p), lhss)
29682970
end
29692971
return pdeps, ps
29702972
end
@@ -2997,12 +2999,14 @@ function dump_parameters(sys::AbstractSystem)
29972999
end
29983000
meta
29993001
end
3000-
pdep_metas = map(collect(keys(pdeps))) do sym
3001-
val = pdeps[sym]
3002+
pdep_metas = map(pdeps) do eq
3003+
sym = eq.lhs
3004+
val = eq.rhs
30023005
meta = dump_variable_metadata(sym)
3006+
defs[eq.lhs] = eq.rhs
30033007
meta = merge(meta,
3004-
(; dependency = pdeps[sym],
3005-
default = symbolic_evaluate(pdeps[sym], merge(defs, pdeps))))
3008+
(; dependency = val,
3009+
default = symbolic_evaluate(val, defs)))
30063010
return meta
30073011
end
30083012
return vcat(metas, pdep_metas)

src/systems/callbacks.jl

+16-9
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,8 @@ function compile_condition(cb::SymbolicDiscreteCallback, sys, dvs, ps;
411411
end
412412
expr = build_function(
413413
condit, u, t, p...; expression = Val{true},
414-
wrap_code = condition_header(sys) .∘ wrap_array_vars(sys, condit; dvs, ps),
414+
wrap_code = condition_header(sys) .∘ wrap_array_vars(sys, condit; dvs, ps) .∘
415+
wrap_parameter_dependencies(sys, !(condit isa AbstractArray)),
415416
kwargs...)
416417
if expression == Val{true}
417418
return expr
@@ -497,7 +498,8 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
497498
pre = get_preprocess_constants(rhss)
498499
rf_oop, rf_ip = build_function(rhss, u, p..., t; expression = Val{true},
499500
wrap_code = add_integrator_header(sys, integ, outvar) .∘
500-
wrap_array_vars(sys, rhss; dvs, ps = _ps),
501+
wrap_array_vars(sys, rhss; dvs, ps = _ps) .∘
502+
wrap_parameter_dependencies(sys, false),
501503
outputidxs = update_inds,
502504
postprocess_fbody = pre,
503505
kwargs...)
@@ -513,7 +515,7 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
513515
end
514516

515517
function generate_rootfinding_callback(sys::AbstractODESystem, dvs = unknowns(sys),
516-
ps = full_parameters(sys); kwargs...)
518+
ps = parameters(sys); kwargs...)
517519
cbs = continuous_events(sys)
518520
isempty(cbs) && return nothing
519521
generate_rootfinding_callback(cbs, sys, dvs, ps; kwargs...)
@@ -524,7 +526,7 @@ generate_rootfinding_callback and thus we can produce a ContinuousCallback inste
524526
"""
525527
function generate_single_rootfinding_callback(
526528
eq, cb, sys::AbstractODESystem, dvs = unknowns(sys),
527-
ps = full_parameters(sys); kwargs...)
529+
ps = parameters(sys); kwargs...)
528530
if !isequal(eq.lhs, 0)
529531
eq = 0 ~ eq.lhs - eq.rhs
530532
end
@@ -547,7 +549,7 @@ end
547549

548550
function generate_vector_rootfinding_callback(
549551
cbs, sys::AbstractODESystem, dvs = unknowns(sys),
550-
ps = full_parameters(sys); rootfind = SciMLBase.RightRootFind, kwargs...)
552+
ps = parameters(sys); rootfind = SciMLBase.RightRootFind, kwargs...)
551553
eqs = map(cb -> flatten_equations(cb.eqs), cbs)
552554
num_eqs = length.(eqs)
553555
# fuse equations to create VectorContinuousCallback
@@ -617,7 +619,7 @@ function compile_affect_fn(cb, sys::AbstractODESystem, dvs, ps, kwargs)
617619
end
618620

619621
function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknowns(sys),
620-
ps = full_parameters(sys); kwargs...)
622+
ps = parameters(sys); kwargs...)
621623
eqs = map(cb -> flatten_equations(cb.eqs), cbs)
622624
num_eqs = length.(eqs)
623625
total_eqs = sum(num_eqs)
@@ -660,10 +662,15 @@ function compile_user_affect(affect::FunctionalAffect, sys, dvs, ps; kwargs...)
660662
v_inds = map(sym -> dvs_ind[sym], unknowns(affect))
661663

662664
if has_index_cache(sys) && get_index_cache(sys) !== nothing
663-
p_inds = [parameter_index(sys, sym) for sym in parameters(affect)]
665+
p_inds = [if (pind = parameter_index(sys, sym)) === nothing
666+
sym
667+
else
668+
pind
669+
end
670+
for sym in parameters(affect)]
664671
else
665672
ps_ind = Dict(reverse(en) for en in enumerate(ps))
666-
p_inds = map(sym -> ps_ind[sym], parameters(affect))
673+
p_inds = map(sym -> get(ps_ind, sym, sym), parameters(affect))
667674
end
668675
# HACK: filter out eliminated symbols. Not clear this is the right thing to do
669676
# (MTK should keep these symbols)
@@ -711,7 +718,7 @@ function generate_discrete_callback(cb, sys, dvs, ps; postprocess_affect_expr! =
711718
end
712719

713720
function generate_discrete_callbacks(sys::AbstractSystem, dvs = unknowns(sys),
714-
ps = full_parameters(sys); kwargs...)
721+
ps = parameters(sys); kwargs...)
715722
has_discrete_events(sys) || return nothing
716723
symcbs = discrete_events(sys)
717724
isempty(symcbs) && return nothing

0 commit comments

Comments
 (0)