@@ -24,17 +24,19 @@ ParameterIndex(portion, idx) = ParameterIndex(portion, idx, false)
24
24
const ParamIndexMap = Dict{BasicSymbolic, Tuple{Int, Int}}
25
25
const UnknownIndexMap = Dict{
26
26
BasicSymbolic, Union{Int, UnitRange{Int}, AbstractArray{Int}}}
27
+ const TunableIndexMap = Dict{BasicSymbolic,
28
+ Union{Int, UnitRange{Int}, Base. ReshapedArray{Int, N, UnitRange{Int}} where {N}}}
27
29
28
30
struct IndexCache
29
31
unknown_idx:: UnknownIndexMap
30
32
discrete_idx:: Dict{BasicSymbolic, Tuple{Int, Int, Int}}
31
- tunable_idx:: ParamIndexMap
33
+ tunable_idx:: TunableIndexMap
32
34
constant_idx:: ParamIndexMap
33
35
dependent_idx:: ParamIndexMap
34
36
nonnumeric_idx:: ParamIndexMap
35
37
observed_syms:: Set{BasicSymbolic}
36
38
discrete_buffer_sizes:: Vector{Vector{BufferTemplate}}
37
- tunable_buffer_sizes :: Vector{ BufferTemplate}
39
+ tunable_buffer_size :: BufferTemplate
38
40
constant_buffer_sizes:: Vector{BufferTemplate}
39
41
dependent_buffer_sizes:: Vector{BufferTemplate}
40
42
nonnumeric_buffer_sizes:: Vector{BufferTemplate}
@@ -75,7 +77,7 @@ function IndexCache(sys::AbstractSystem)
75
77
end
76
78
end
77
79
78
- observed_syms = Set {Union{Symbol, BasicSymbolic} } ()
80
+ observed_syms = Set {BasicSymbolic} ()
79
81
for eq in observed (sys)
80
82
if symbolic_type (eq. lhs) != NotSymbolic ()
81
83
sym = eq. lhs
@@ -236,7 +238,10 @@ function IndexCache(sys::AbstractSystem)
236
238
haskey (dependent_buffers, ctype) && p in dependent_buffers[ctype] && continue
237
239
insert_by_type! (
238
240
if ctype <: Real || ctype <: AbstractArray{<:Real}
239
- if istunable (p, true ) && Symbolics. shape (p) != = Symbolics. Unknown ()
241
+ if istunable (p, true ) && Symbolics. shape (p) != Symbolics. Unknown () &&
242
+ (ctype == Real || ctype <: AbstractFloat ||
243
+ ctype <: AbstractArray{Real} ||
244
+ ctype <: AbstractArray{<:AbstractFloat} )
240
245
tunable_buffers
241
246
else
242
247
constant_buffers
@@ -292,11 +297,30 @@ function IndexCache(sys::AbstractSystem)
292
297
return idxs, buffer_sizes
293
298
end
294
299
295
- tunable_idxs, tunable_buffer_sizes = get_buffer_sizes_and_idxs (tunable_buffers)
296
300
const_idxs, const_buffer_sizes = get_buffer_sizes_and_idxs (constant_buffers)
297
301
dependent_idxs, dependent_buffer_sizes = get_buffer_sizes_and_idxs (dependent_buffers)
298
302
nonnumeric_idxs, nonnumeric_buffer_sizes = get_buffer_sizes_and_idxs (nonnumeric_buffers)
299
303
304
+ tunable_idxs = TunableIndexMap ()
305
+ tunable_buffer_size = 0
306
+ for (i, (_, buf)) in enumerate (tunable_buffers)
307
+ for (j, p) in enumerate (buf)
308
+ idx = if size (p) == ()
309
+ tunable_buffer_size + 1
310
+ else
311
+ reshape (
312
+ (tunable_buffer_size + 1 ): (tunable_buffer_size + length (p)), size (p))
313
+ end
314
+ tunable_buffer_size += length (p)
315
+ tunable_idxs[p] = idx
316
+ tunable_idxs[default_toterm (p)] = idx
317
+ if hasname (p) && (! iscall (p) || operation (p) != = getindex)
318
+ symbol_to_variable[getname (p)] = p
319
+ symbol_to_variable[getname (default_toterm (p))] = p
320
+ end
321
+ end
322
+ end
323
+
300
324
for sym in Iterators. flatten ((keys (unk_idxs), keys (disc_idxs), keys (tunable_idxs),
301
325
keys (const_idxs), keys (dependent_idxs), keys (nonnumeric_idxs),
302
326
observed_syms, independent_variable_symbols (sys)))
@@ -314,7 +338,7 @@ function IndexCache(sys::AbstractSystem)
314
338
nonnumeric_idxs,
315
339
observed_syms,
316
340
disc_buffer_sizes,
317
- tunable_buffer_sizes ,
341
+ BufferTemplate (Real, tunable_buffer_size) ,
318
342
const_buffer_sizes,
319
343
dependent_buffer_sizes,
320
344
nonnumeric_buffer_sizes,
@@ -410,20 +434,6 @@ function check_index_map(idxmap, sym)
410
434
end
411
435
end
412
436
413
- function discrete_linear_index (ic:: IndexCache , idx:: ParameterIndex )
414
- idx. portion isa SciMLStructures. Discrete || error (" Discrete variable index expected" )
415
- ind = sum (temp. length for temp in ic. tunable_buffer_sizes; init = 0 )
416
- for clockbuftemps in Iterators. take (ic. discrete_buffer_sizes, idx. idx[1 ] - 1 )
417
- ind += sum (temp. length for temp in clockbuftemps; init = 0 )
418
- end
419
- ind += sum (
420
- temp. length
421
- for temp in Iterators. take (ic. discrete_buffer_sizes[idx. idx[1 ]], idx. idx[2 ] - 1 );
422
- init = 0 )
423
- ind += idx. idx[3 ]
424
- return ind
425
- end
426
-
427
437
function reorder_parameters (sys:: AbstractSystem , ps; kwargs... )
428
438
if has_index_cache (sys) && get_index_cache (sys) != = nothing
429
439
reorder_parameters (get_index_cache (sys), ps; kwargs... )
436
446
437
447
function reorder_parameters (ic:: IndexCache , ps; drop_missing = false )
438
448
isempty (ps) && return ()
439
- param_buf = Tuple (BasicSymbolic[unwrap (variable (:DEF )) for _ in 1 : (temp. length)]
440
- for temp in ic. tunable_buffer_sizes)
449
+ param_buf = if ic. tunable_buffer_size. length == 0
450
+ ()
451
+ else
452
+ (BasicSymbolic[unwrap (variable (:DEF ))
453
+ for _ in 1 : (ic. tunable_buffer_size. length)],)
454
+ end
441
455
disc_buf = Tuple (BasicSymbolic[unwrap (variable (:DEF )) for _ in 1 : (temp. length)]
442
456
for temp in Iterators. flatten (ic. discrete_buffer_sizes))
443
457
const_buf = Tuple (BasicSymbolic[unwrap (variable (:DEF )) for _ in 1 : (temp. length)]
@@ -453,8 +467,12 @@ function reorder_parameters(ic::IndexCache, ps; drop_missing = false)
453
467
i, j, k = ic. discrete_idx[p]
454
468
disc_buf[(i - 1 ) * disc_offset + j][k] = p
455
469
elseif haskey (ic. tunable_idx, p)
456
- i, j = ic. tunable_idx[p]
457
- param_buf[i][j] = p
470
+ i = ic. tunable_idx[p]
471
+ if i isa Int
472
+ param_buf[1 ][i] = unwrap (p)
473
+ else
474
+ param_buf[1 ][i] = unwrap .(collect (p))
475
+ end
458
476
elseif haskey (ic. constant_idx, p)
459
477
i, j = ic. constant_idx[p]
460
478
const_buf[i][j] = p
0 commit comments