Skip to content

Commit a569f50

Browse files
Merge pull request #3439 from AayushSabharwal/as/initials
fix: separate `Initial` parameters into `initials` portion
2 parents ccb04d8 + bde633c commit a569f50

16 files changed

+240
-70
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ Reexport = "0.2, 1"
136136
RuntimeGeneratedFunctions = "0.5.9"
137137
SCCNonlinearSolve = "1.0.0"
138138
SciMLBase = "2.75"
139-
SciMLStructures = "1.0"
139+
SciMLStructures = "1.7"
140140
Serialization = "1"
141141
Setfield = "0.7, 0.8, 1"
142142
SimpleNonlinearSolve = "0.1.0, 1, 2"

ext/MTKChainRulesCoreExt.jl

+8-2
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,11 @@ function ChainRulesCore.rrule(
7979
end
8080
newbuf = MTK.remake_buffer(indp, oldbuf, idxs, vals)
8181
tunable_idxs = reduce(
82-
vcat, (idx.idx for idx in idxs if idx.portion isa MTK.SciMLStructures.Tunable))
82+
vcat, (idx.idx for idx in idxs if idx.portion isa MTK.SciMLStructures.Tunable);
83+
init = Union{Int, AbstractVector{Int}}[])
84+
initials_idxs = reduce(
85+
vcat, (idx.idx for idx in idxs if idx.portion isa MTK.SciMLStructures.Initials);
86+
init = Union{Int, AbstractVector{Int}}[])
8387
disc_idxs = subset_idxs(idxs, MTK.SciMLStructures.Discrete(), oldbuf.discrete)
8488
const_idxs = subset_idxs(idxs, MTK.SciMLStructures.Constants(), oldbuf.constant)
8589
nn_idxs = subset_idxs(idxs, MTK.NONNUMERIC_PORTION, oldbuf.nonnumeric)
@@ -91,10 +95,12 @@ function ChainRulesCore.rrule(
9195
indp′ = NoTangent()
9296

9397
tunable = selected_tangents(buf′.tunable, tunable_idxs)
98+
initials = selected_tangents(buf′.initials, initials_idxs)
9499
discrete = selected_tangents(buf′.discrete, disc_idxs)
95100
constant = selected_tangents(buf′.constant, const_idxs)
96101
nonnumeric = selected_tangents(buf′.nonnumeric, nn_idxs)
97-
oldbuf′ = Tangent{typeof(oldbuf)}(; tunable, discrete, constant, nonnumeric)
102+
oldbuf′ = Tangent{typeof(oldbuf)}(;
103+
tunable, initials, discrete, constant, nonnumeric)
98104
idxs′ = NoTangent()
99105
vals′ = map(i -> MTK._ducktyped_parameter_values(buf′, i), idxs)
100106
return f′, indp′, oldbuf′, idxs′, vals′

src/inputoutput.jl

+3
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,9 @@ function generate_control_function(sys::AbstractODESystem, inputs = unbound_inpu
250250
f = build_function_wrapper(sys, rhss, args...; p_start = 3 + implicit_dae,
251251
p_end = length(p) + 2 + implicit_dae)
252252
f = eval_or_rgf.(f; eval_expression, eval_module)
253+
f = GeneratedFunctionWrapper{(
254+
3 + implicit_dae, length(args) - length(p) + 1, is_split(sys))}(f...)
255+
f = f, f
253256
ps = setdiff(parameters(sys), inputs, disturbance_inputs)
254257
(; f, dvs, ps, io_sys = sys)
255258
end

src/systems/abstractsystem.jl

+46-30
Original file line numberDiff line numberDiff line change
@@ -773,38 +773,21 @@ function complete(sys::AbstractSystem; split = true, flatten = true)
773773
if !isempty(all_ps)
774774
# reorder parameters by portions
775775
ps_split = reorder_parameters(sys, all_ps)
776+
# if there are tunables, they will all be in `ps_split[1]`
777+
# and the arrays will have been scalarized
778+
ordered_ps = eltype(all_ps)[]
776779
# if there are no tunables, vcat them
777-
if isempty(get_index_cache(sys).tunable_idx)
778-
ordered_ps = reduce(vcat, ps_split)
779-
else
780-
# if there are tunables, they will all be in `ps_split[1]`
781-
# and the arrays will have been scalarized
782-
ordered_ps = eltype(all_ps)[]
783-
i = 1
784-
# go through all the tunables
785-
while i <= length(ps_split[1])
786-
sym = ps_split[1][i]
787-
# if the sym is not a scalarized array symbolic OR it was already scalarized,
788-
# just push it as-is
789-
if !iscall(sym) || operation(sym) != getindex ||
790-
any(isequal(sym), all_ps)
791-
push!(ordered_ps, sym)
792-
i += 1
793-
continue
794-
end
795-
# the next `length(sym)` symbols should be scalarized versions of the same
796-
# array symbolic
797-
if !allequal(first(arguments(x))
798-
for x in view(ps_split[1], i:(i + length(sym) - 1)))
799-
error("This should not be possible. Please open an issue in ModelingToolkit.jl with an MWE and stacktrace.")
800-
end
801-
arrsym = first(arguments(sym))
802-
push!(ordered_ps, arrsym)
803-
i += length(arrsym)
804-
end
805-
ordered_ps = vcat(
806-
ordered_ps, reduce(vcat, ps_split[2:end]; init = eltype(ordered_ps)[]))
780+
if !isempty(get_index_cache(sys).tunable_idx)
781+
unflatten_parameters!(ordered_ps, ps_split[1], all_ps)
782+
ps_split = Base.tail(ps_split)
807783
end
784+
# unflatten initial parameters
785+
if !isempty(get_index_cache(sys).initials_idx)
786+
unflatten_parameters!(ordered_ps, ps_split[1], all_ps)
787+
ps_split = Base.tail(ps_split)
788+
end
789+
ordered_ps = vcat(
790+
ordered_ps, reduce(vcat, ps_split; init = eltype(ordered_ps)[]))
808791
@set! sys.ps = ordered_ps
809792
end
810793
elseif has_index_cache(sys)
@@ -816,6 +799,39 @@ function complete(sys::AbstractSystem; split = true, flatten = true)
816799
isdefined(sys, :complete) ? (@set! sys.complete = true) : sys
817800
end
818801

802+
"""
803+
$(TYPEDSIGNATURES)
804+
805+
Given a flattened array of parameters `params` and a collection of all (unscalarized)
806+
parameters in the system `all_ps`, unscalarize the elements in `params` and append
807+
to `buffer` in the same order as they are present in `params`. Effectively, if
808+
`params = [p[1], p[2], p[3], q]` then this is equivalent to `push!(buffer, p, q)`.
809+
"""
810+
function unflatten_parameters!(buffer, params, all_ps)
811+
i = 1
812+
# go through all the tunables
813+
while i <= length(params)
814+
sym = params[i]
815+
# if the sym is not a scalarized array symbolic OR it was already scalarized,
816+
# just push it as-is
817+
if !iscall(sym) || operation(sym) != getindex ||
818+
any(isequal(sym), all_ps)
819+
push!(buffer, sym)
820+
i += 1
821+
continue
822+
end
823+
# the next `length(sym)` symbols should be scalarized versions of the same
824+
# array symbolic
825+
if !allequal(first(arguments(x))
826+
for x in view(params, i:(i + length(sym) - 1)))
827+
error("This should not be possible. Please open an issue in ModelingToolkit.jl with an MWE and stacktrace.")
828+
end
829+
arrsym = first(arguments(sym))
830+
push!(buffer, arrsym)
831+
i += length(arrsym)
832+
end
833+
end
834+
819835
for prop in [:eqs
820836
:tag
821837
:noiseeqs

src/systems/codegen_utils.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ end
287287
# The user provided a single buffer/tuple for the parameter object, so wrap that
288288
# one in a tuple
289289
fargs = ntuple(Val(length(args))) do i
290-
i == paramidx ? :((args[$i],)) : :(args[$i])
290+
i == paramidx ? :((args[$i], nothing)) : :(args[$i])
291291
end
292292
return :($f($(fargs...)))
293293
end

src/systems/diffeqs/modelingtoolkitize.jl

+5-5
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,9 @@ function modelingtoolkitize(
6262
fill!(rhs, 0)
6363
if prob.f isa ODEFunction &&
6464
prob.f.f isa FunctionWrappersWrappers.FunctionWrappersWrapper
65-
prob.f.f.fw[1].obj[](rhs, vars, p isa MTKParameters ? (params,) : params, t)
65+
prob.f.f.fw[1].obj[](rhs, vars, params, t)
6666
else
67-
prob.f(rhs, vars, p isa MTKParameters ? (params,) : params, t)
67+
prob.f(rhs, vars, params, t)
6868
end
6969
else
7070
rhs = prob.f(vars, params, t)
@@ -255,14 +255,14 @@ function modelingtoolkitize(prob::DiffEqBase.SDEProblem; kwargs...)
255255
if DiffEqBase.isinplace(prob)
256256
lhs = similar(vars, Any)
257257

258-
prob.f(lhs, vars, p isa MTKParameters ? (params,) : params, t)
258+
prob.f(lhs, vars, params, t)
259259

260260
if DiffEqBase.is_diagonal_noise(prob)
261261
neqs = similar(vars, Any)
262-
prob.g(neqs, vars, p isa MTKParameters ? (params,) : params, t)
262+
prob.g(neqs, vars, params, t)
263263
else
264264
neqs = similar(vars, Any, size(prob.noise_rate_prototype))
265-
prob.g(neqs, vars, p isa MTKParameters ? (params,) : params, t)
265+
prob.g(neqs, vars, params, t)
266266
end
267267
else
268268
lhs = prob.f(vars, params, t)

src/systems/index_cache.jl

+68-2
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,15 @@ struct IndexCache
4949
# sym => (clockidx, idx_in_clockbuffer)
5050
callback_to_clocks::Dict{Any, Vector{Int}}
5151
tunable_idx::TunableIndexMap
52+
initials_idx::TunableIndexMap
5253
constant_idx::ParamIndexMap
5354
nonnumeric_idx::NonnumericMap
5455
observed_syms_to_timeseries::Dict{BasicSymbolic, TimeseriesSetType}
5556
dependent_pars_to_timeseries::Dict{
5657
Union{BasicSymbolic, CallWithMetadata}, TimeseriesSetType}
5758
discrete_buffer_sizes::Vector{Vector{BufferTemplate}}
5859
tunable_buffer_size::BufferTemplate
60+
initials_buffer_size::BufferTemplate
5961
constant_buffer_sizes::Vector{BufferTemplate}
6062
nonnumeric_buffer_sizes::Vector{BufferTemplate}
6163
symbol_to_variable::Dict{Symbol, SymbolicParam}
@@ -251,7 +253,9 @@ function IndexCache(sys::AbstractSystem)
251253

252254
tunable_idxs = TunableIndexMap()
253255
tunable_buffer_size = 0
254-
for buffers in (tunable_buffers, initial_param_buffers)
256+
bufferlist = is_initializesystem(sys) ? (tunable_buffers, initial_param_buffers) :
257+
(tunable_buffers,)
258+
for buffers in bufferlist
255259
for (i, (_, buf)) in enumerate(buffers)
256260
for (j, p) in enumerate(buf)
257261
idx = if size(p) == ()
@@ -271,6 +275,43 @@ function IndexCache(sys::AbstractSystem)
271275
end
272276
end
273277

278+
initials_idxs = TunableIndexMap()
279+
initials_buffer_size = 0
280+
if !is_initializesystem(sys)
281+
for (i, (_, buf)) in enumerate(initial_param_buffers)
282+
for (j, p) in enumerate(buf)
283+
idx = if size(p) == ()
284+
initials_buffer_size + 1
285+
else
286+
reshape(
287+
(initials_buffer_size + 1):(initials_buffer_size + length(p)), size(p))
288+
end
289+
initials_buffer_size += length(p)
290+
initials_idxs[p] = idx
291+
initials_idxs[default_toterm(p)] = idx
292+
if hasname(p) && (!iscall(p) || operation(p) !== getindex)
293+
symbol_to_variable[getname(p)] = p
294+
symbol_to_variable[getname(default_toterm(p))] = p
295+
end
296+
end
297+
end
298+
end
299+
300+
for k in collect(keys(tunable_idxs))
301+
v = tunable_idxs[k]
302+
v isa AbstractArray || continue
303+
for (kk, vv) in zip(collect(k), v)
304+
tunable_idxs[kk] = vv
305+
end
306+
end
307+
for k in collect(keys(initials_idxs))
308+
v = initials_idxs[k]
309+
v isa AbstractArray || continue
310+
for (kk, vv) in zip(collect(k), v)
311+
initials_idxs[kk] = vv
312+
end
313+
end
314+
274315
dependent_pars_to_timeseries = Dict{
275316
Union{BasicSymbolic, CallWithMetadata}, TimeseriesSetType}()
276317

@@ -341,12 +382,14 @@ function IndexCache(sys::AbstractSystem)
341382
disc_idxs,
342383
callback_to_clocks,
343384
tunable_idxs,
385+
initials_idxs,
344386
const_idxs,
345387
nonnumeric_idxs,
346388
observed_syms_to_timeseries,
347389
dependent_pars_to_timeseries,
348390
disc_buffer_templates,
349391
BufferTemplate(Real, tunable_buffer_size),
392+
BufferTemplate(Real, initials_buffer_size),
350393
const_buffer_sizes,
351394
nonnumeric_buffer_sizes,
352395
symbol_to_variable
@@ -385,6 +428,8 @@ function SymbolicIndexingInterface.parameter_index(ic::IndexCache, sym)
385428
Symbolics.shape(sym) !== Symbolics.Unknown()
386429
return if (idx = check_index_map(ic.tunable_idx, sym)) !== nothing
387430
ParameterIndex(SciMLStructures.Tunable(), idx, validate_size)
431+
elseif (idx = check_index_map(ic.initials_idx, sym)) !== nothing
432+
ParameterIndex(SciMLStructures.Initials(), idx, validate_size)
388433
elseif (idx = check_index_map(ic.discrete_idx, sym)) !== nothing
389434
ParameterIndex(
390435
SciMLStructures.Discrete(), (idx.buffer_idx, idx.idx_in_buffer), validate_size)
@@ -465,6 +510,12 @@ function reorder_parameters(ic::IndexCache, ps; drop_missing = false)
465510
(BasicSymbolic[unwrap(variable(:DEF))
466511
for _ in 1:(ic.tunable_buffer_size.length)],)
467512
end
513+
initials_buf = if ic.initials_buffer_size.length == 0
514+
()
515+
else
516+
(BasicSymbolic[unwrap(variable(:DEF))
517+
for _ in 1:(ic.initials_buffer_size.length)],)
518+
end
468519

469520
disc_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF))
470521
for _ in 1:(sum(x -> x.length, temp))]
@@ -486,6 +537,13 @@ function reorder_parameters(ic::IndexCache, ps; drop_missing = false)
486537
else
487538
param_buf[1][i] = unwrap.(collect(p))
488539
end
540+
elseif haskey(ic.initials_idx, p)
541+
i = ic.initials_idx[p]
542+
if i isa Int
543+
initials_buf[1][i] = unwrap(p)
544+
else
545+
initials_buf[1][i] = unwrap.(collect(p))
546+
end
489547
elseif haskey(ic.constant_idx, p)
490548
i, j = ic.constant_idx[p]
491549
const_buf[i][j] = p
@@ -498,7 +556,8 @@ function reorder_parameters(ic::IndexCache, ps; drop_missing = false)
498556
end
499557

500558
result = broadcast.(
501-
unwrap, (param_buf..., disc_buf..., const_buf..., nonnumeric_buf...))
559+
unwrap, (
560+
param_buf..., initials_buf..., disc_buf..., const_buf..., nonnumeric_buf...))
502561
if drop_missing
503562
result = map(result) do buf
504563
filter(buf) do sym
@@ -521,6 +580,11 @@ function iterated_buffer_index(ic::IndexCache, ind::ParameterIndex)
521580
elseif ic.tunable_buffer_size.length > 0
522581
idx += 1
523582
end
583+
if ind.portion isa SciMLStructures.Initials
584+
return idx + 1
585+
elseif ic.initials_buffer_size.length > 0
586+
idx += 1
587+
end
524588
if ind.portion isa SciMLStructures.Discrete
525589
return idx + ind.idx[1]
526590
elseif !isempty(ic.discrete_buffer_sizes)
@@ -542,6 +606,8 @@ function get_buffer_template(ic::IndexCache, pidx::ParameterIndex)
542606

543607
if portion isa SciMLStructures.Tunable
544608
return ic.tunable_buffer_size
609+
elseif portion isa SciMLStructures.Initials
610+
return ic.initials_buffer_size
545611
elseif portion isa SciMLStructures.Discrete
546612
return ic.discrete_buffer_sizes[idx[1]][1]
547613
elseif portion isa SciMLStructures.Constants

src/systems/nonlinear/initializesystem.jl

+21-1
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,15 @@ function (rip::ReconstructInitializeprob)(srcvalp, dstvalp)
338338
end
339339
buf, repack, alias = SciMLStructures.canonicalize(SciMLStructures.Tunable(), newp)
340340
if eltype(buf) != T
341-
newp = repack(T.(buf))
341+
newbuf = similar(buf, T)
342+
copyto!(newbuf, buf)
343+
newp = repack(newbuf)
344+
end
345+
buf, repack, alias = SciMLStructures.canonicalize(SciMLStructures.Initials(), newp)
346+
if eltype(buf) != T
347+
newbuf = similar(buf, T)
348+
copyto!(newbuf, buf)
349+
newp = repack(newbuf)
342350
end
343351
return u0, newp
344352
end
@@ -520,6 +528,9 @@ function SciMLBase.late_binding_update_u0_p(
520528
tunables, repack, alias = SciMLStructures.canonicalize(SciMLStructures.Tunable(), newp)
521529
tunables = DiffEqBase.promote_u0(tunables, newu0, t0)
522530
newp = repack(tunables)
531+
initials, repack, alias = SciMLStructures.canonicalize(SciMLStructures.Initials(), newp)
532+
initials = DiffEqBase.promote_u0(initials, newu0, t0)
533+
newp = repack(initials)
523534

524535
allsyms = all_symbols(sys)
525536
for (k, v) in u0
@@ -538,6 +549,15 @@ function SciMLBase.late_binding_update_u0_p(
538549
return newu0, newp
539550
end
540551

552+
"""
553+
$(TYPEDSIGNATURES)
554+
555+
Check if the given system is an initialization system.
556+
"""
557+
function is_initializesystem(sys::AbstractSystem)
558+
sys isa NonlinearSystem && get_metadata(sys) isa InitializationSystemMetadata
559+
end
560+
541561
"""
542562
Counteracts the CSE/array variable hacks in `symbolics_tearing.jl` so it works with
543563
initialization.

0 commit comments

Comments
 (0)