Skip to content

Commit 3b4fb73

Browse files
Merge pull request SciML#2447 from AayushSabharwal/as/param-splitting
feat!: use SciMLStructures and add new `MTKParameters`
2 parents 88a1c32 + de504b5 commit 3b4fb73

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+1159
-345
lines changed

Diff for: Project.toml

+2
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
3939
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
4040
RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47"
4141
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
42+
SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226"
4243
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
4344
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
4445
SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
@@ -99,6 +100,7 @@ SciMLBase = "2.0.1"
99100
Serialization = "1"
100101
Setfield = "0.7, 0.8, 1"
101102
SimpleNonlinearSolve = "0.1.0, 1"
103+
SciMLStructures = "1.0"
102104
SparseArrays = "1"
103105
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
104106
StaticArrays = "0.10, 0.11, 0.12, 1.0"

Diff for: ext/MTKBifurcationKitExt.jl

+1
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ function BifurcationKit.BifurcationProblem(nsys::NonlinearSystem,
9191
if !ModelingToolkit.iscomplete(nsys)
9292
error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `BifurcationProblem`")
9393
end
94+
@set! nsys.index_cache = nothing # force usage of a parameter vector instead of `MTKParameters`
9495
# Creates F and J functions.
9596
ofun = NonlinearFunction(nsys; jac = jac)
9697
F = ofun.f

Diff for: src/ModelingToolkit.jl

+4-1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ using PrecompileTools, Reexport
3131
import Distributions
3232
import FunctionWrappersWrappers
3333
using URIs: URI
34+
using SciMLStructures
3435

3536
using RecursiveArrayTools
3637

@@ -62,7 +63,7 @@ using PrecompileTools, Reexport
6263
ParallelForm, SerialForm, MultithreadedForm, build_function,
6364
rhss, lhss, prettify_expr, gradient,
6465
jacobian, hessian, derivative, sparsejacobian, sparsehessian,
65-
substituter, scalarize, getparent
66+
substituter, scalarize, getparent, hasderiv, hasdiff
6667

6768
import DiffEqBase: @add_kwonly
6869
import OrdinaryDiffEq
@@ -128,6 +129,8 @@ include("constants.jl")
128129
include("utils.jl")
129130
include("domains.jl")
130131

132+
include("systems/index_cache.jl")
133+
include("systems/parameter_buffer.jl")
131134
include("systems/abstractsystem.jl")
132135
include("systems/model_parsing.jl")
133136
include("systems/connectors.jl")

Diff for: src/clock.jl

+3-1
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,9 @@ true if `x` contains only discrete-domain signals.
8383
See also [`has_discrete_domain`](@ref)
8484
"""
8585
function is_discrete_domain(x)
86-
issym(x) && return getmetadata(x, TimeDomain, false) isa Discrete
86+
if hasmetadata(x, TimeDomain) || issym(x)
87+
return getmetadata(x, TimeDomain, false) isa AbstractDiscrete
88+
end
8789
!has_discrete_domain(x) && has_continuous_domain(x)
8890
end
8991

Diff for: src/discretedomain.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ Base.hash(D::Sample, u::UInt) = hash(D.clock, xor(u, 0x055640d6d952f101))
114114
115115
Returns true if the expression or equation `O` contains [`Sample`](@ref) terms.
116116
"""
117-
hassample(O) = recursive_hasoperator(Sample, O)
117+
hassample(O) = recursive_hasoperator(Sample, unwrap(O))
118118

119119
# Hold
120120

@@ -140,7 +140,7 @@ Hold(x) = Hold()(x)
140140
141141
Returns true if the expression or equation `O` contains [`Hold`](@ref) terms.
142142
"""
143-
hashold(O) = recursive_hasoperator(Hold, O)
143+
hashold(O) = recursive_hasoperator(Hold, unwrap(O))
144144

145145
# ShiftIndex
146146

Diff for: src/systems/abstractsystem.jl

+101-17
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,14 @@ function SymbolicIndexingInterface.is_variable(sys::AbstractSystem, sym)
187187
if unwrap(sym) isa Int # [x, 1] coerces 1 to a Num
188188
return unwrap(sym) in 1:length(variable_symbols(sys))
189189
end
190-
return any(isequal(sym), variable_symbols(sys)) ||
190+
if has_index_cache(sys) && get_index_cache(sys) !== nothing
191+
ic = get_index_cache(sys)
192+
h = getsymbolhash(sym)
193+
return haskey(ic.unknown_idx, h) || haskey(ic.unknown_idx, getsymbolhash(default_toterm(sym))) || hasname(sym) && is_variable(sys, getname(sym))
194+
else
195+
return any(isequal(sym), variable_symbols(sys)) ||
191196
hasname(sym) && is_variable(sys, getname(sym))
197+
end
192198
end
193199

194200
function SymbolicIndexingInterface.is_variable(sys::AbstractSystem, sym::Symbol)
@@ -202,6 +208,22 @@ function SymbolicIndexingInterface.variable_index(sys::AbstractSystem, sym)
202208
if unwrap(sym) isa Int
203209
return unwrap(sym)
204210
end
211+
if has_index_cache(sys) && get_index_cache(sys) !== nothing
212+
ic = get_index_cache(sys)
213+
h = getsymbolhash(sym)
214+
return if haskey(ic.unknown_idx, h)
215+
ic.unknown_idx[h]
216+
else
217+
h = getsymbolhash(default_toterm(sym))
218+
if haskey(ic.unknown_idx, h)
219+
ic.unknown_idx[h]
220+
elseif hasname(sym)
221+
variable_index(sys, getname(sym))
222+
else
223+
nothing
224+
end
225+
end
226+
end
205227
idx = findfirst(isequal(sym), variable_symbols(sys))
206228
if idx === nothing && hasname(sym)
207229
idx = variable_index(sys, getname(sym))
@@ -230,7 +252,19 @@ function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym)
230252
if unwrap(sym) isa Int
231253
return unwrap(sym) in 1:length(parameter_symbols(sys))
232254
end
233-
255+
if has_index_cache(sys) && get_index_cache(sys) !== nothing
256+
ic = get_index_cache(sys)
257+
h = getsymbolhash(sym)
258+
return if haskey(ic.param_idx, h) || haskey(ic.discrete_idx, h) ||
259+
haskey(ic.constant_idx, h) || haskey(ic.dependent_idx, h)
260+
true
261+
else
262+
h = getsymbolhash(default_toterm(sym))
263+
haskey(ic.param_idx, h) || haskey(ic.discrete_idx, h) ||
264+
haskey(ic.constant_idx, h) || haskey(ic.dependent_idx, h) ||
265+
hasname(sym) && is_parameter(sys, getname(sym))
266+
end
267+
end
234268
return any(isequal(sym), parameter_symbols(sys)) ||
235269
hasname(sym) && is_parameter(sys, getname(sym))
236270
end
@@ -246,6 +280,33 @@ function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym)
246280
if unwrap(sym) isa Int
247281
return unwrap(sym)
248282
end
283+
if has_index_cache(sys) && get_index_cache(sys) !== nothing
284+
ic = get_index_cache(sys)
285+
h = getsymbolhash(sym)
286+
return if haskey(ic.param_idx, h)
287+
ParameterIndex(SciMLStructures.Tunable(), ic.param_idx[h])
288+
elseif haskey(ic.discrete_idx, h)
289+
ParameterIndex(SciMLStructures.Discrete(), ic.discrete_idx[h])
290+
elseif haskey(ic.constant_idx, h)
291+
ParameterIndex(SciMLStructures.Constants(), ic.constant_idx[h])
292+
elseif haskey(ic.dependent_idx, h)
293+
ParameterIndex(nothing, ic.dependent_idx[h])
294+
else
295+
h = getsymbolhash(default_toterm(sym))
296+
if haskey(ic.param_idx, h)
297+
ParameterIndex(SciMLStructures.Tunable(), ic.param_idx[h])
298+
elseif haskey(ic.discrete_idx, h)
299+
ParameterIndex(SciMLStructures.Discrete(), ic.discrete_idx[h])
300+
elseif haskey(ic.constant_idx, h)
301+
ParameterIndex(SciMLStructures.Constants(), ic.constant_idx[h])
302+
elseif haskey(ic.dependent_idx, h)
303+
ParameterIndex(nothing, ic.dependent_idx[h])
304+
else
305+
nothing
306+
end
307+
end
308+
end
309+
249310
idx = findfirst(isequal(sym), parameter_symbols(sys))
250311
if idx === nothing && hasname(sym)
251312
idx = parameter_index(sys, getname(sym))
@@ -313,6 +374,9 @@ Mark a system as completed. If a system is complete, the system will no longer
313374
namespace its subsystems or variables, i.e. `isequal(complete(sys).v.i, v.i)`.
314375
"""
315376
function complete(sys::AbstractSystem)
377+
if has_index_cache(sys)
378+
@set! sys.index_cache = IndexCache(sys)
379+
end
316380
isdefined(sys, :complete) ? (@set! sys.complete = true) : sys
317381
end
318382

@@ -354,7 +418,8 @@ for prop in [:eqs
354418
:discrete_subsystems
355419
:solved_unknowns
356420
:split_idxs
357-
:parent]
421+
:parent
422+
:index_cache]
358423
fname1 = Symbol(:get_, prop)
359424
fname2 = Symbol(:has_, prop)
360425
@eval begin
@@ -1437,14 +1502,19 @@ function linearization_function(sys::AbstractSystem, inputs,
14371502
end
14381503
sys = ssys
14391504
x0 = merge(defaults(sys), Dict(missing_variable_defaults(sys)), op)
1440-
u0, p, _ = get_u0_p(sys, x0, p; use_union = false, tofloat = true)
1441-
p, split_idxs = split_parameters_by_type(p)
1442-
ps = parameters(sys)
1443-
if p isa Tuple
1444-
ps = Base.Fix1(getindex, ps).(split_idxs)
1445-
ps = (ps...,) #if p is Tuple, ps should be Tuple
1505+
u0, _p, _ = get_u0_p(sys, x0, p; use_union = false, tofloat = true)
1506+
if has_index_cache(sys) && get_index_cache(sys) !== nothing
1507+
p = MTKParameters(sys, p)
1508+
ps = reorder_parameters(sys, parameters(sys))
1509+
else
1510+
p = _p
1511+
p, split_idxs = split_parameters_by_type(p)
1512+
ps = parameters(sys)
1513+
if p isa Tuple
1514+
ps = Base.Fix1(getindex, ps).(split_idxs)
1515+
ps = (ps...,) #if p is Tuple, ps should be Tuple
1516+
end
14461517
end
1447-
14481518
lin_fun = let diff_idxs = diff_idxs,
14491519
alge_idxs = alge_idxs,
14501520
input_idxs = input_idxs,
@@ -1468,7 +1538,7 @@ function linearization_function(sys::AbstractSystem, inputs,
14681538
uf = SciMLBase.UJacobianWrapper(fun, t, p)
14691539
fg_xz = ForwardDiff.jacobian(uf, u)
14701540
h_xz = ForwardDiff.jacobian(let p = p, t = t
1471-
xz -> h(xz, p, t)
1541+
xz -> p isa MTKParameters ? h(xz, p..., t) : h(xz, p, t)
14721542
end, u)
14731543
pf = SciMLBase.ParamJacobianWrapper(fun, t, u)
14741544
fg_u = jacobian_wrt_vars(pf, p, input_idxs, chunk)
@@ -1479,7 +1549,9 @@ function linearization_function(sys::AbstractSystem, inputs,
14791549
h_xz = fg_u = zeros(0, length(inputs))
14801550
end
14811551
hp = let u = u, t = t
1482-
p -> h(u, p, t)
1552+
_hp(p) = h(u, p, t)
1553+
_hp(p::MTKParameters) = h(u, p..., t)
1554+
_hp
14831555
end
14841556
h_u = jacobian_wrt_vars(hp, p, input_idxs, chunk)
14851557
(f_x = fg_xz[diff_idxs, diff_idxs],
@@ -1521,13 +1593,14 @@ function linearize_symbolic(sys::AbstractSystem, inputs,
15211593
kwargs...)
15221594
sts = unknowns(sys)
15231595
t = get_iv(sys)
1524-
p = parameters(sys)
1596+
ps = parameters(sys)
1597+
p = reorder_parameters(sys, ps)
15251598

1526-
fun = generate_function(sys, sts, p; expression = Val{false})[1]
1527-
dx = fun(sts, p, t)
1599+
fun = generate_function(sys, sts, ps; expression = Val{false})[1]
1600+
dx = fun(sts, p..., t)
15281601

15291602
h = build_explicit_observed_function(sys, outputs)
1530-
y = h(sts, p, t)
1603+
y = h(sts, p..., t)
15311604

15321605
fg_xz = Symbolics.jacobian(dx, sts)
15331606
fg_u = Symbolics.jacobian(dx, inputs)
@@ -1722,7 +1795,18 @@ function linearize(sys, lin_fun; t = 0.0, op = Dict(), allow_input_derivatives =
17221795
p = DiffEqBase.NullParameters())
17231796
x0 = merge(defaults(sys), op)
17241797
u0, p2, _ = get_u0_p(sys, x0, p; use_union = false, tofloat = true)
1725-
1798+
if has_index_cache(sys) && get_index_cache(sys) !== nothing
1799+
if p isa SciMLBase.NullParameters
1800+
p = op
1801+
elseif p isa Dict
1802+
p = merge(p, op)
1803+
elseif p isa Vector && eltype(p) <: Pair
1804+
p = merge(Dict(p), op)
1805+
elseif p isa Vector
1806+
p = merge(Dict(parameters(sys) .=> p), op)
1807+
end
1808+
p2 = MTKParameters(sys, p)
1809+
end
17261810
linres = lin_fun(u0, p2, t)
17271811
f_x, f_z, g_x, g_z, f_u, g_u, h_x, h_z, h_u = linres
17281812

0 commit comments

Comments
 (0)