Skip to content

Commit 500cd49

Browse files
refactor: turn tunables portion into a Vector{T}
1 parent ec870e2 commit 500cd49

17 files changed

+294
-194
lines changed

src/inputoutput.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ function generate_control_function(sys::AbstractODESystem, inputs = unbound_inpu
243243
end
244244
process = get_postprocess_fbody(sys)
245245
f = build_function(rhss, args...; postprocess_fbody = process,
246-
expression = Val{true}, kwargs...)
246+
expression = Val{true}, wrap_code = wrap_array_vars(sys, rhss; dvs, ps), kwargs...)
247247
f = eval_or_rgf.(f; eval_expression, eval_module)
248248
(; f, dvs, ps, io_sys = sys)
249249
end

src/systems/abstractsystem.jl

+53-9
Original file line numberDiff line numberDiff line change
@@ -223,14 +223,44 @@ function wrap_assignments(isscalar, assignments; let_block = false)
223223
end
224224
end
225225

226-
function wrap_array_vars(sys::AbstractSystem, exprs; dvs = unknowns(sys))
226+
function wrap_array_vars(
227+
sys::AbstractSystem, exprs; dvs = unknowns(sys), ps = parameters(sys))
227228
isscalar = !(exprs isa AbstractArray)
228229
array_vars = Dict{Any, AbstractArray{Int}}()
229-
for (j, x) in enumerate(dvs)
230-
if iscall(x) && operation(x) == getindex
231-
arg = arguments(x)[1]
232-
inds = get!(() -> Int[], array_vars, arg)
233-
push!(inds, j)
230+
if dvs !== nothing
231+
for (j, x) in enumerate(dvs)
232+
if iscall(x) && operation(x) == getindex
233+
arg = arguments(x)[1]
234+
inds = get!(() -> Int[], array_vars, arg)
235+
push!(inds, j)
236+
end
237+
end
238+
uind = 1
239+
else
240+
uind = 0
241+
end
242+
# tunables are scalarized and concatenated, so we need to have assignments
243+
# for the non-scalarized versions
244+
array_tunables = Dict{Any, AbstractArray{Int}}()
245+
for p in ps
246+
idx = parameter_index(sys, p)
247+
idx isa ParameterIndex || continue
248+
idx.portion isa SciMLStructures.Tunable || continue
249+
idx.idx isa AbstractArray || continue
250+
array_tunables[p] = idx.idx
251+
end
252+
# Other parameters may be scalarized arrays but used in the vector form
253+
other_array_parameters = Assignment[]
254+
for p in ps
255+
idx = parameter_index(sys, p)
256+
if Symbolics.isarraysymbolic(p)
257+
idx === nothing || continue
258+
push!(other_array_parameters, p collect(p))
259+
elseif iscall(p) && operation(p) == getindex
260+
idx === nothing && continue
261+
# all of the scalarized variables are in `ps`
262+
all(x -> any(isequal(x), ps), collect(p))|| continue
263+
push!(other_array_parameters, p collect(p))
234264
end
235265
end
236266
for (k, inds) in array_vars
@@ -244,7 +274,12 @@ function wrap_array_vars(sys::AbstractSystem, exprs; dvs = unknowns(sys))
244274
expr.args,
245275
[],
246276
Let(
247-
[k :(view($(expr.args[1].name), $v)) for (k, v) in array_vars],
277+
vcat(
278+
[k :(view($(expr.args[uind].name), $v)) for (k, v) in array_vars],
279+
[k :(view($(expr.args[uind + 1].name), $v))
280+
for (k, v) in array_tunables],
281+
other_array_parameters
282+
),
248283
expr.body,
249284
false
250285
)
@@ -256,7 +291,11 @@ function wrap_array_vars(sys::AbstractSystem, exprs; dvs = unknowns(sys))
256291
expr.args,
257292
[],
258293
Let(
259-
[k :(view($(expr.args[1].name), $v)) for (k, v) in array_vars],
294+
vcat(
295+
[k :(view($(expr.args[uind].name), $v)) for (k, v) in array_vars],
296+
[k :(view($(expr.args[uind + 1].name), $v))
297+
for (k, v) in array_tunables]
298+
),
260299
expr.body,
261300
false
262301
)
@@ -267,7 +306,12 @@ function wrap_array_vars(sys::AbstractSystem, exprs; dvs = unknowns(sys))
267306
expr.args,
268307
[],
269308
Let(
270-
[k :(view($(expr.args[2].name), $v)) for (k, v) in array_vars],
309+
vcat(
310+
[k :(view($(expr.args[uind + 1].name), $v))
311+
for (k, v) in array_vars],
312+
[k :(view($(expr.args[uind + 2].name), $v))
313+
for (k, v) in array_tunables]
314+
),
271315
expr.body,
272316
false
273317
)

src/systems/callbacks.jl

+5-5
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,8 @@ function compile_condition(cb::SymbolicDiscreteCallback, sys, dvs, ps;
354354
condit = substitute(condit, cmap)
355355
end
356356
expr = build_function(
357-
condit, u, t, p...; expression = Val{true}, wrap_code = condition_header(sys),
357+
condit, u, t, p...; expression = Val{true},
358+
wrap_code = condition_header(sys) .∘ wrap_array_vars(sys, condit; dvs, ps),
358359
kwargs...)
359360
if expression == Val{true}
360361
return expr
@@ -411,10 +412,8 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
411412
update_inds = map(sym -> unknownind[sym], update_vars)
412413
elseif isparameter(first(lhss)) && alleq
413414
if has_index_cache(sys) && get_index_cache(sys) !== nothing
414-
ic = get_index_cache(sys)
415415
update_inds = map(update_vars) do sym
416-
pind = parameter_index(sys, sym)
417-
discrete_linear_index(ic, pind)
416+
return parameter_index(sys, sym)
418417
end
419418
else
420419
psind = Dict(reverse(en) for en in enumerate(ps))
@@ -440,7 +439,8 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
440439
integ = gensym(:MTKIntegrator)
441440
pre = get_preprocess_constants(rhss)
442441
rf_oop, rf_ip = build_function(rhss, u, p..., t; expression = Val{true},
443-
wrap_code = add_integrator_header(sys, integ, outvar),
442+
wrap_code = add_integrator_header(sys, integ, outvar) .∘
443+
wrap_array_vars(sys, rhss; dvs, ps),
444444
outputidxs = update_inds,
445445
postprocess_fbody = pre,
446446
kwargs...)

src/systems/diffeqs/abstractodesystem.jl

+8-4
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ end
8787

8888
function generate_tgrad(
8989
sys::AbstractODESystem, dvs = unknowns(sys), ps = full_parameters(sys);
90-
simplify = false, kwargs...)
90+
simplify = false, wrap_code = identity, kwargs...)
9191
tgrad = calculate_tgrad(sys, simplify = simplify)
9292
pre = get_preprocess_constants(tgrad)
9393
p = if has_index_cache(sys) && get_index_cache(sys) !== nothing
@@ -97,29 +97,33 @@ function generate_tgrad(
9797
else
9898
(ps,)
9999
end
100+
wrap_code = wrap_code .∘ wrap_array_vars(sys, tgrad; dvs, ps)
100101
return build_function(tgrad,
101102
dvs,
102103
p...,
103104
get_iv(sys);
104105
postprocess_fbody = pre,
106+
wrap_code,
105107
kwargs...)
106108
end
107109

108110
function generate_jacobian(sys::AbstractODESystem, dvs = unknowns(sys),
109111
ps = full_parameters(sys);
110-
simplify = false, sparse = false, kwargs...)
112+
simplify = false, sparse = false, wrap_code = identity, kwargs...)
111113
jac = calculate_jacobian(sys; simplify = simplify, sparse = sparse)
112114
pre = get_preprocess_constants(jac)
113115
p = if has_index_cache(sys) && get_index_cache(sys) !== nothing
114116
reorder_parameters(get_index_cache(sys), ps)
115117
else
116118
(ps,)
117119
end
120+
wrap_code = wrap_code .∘ wrap_array_vars(sys, jac; dvs, ps)
118121
return build_function(jac,
119122
dvs,
120123
p...,
121124
get_iv(sys);
122125
postprocess_fbody = pre,
126+
wrap_code,
123127
kwargs...)
124128
end
125129

@@ -188,12 +192,12 @@ function generate_function(sys::AbstractODESystem, dvs = unknowns(sys),
188192
if implicit_dae
189193
build_function(rhss, ddvs, u, p..., t; postprocess_fbody = pre,
190194
states = sol_states,
191-
wrap_code = wrap_code .∘ wrap_array_vars(sys, rhss; dvs),
195+
wrap_code = wrap_code .∘ wrap_array_vars(sys, rhss; dvs, ps),
192196
kwargs...)
193197
else
194198
build_function(rhss, u, p..., t; postprocess_fbody = pre,
195199
states = sol_states,
196-
wrap_code = wrap_code .∘ wrap_array_vars(sys, rhss; dvs),
200+
wrap_code = wrap_code .∘ wrap_array_vars(sys, rhss; dvs, ps),
197201
kwargs...)
198202
end
199203
end

src/systems/diffeqs/odesystem.jl

+8-2
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,7 @@ function build_explicit_observed_function(sys, ts;
485485
if inputs !== nothing
486486
ps = setdiff(ps, inputs) # Inputs have been converted to parameters by io_preprocessing, remove those from the parameter list
487487
end
488+
_ps = ps
488489
if ps isa Tuple
489490
ps = DestructuredArgs.(ps, inbounds = !checkbounds)
490491
elseif has_index_cache(sys) && get_index_cache(sys) !== nothing
@@ -505,19 +506,24 @@ function build_explicit_observed_function(sys, ts;
505506
end
506507
pre = get_postprocess_fbody(sys)
507508

509+
array_wrapper = if param_only
510+
wrap_array_vars(sys, ts; ps = _ps, dvs = nothing)
511+
else
512+
wrap_array_vars(sys, ts; ps = _ps)
513+
end
508514
# Need to keep old method of building the function since it uses `output_type`,
509515
# which can't be provided to `build_function`
510516
oop_fn = Func(args, [],
511517
pre(Let(obsexprs,
512518
isscalar ? ts[1] : MakeArray(ts, output_type),
513-
false))) |> wrap_array_vars(sys, ts)[1] |> toexpr
519+
false))) |> array_wrapper[1] |> toexpr
514520
oop_fn = expression ? oop_fn : eval_or_rgf(oop_fn; eval_expression, eval_module)
515521

516522
if !isscalar
517523
iip_fn = build_function(ts,
518524
args...;
519525
postprocess_fbody = pre,
520-
wrap_code = wrap_array_vars(sys, ts) .∘ wrap_assignments(isscalar, obsexprs),
526+
wrap_code = array_wrapper .∘ wrap_assignments(isscalar, obsexprs),
521527
expression = Val{true})[2]
522528
if !expression
523529
iip_fn = eval_or_rgf(iip_fn; eval_expression, eval_module)

src/systems/discrete_system/discrete_system.jl

+4-2
Original file line numberDiff line numberDiff line change
@@ -218,8 +218,10 @@ function flatten(sys::DiscreteSystem, noeqs = false)
218218
end
219219

220220
function generate_function(
221-
sys::DiscreteSystem, dvs = unknowns(sys), ps = full_parameters(sys); kwargs...)
222-
generate_custom_function(sys, [eq.rhs for eq in equations(sys)], dvs, ps; kwargs...)
221+
sys::DiscreteSystem, dvs = unknowns(sys), ps = full_parameters(sys); wrap_code = identity, kwargs...)
222+
exprs = [eq.rhs for eq in equations(sys)]
223+
wrap_code = wrap_code .∘ wrap_array_vars(sys, exprs)
224+
generate_custom_function(sys, exprs, dvs, ps; wrap_code, kwargs...)
223225
end
224226

225227
function process_DiscreteProblem(constructor, sys::DiscreteSystem, u0map, parammap;

src/systems/index_cache.jl

+42-24
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,19 @@ ParameterIndex(portion, idx) = ParameterIndex(portion, idx, false)
2424
const ParamIndexMap = Dict{BasicSymbolic, Tuple{Int, Int}}
2525
const UnknownIndexMap = Dict{
2626
BasicSymbolic, Union{Int, UnitRange{Int}, AbstractArray{Int}}}
27+
const TunableIndexMap = Dict{BasicSymbolic,
28+
Union{Int, UnitRange{Int}, Base.ReshapedArray{Int, N, UnitRange{Int}} where {N}}}
2729

2830
struct IndexCache
2931
unknown_idx::UnknownIndexMap
3032
discrete_idx::Dict{BasicSymbolic, Tuple{Int, Int, Int}}
31-
tunable_idx::ParamIndexMap
33+
tunable_idx::TunableIndexMap
3234
constant_idx::ParamIndexMap
3335
dependent_idx::ParamIndexMap
3436
nonnumeric_idx::ParamIndexMap
3537
observed_syms::Set{BasicSymbolic}
3638
discrete_buffer_sizes::Vector{Vector{BufferTemplate}}
37-
tunable_buffer_sizes::Vector{BufferTemplate}
39+
tunable_buffer_size::BufferTemplate
3840
constant_buffer_sizes::Vector{BufferTemplate}
3941
dependent_buffer_sizes::Vector{BufferTemplate}
4042
nonnumeric_buffer_sizes::Vector{BufferTemplate}
@@ -75,7 +77,7 @@ function IndexCache(sys::AbstractSystem)
7577
end
7678
end
7779

78-
observed_syms = Set{Union{Symbol, BasicSymbolic}}()
80+
observed_syms = Set{BasicSymbolic}()
7981
for eq in observed(sys)
8082
if symbolic_type(eq.lhs) != NotSymbolic()
8183
sym = eq.lhs
@@ -236,7 +238,10 @@ function IndexCache(sys::AbstractSystem)
236238
haskey(dependent_buffers, ctype) && p in dependent_buffers[ctype] && continue
237239
insert_by_type!(
238240
if ctype <: Real || ctype <: AbstractArray{<:Real}
239-
if istunable(p, true) && Symbolics.shape(p) !== Symbolics.Unknown()
241+
if istunable(p, true) && Symbolics.shape(p) != Symbolics.Unknown() &&
242+
(ctype == Real || ctype <: AbstractFloat ||
243+
ctype <: AbstractArray{Real} ||
244+
ctype <: AbstractArray{<:AbstractFloat})
240245
tunable_buffers
241246
else
242247
constant_buffers
@@ -292,11 +297,30 @@ function IndexCache(sys::AbstractSystem)
292297
return idxs, buffer_sizes
293298
end
294299

295-
tunable_idxs, tunable_buffer_sizes = get_buffer_sizes_and_idxs(tunable_buffers)
296300
const_idxs, const_buffer_sizes = get_buffer_sizes_and_idxs(constant_buffers)
297301
dependent_idxs, dependent_buffer_sizes = get_buffer_sizes_and_idxs(dependent_buffers)
298302
nonnumeric_idxs, nonnumeric_buffer_sizes = get_buffer_sizes_and_idxs(nonnumeric_buffers)
299303

304+
tunable_idxs = TunableIndexMap()
305+
tunable_buffer_size = 0
306+
for (i, (_, buf)) in enumerate(tunable_buffers)
307+
for (j, p) in enumerate(buf)
308+
idx = if size(p) == ()
309+
tunable_buffer_size + 1
310+
else
311+
reshape(
312+
(tunable_buffer_size + 1):(tunable_buffer_size + length(p)), size(p))
313+
end
314+
tunable_buffer_size += length(p)
315+
tunable_idxs[p] = idx
316+
tunable_idxs[default_toterm(p)] = idx
317+
if hasname(p) && (!iscall(p) || operation(p) !== getindex)
318+
symbol_to_variable[getname(p)] = p
319+
symbol_to_variable[getname(default_toterm(p))] = p
320+
end
321+
end
322+
end
323+
300324
for sym in Iterators.flatten((keys(unk_idxs), keys(disc_idxs), keys(tunable_idxs),
301325
keys(const_idxs), keys(dependent_idxs), keys(nonnumeric_idxs),
302326
observed_syms, independent_variable_symbols(sys)))
@@ -314,7 +338,7 @@ function IndexCache(sys::AbstractSystem)
314338
nonnumeric_idxs,
315339
observed_syms,
316340
disc_buffer_sizes,
317-
tunable_buffer_sizes,
341+
BufferTemplate(Real, tunable_buffer_size),
318342
const_buffer_sizes,
319343
dependent_buffer_sizes,
320344
nonnumeric_buffer_sizes,
@@ -410,20 +434,6 @@ function check_index_map(idxmap, sym)
410434
end
411435
end
412436

413-
function discrete_linear_index(ic::IndexCache, idx::ParameterIndex)
414-
idx.portion isa SciMLStructures.Discrete || error("Discrete variable index expected")
415-
ind = sum(temp.length for temp in ic.tunable_buffer_sizes; init = 0)
416-
for clockbuftemps in Iterators.take(ic.discrete_buffer_sizes, idx.idx[1] - 1)
417-
ind += sum(temp.length for temp in clockbuftemps; init = 0)
418-
end
419-
ind += sum(
420-
temp.length
421-
for temp in Iterators.take(ic.discrete_buffer_sizes[idx.idx[1]], idx.idx[2] - 1);
422-
init = 0)
423-
ind += idx.idx[3]
424-
return ind
425-
end
426-
427437
function reorder_parameters(sys::AbstractSystem, ps; kwargs...)
428438
if has_index_cache(sys) && get_index_cache(sys) !== nothing
429439
reorder_parameters(get_index_cache(sys), ps; kwargs...)
@@ -436,8 +446,12 @@ end
436446

437447
function reorder_parameters(ic::IndexCache, ps; drop_missing = false)
438448
isempty(ps) && return ()
439-
param_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) for _ in 1:(temp.length)]
440-
for temp in ic.tunable_buffer_sizes)
449+
param_buf = if ic.tunable_buffer_size.length == 0
450+
()
451+
else
452+
(BasicSymbolic[unwrap(variable(:DEF))
453+
for _ in 1:(ic.tunable_buffer_size.length)],)
454+
end
441455
disc_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) for _ in 1:(temp.length)]
442456
for temp in Iterators.flatten(ic.discrete_buffer_sizes))
443457
const_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) for _ in 1:(temp.length)]
@@ -453,8 +467,12 @@ function reorder_parameters(ic::IndexCache, ps; drop_missing = false)
453467
i, j, k = ic.discrete_idx[p]
454468
disc_buf[(i - 1) * disc_offset + j][k] = p
455469
elseif haskey(ic.tunable_idx, p)
456-
i, j = ic.tunable_idx[p]
457-
param_buf[i][j] = p
470+
i = ic.tunable_idx[p]
471+
if i isa Int
472+
param_buf[1][i] = unwrap(p)
473+
else
474+
param_buf[1][i] = unwrap.(collect(p))
475+
end
458476
elseif haskey(ic.constant_idx, p)
459477
i, j = ic.constant_idx[p]
460478
const_buf[i][j] = p

src/systems/jumps/jumpsystem.jl

+1
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ function generate_rate_function(js::JumpSystem, rate)
203203
p = reorder_parameters(js, full_parameters(js))
204204
rf = build_function(rate, unknowns(js), p...,
205205
get_iv(js),
206+
wrap_code = wrap_array_vars(js, rate; dvs = unknowns(js), ps = parameters(js)),
206207
expression = Val{true})
207208
end
208209

0 commit comments

Comments
 (0)