@@ -187,8 +187,14 @@ function SymbolicIndexingInterface.is_variable(sys::AbstractSystem, sym)
187
187
if unwrap (sym) isa Int # [x, 1] coerces 1 to a Num
188
188
return unwrap (sym) in 1 : length (variable_symbols (sys))
189
189
end
190
- return any (isequal (sym), variable_symbols (sys)) ||
190
+ if has_index_cache (sys) && get_index_cache (sys) != = nothing
191
+ ic = get_index_cache (sys)
192
+ h = getsymbolhash (sym)
193
+ return haskey (ic. unknown_idx, h) || haskey (ic. unknown_idx, getsymbolhash (default_toterm (sym))) || hasname (sym) && is_variable (sys, getname (sym))
194
+ else
195
+ return any (isequal (sym), variable_symbols (sys)) ||
191
196
hasname (sym) && is_variable (sys, getname (sym))
197
+ end
192
198
end
193
199
194
200
function SymbolicIndexingInterface. is_variable (sys:: AbstractSystem , sym:: Symbol )
@@ -202,6 +208,22 @@ function SymbolicIndexingInterface.variable_index(sys::AbstractSystem, sym)
202
208
if unwrap (sym) isa Int
203
209
return unwrap (sym)
204
210
end
211
+ if has_index_cache (sys) && get_index_cache (sys) != = nothing
212
+ ic = get_index_cache (sys)
213
+ h = getsymbolhash (sym)
214
+ return if haskey (ic. unknown_idx, h)
215
+ ic. unknown_idx[h]
216
+ else
217
+ h = getsymbolhash (default_toterm (sym))
218
+ if haskey (ic. unknown_idx, h)
219
+ ic. unknown_idx[h]
220
+ elseif hasname (sym)
221
+ variable_index (sys, getname (sym))
222
+ else
223
+ nothing
224
+ end
225
+ end
226
+ end
205
227
idx = findfirst (isequal (sym), variable_symbols (sys))
206
228
if idx === nothing && hasname (sym)
207
229
idx = variable_index (sys, getname (sym))
@@ -230,7 +252,19 @@ function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym)
230
252
if unwrap (sym) isa Int
231
253
return unwrap (sym) in 1 : length (parameter_symbols (sys))
232
254
end
233
-
255
+ if has_index_cache (sys) && get_index_cache (sys) != = nothing
256
+ ic = get_index_cache (sys)
257
+ h = getsymbolhash (sym)
258
+ return if haskey (ic. param_idx, h) || haskey (ic. discrete_idx, h) ||
259
+ haskey (ic. constant_idx, h) || haskey (ic. dependent_idx, h)
260
+ true
261
+ else
262
+ h = getsymbolhash (default_toterm (sym))
263
+ haskey (ic. param_idx, h) || haskey (ic. discrete_idx, h) ||
264
+ haskey (ic. constant_idx, h) || haskey (ic. dependent_idx, h) ||
265
+ hasname (sym) && is_parameter (sys, getname (sym))
266
+ end
267
+ end
234
268
return any (isequal (sym), parameter_symbols (sys)) ||
235
269
hasname (sym) && is_parameter (sys, getname (sym))
236
270
end
@@ -246,6 +280,33 @@ function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym)
246
280
if unwrap (sym) isa Int
247
281
return unwrap (sym)
248
282
end
283
+ if has_index_cache (sys) && get_index_cache (sys) != = nothing
284
+ ic = get_index_cache (sys)
285
+ h = getsymbolhash (sym)
286
+ return if haskey (ic. param_idx, h)
287
+ ParameterIndex (SciMLStructures. Tunable (), ic. param_idx[h])
288
+ elseif haskey (ic. discrete_idx, h)
289
+ ParameterIndex (SciMLStructures. Discrete (), ic. discrete_idx[h])
290
+ elseif haskey (ic. constant_idx, h)
291
+ ParameterIndex (SciMLStructures. Constants (), ic. constant_idx[h])
292
+ elseif haskey (ic. dependent_idx, h)
293
+ ParameterIndex (nothing , ic. dependent_idx[h])
294
+ else
295
+ h = getsymbolhash (default_toterm (sym))
296
+ if haskey (ic. param_idx, h)
297
+ ParameterIndex (SciMLStructures. Tunable (), ic. param_idx[h])
298
+ elseif haskey (ic. discrete_idx, h)
299
+ ParameterIndex (SciMLStructures. Discrete (), ic. discrete_idx[h])
300
+ elseif haskey (ic. constant_idx, h)
301
+ ParameterIndex (SciMLStructures. Constants (), ic. constant_idx[h])
302
+ elseif haskey (ic. dependent_idx, h)
303
+ ParameterIndex (nothing , ic. dependent_idx[h])
304
+ else
305
+ nothing
306
+ end
307
+ end
308
+ end
309
+
249
310
idx = findfirst (isequal (sym), parameter_symbols (sys))
250
311
if idx === nothing && hasname (sym)
251
312
idx = parameter_index (sys, getname (sym))
@@ -313,6 +374,9 @@ Mark a system as completed. If a system is complete, the system will no longer
313
374
namespace its subsystems or variables, i.e. `isequal(complete(sys).v.i, v.i)`.
314
375
"""
315
376
function complete (sys:: AbstractSystem )
377
+ if has_index_cache (sys)
378
+ @set! sys. index_cache = IndexCache (sys)
379
+ end
316
380
isdefined (sys, :complete ) ? (@set! sys. complete = true ) : sys
317
381
end
318
382
@@ -354,7 +418,8 @@ for prop in [:eqs
354
418
:discrete_subsystems
355
419
:solved_unknowns
356
420
:split_idxs
357
- :parent ]
421
+ :parent
422
+ :index_cache ]
358
423
fname1 = Symbol (:get_ , prop)
359
424
fname2 = Symbol (:has_ , prop)
360
425
@eval begin
@@ -1437,14 +1502,19 @@ function linearization_function(sys::AbstractSystem, inputs,
1437
1502
end
1438
1503
sys = ssys
1439
1504
x0 = merge (defaults (sys), Dict (missing_variable_defaults (sys)), op)
1440
- u0, p, _ = get_u0_p (sys, x0, p; use_union = false , tofloat = true )
1441
- p, split_idxs = split_parameters_by_type (p)
1442
- ps = parameters (sys)
1443
- if p isa Tuple
1444
- ps = Base. Fix1 (getindex, ps).(split_idxs)
1445
- ps = (ps... ,) # if p is Tuple, ps should be Tuple
1505
+ u0, _p, _ = get_u0_p (sys, x0, p; use_union = false , tofloat = true )
1506
+ if has_index_cache (sys) && get_index_cache (sys) != = nothing
1507
+ p = MTKParameters (sys, p)
1508
+ ps = reorder_parameters (sys, parameters (sys))
1509
+ else
1510
+ p = _p
1511
+ p, split_idxs = split_parameters_by_type (p)
1512
+ ps = parameters (sys)
1513
+ if p isa Tuple
1514
+ ps = Base. Fix1 (getindex, ps).(split_idxs)
1515
+ ps = (ps... ,) # if p is Tuple, ps should be Tuple
1516
+ end
1446
1517
end
1447
-
1448
1518
lin_fun = let diff_idxs = diff_idxs,
1449
1519
alge_idxs = alge_idxs,
1450
1520
input_idxs = input_idxs,
@@ -1468,7 +1538,7 @@ function linearization_function(sys::AbstractSystem, inputs,
1468
1538
uf = SciMLBase. UJacobianWrapper (fun, t, p)
1469
1539
fg_xz = ForwardDiff. jacobian (uf, u)
1470
1540
h_xz = ForwardDiff. jacobian (let p = p, t = t
1471
- xz -> h (xz, p, t)
1541
+ xz -> p isa MTKParameters ? h (xz, p ... , t) : h (xz, p, t)
1472
1542
end , u)
1473
1543
pf = SciMLBase. ParamJacobianWrapper (fun, t, u)
1474
1544
fg_u = jacobian_wrt_vars (pf, p, input_idxs, chunk)
@@ -1479,7 +1549,9 @@ function linearization_function(sys::AbstractSystem, inputs,
1479
1549
h_xz = fg_u = zeros (0 , length (inputs))
1480
1550
end
1481
1551
hp = let u = u, t = t
1482
- p -> h (u, p, t)
1552
+ _hp (p) = h (u, p, t)
1553
+ _hp (p:: MTKParameters ) = h (u, p... , t)
1554
+ _hp
1483
1555
end
1484
1556
h_u = jacobian_wrt_vars (hp, p, input_idxs, chunk)
1485
1557
(f_x = fg_xz[diff_idxs, diff_idxs],
@@ -1521,13 +1593,14 @@ function linearize_symbolic(sys::AbstractSystem, inputs,
1521
1593
kwargs... )
1522
1594
sts = unknowns (sys)
1523
1595
t = get_iv (sys)
1524
- p = parameters (sys)
1596
+ ps = parameters (sys)
1597
+ p = reorder_parameters (sys, ps)
1525
1598
1526
- fun = generate_function (sys, sts, p ; expression = Val{false })[1 ]
1527
- dx = fun (sts, p, t)
1599
+ fun = generate_function (sys, sts, ps ; expression = Val{false })[1 ]
1600
+ dx = fun (sts, p... , t)
1528
1601
1529
1602
h = build_explicit_observed_function (sys, outputs)
1530
- y = h (sts, p, t)
1603
+ y = h (sts, p... , t)
1531
1604
1532
1605
fg_xz = Symbolics. jacobian (dx, sts)
1533
1606
fg_u = Symbolics. jacobian (dx, inputs)
@@ -1722,7 +1795,18 @@ function linearize(sys, lin_fun; t = 0.0, op = Dict(), allow_input_derivatives =
1722
1795
p = DiffEqBase. NullParameters ())
1723
1796
x0 = merge (defaults (sys), op)
1724
1797
u0, p2, _ = get_u0_p (sys, x0, p; use_union = false , tofloat = true )
1725
-
1798
+ if has_index_cache (sys) && get_index_cache (sys) != = nothing
1799
+ if p isa SciMLBase. NullParameters
1800
+ p = op
1801
+ elseif p isa Dict
1802
+ p = merge (p, op)
1803
+ elseif p isa Vector && eltype (p) <: Pair
1804
+ p = merge (Dict (p), op)
1805
+ elseif p isa Vector
1806
+ p = merge (Dict (parameters (sys) .=> p), op)
1807
+ end
1808
+ p2 = MTKParameters (sys, p)
1809
+ end
1726
1810
linres = lin_fun (u0, p2, t)
1727
1811
f_x, f_z, g_x, g_z, f_u, g_u, h_x, h_z, h_u = linres
1728
1812
0 commit comments