Skip to content

fix: separate Initial parameters into initials portion #3439

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Mar 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
5567ab2
feat: separate initial parameters for non-initialization-systems
AayushSabharwal Mar 4, 2025
b52ecc2
feat: add `Initials` portion to `MTKParameters`
AayushSabharwal Mar 4, 2025
e2251af
refactor: remove dead code
AayushSabharwal Mar 4, 2025
7f3b1e1
fix: properly handle initial parameters in `complete`
AayushSabharwal Mar 4, 2025
73fb137
fix: handle initial parameters in symbolic indexing
AayushSabharwal Mar 4, 2025
e563466
fix: handle initial parameters in codegen
AayushSabharwal Mar 4, 2025
793abb5
fix: fix calling `MTKParameters` functions with just tunables vector
AayushSabharwal Mar 4, 2025
ef5326b
fix: promote `Initials` portion in `ReconstructInitializeprob`
AayushSabharwal Mar 4, 2025
0c9f032
fix: promote `Initials` portion in `late_binding_update_u0_p`
AayushSabharwal Mar 4, 2025
d830ab2
fix: simplify `modelingtoolkitize`
AayushSabharwal Mar 4, 2025
d7da992
fix: return `GeneratedFunctionWrapper` from `generate_control_function`
AayushSabharwal Mar 4, 2025
43c3ccb
fix: support initials in adjoints
AayushSabharwal Mar 4, 2025
5647dee
test: mark test as not broken
AayushSabharwal Mar 4, 2025
d6b3f27
test: remove unnecessary wrapping of parameter vector
AayushSabharwal Mar 4, 2025
5186afd
test: test promotion of initials portion in `remake`
AayushSabharwal Mar 4, 2025
84f7309
test: remove splatting of parameter object
AayushSabharwal Mar 4, 2025
25d52da
test: test indexing of `Initials` portion
AayushSabharwal Mar 4, 2025
6686524
test: update mtkparameters tests with new portion
AayushSabharwal Mar 5, 2025
66ec113
fix: handle empty `reduce` in rrule
AayushSabharwal Mar 5, 2025
bde633c
build: bump SciMLStructures compat
AayushSabharwal Mar 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ Reexport = "0.2, 1"
RuntimeGeneratedFunctions = "0.5.9"
SCCNonlinearSolve = "1.0.0"
SciMLBase = "2.75"
SciMLStructures = "1.0"
SciMLStructures = "1.7"
Serialization = "1"
Setfield = "0.7, 0.8, 1"
SimpleNonlinearSolve = "0.1.0, 1, 2"
Expand Down
10 changes: 8 additions & 2 deletions ext/MTKChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,11 @@ function ChainRulesCore.rrule(
end
newbuf = MTK.remake_buffer(indp, oldbuf, idxs, vals)
tunable_idxs = reduce(
vcat, (idx.idx for idx in idxs if idx.portion isa MTK.SciMLStructures.Tunable))
vcat, (idx.idx for idx in idxs if idx.portion isa MTK.SciMLStructures.Tunable);
init = Union{Int, AbstractVector{Int}}[])
initials_idxs = reduce(
vcat, (idx.idx for idx in idxs if idx.portion isa MTK.SciMLStructures.Initials);
init = Union{Int, AbstractVector{Int}}[])
disc_idxs = subset_idxs(idxs, MTK.SciMLStructures.Discrete(), oldbuf.discrete)
const_idxs = subset_idxs(idxs, MTK.SciMLStructures.Constants(), oldbuf.constant)
nn_idxs = subset_idxs(idxs, MTK.NONNUMERIC_PORTION, oldbuf.nonnumeric)
Expand All @@ -91,10 +95,12 @@ function ChainRulesCore.rrule(
indp′ = NoTangent()

tunable = selected_tangents(buf′.tunable, tunable_idxs)
initials = selected_tangents(buf′.initials, initials_idxs)
discrete = selected_tangents(buf′.discrete, disc_idxs)
constant = selected_tangents(buf′.constant, const_idxs)
nonnumeric = selected_tangents(buf′.nonnumeric, nn_idxs)
oldbuf′ = Tangent{typeof(oldbuf)}(; tunable, discrete, constant, nonnumeric)
oldbuf′ = Tangent{typeof(oldbuf)}(;
tunable, initials, discrete, constant, nonnumeric)
idxs′ = NoTangent()
vals′ = map(i -> MTK._ducktyped_parameter_values(buf′, i), idxs)
return f′, indp′, oldbuf′, idxs′, vals′
Expand Down
3 changes: 3 additions & 0 deletions src/inputoutput.jl
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,9 @@ function generate_control_function(sys::AbstractODESystem, inputs = unbound_inpu
f = build_function_wrapper(sys, rhss, args...; p_start = 3 + implicit_dae,
p_end = length(p) + 2 + implicit_dae)
f = eval_or_rgf.(f; eval_expression, eval_module)
f = GeneratedFunctionWrapper{(
3 + implicit_dae, length(args) - length(p) + 1, is_split(sys))}(f...)
f = f, f
ps = setdiff(parameters(sys), inputs, disturbance_inputs)
(; f, dvs, ps, io_sys = sys)
end
Expand Down
76 changes: 46 additions & 30 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,7 @@
"""
Initial(x)

The `Initial` operator. Used by initializaton to store constant constraints on variables

Check warning on line 625 in src/systems/abstractsystem.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"initializaton" should be "initialization".
of a system. See the documentation section on initialization for more information.
"""
struct Initial <: Symbolics.Operator end
Expand Down Expand Up @@ -773,38 +773,21 @@
if !isempty(all_ps)
# reorder parameters by portions
ps_split = reorder_parameters(sys, all_ps)
# if there are tunables, they will all be in `ps_split[1]`
# and the arrays will have been scalarized
ordered_ps = eltype(all_ps)[]
# if there are no tunables, vcat them
if isempty(get_index_cache(sys).tunable_idx)
ordered_ps = reduce(vcat, ps_split)
else
# if there are tunables, they will all be in `ps_split[1]`
# and the arrays will have been scalarized
ordered_ps = eltype(all_ps)[]
i = 1
# go through all the tunables
while i <= length(ps_split[1])
sym = ps_split[1][i]
# if the sym is not a scalarized array symbolic OR it was already scalarized,
# just push it as-is
if !iscall(sym) || operation(sym) != getindex ||
any(isequal(sym), all_ps)
push!(ordered_ps, sym)
i += 1
continue
end
# the next `length(sym)` symbols should be scalarized versions of the same
# array symbolic
if !allequal(first(arguments(x))
for x in view(ps_split[1], i:(i + length(sym) - 1)))
error("This should not be possible. Please open an issue in ModelingToolkit.jl with an MWE and stacktrace.")
end
arrsym = first(arguments(sym))
push!(ordered_ps, arrsym)
i += length(arrsym)
end
ordered_ps = vcat(
ordered_ps, reduce(vcat, ps_split[2:end]; init = eltype(ordered_ps)[]))
if !isempty(get_index_cache(sys).tunable_idx)
unflatten_parameters!(ordered_ps, ps_split[1], all_ps)
ps_split = Base.tail(ps_split)
end
# unflatten initial parameters
if !isempty(get_index_cache(sys).initials_idx)
unflatten_parameters!(ordered_ps, ps_split[1], all_ps)
ps_split = Base.tail(ps_split)
end
ordered_ps = vcat(
ordered_ps, reduce(vcat, ps_split; init = eltype(ordered_ps)[]))
@set! sys.ps = ordered_ps
end
elseif has_index_cache(sys)
Expand All @@ -816,6 +799,39 @@
isdefined(sys, :complete) ? (@set! sys.complete = true) : sys
end

"""
$(TYPEDSIGNATURES)

Given a flattened array of parameters `params` and a collection of all (unscalarized)
parameters in the system `all_ps`, unscalarize the elements in `params` and append
to `buffer` in the same order as they are present in `params`. Effectively, if
`params = [p[1], p[2], p[3], q]` then this is equivalent to `push!(buffer, p, q)`.
"""
function unflatten_parameters!(buffer, params, all_ps)
i = 1
# go through all the tunables
while i <= length(params)
sym = params[i]
# if the sym is not a scalarized array symbolic OR it was already scalarized,
# just push it as-is
if !iscall(sym) || operation(sym) != getindex ||
any(isequal(sym), all_ps)
push!(buffer, sym)
i += 1
continue
end
# the next `length(sym)` symbols should be scalarized versions of the same
# array symbolic
if !allequal(first(arguments(x))
for x in view(params, i:(i + length(sym) - 1)))
error("This should not be possible. Please open an issue in ModelingToolkit.jl with an MWE and stacktrace.")
end
arrsym = first(arguments(sym))
push!(buffer, arrsym)
i += length(arrsym)
end
end

for prop in [:eqs
:tag
:noiseeqs
Expand Down
2 changes: 1 addition & 1 deletion src/systems/codegen_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ end
# The user provided a single buffer/tuple for the parameter object, so wrap that
# one in a tuple
fargs = ntuple(Val(length(args))) do i
i == paramidx ? :((args[$i],)) : :(args[$i])
i == paramidx ? :((args[$i], nothing)) : :(args[$i])
end
return :($f($(fargs...)))
end
Expand Down
10 changes: 5 additions & 5 deletions src/systems/diffeqs/modelingtoolkitize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ function modelingtoolkitize(
fill!(rhs, 0)
if prob.f isa ODEFunction &&
prob.f.f isa FunctionWrappersWrappers.FunctionWrappersWrapper
prob.f.f.fw[1].obj[](rhs, vars, p isa MTKParameters ? (params,) : params, t)
prob.f.f.fw[1].obj[](rhs, vars, params, t)
else
prob.f(rhs, vars, p isa MTKParameters ? (params,) : params, t)
prob.f(rhs, vars, params, t)
end
else
rhs = prob.f(vars, params, t)
Expand Down Expand Up @@ -255,14 +255,14 @@ function modelingtoolkitize(prob::DiffEqBase.SDEProblem; kwargs...)
if DiffEqBase.isinplace(prob)
lhs = similar(vars, Any)

prob.f(lhs, vars, p isa MTKParameters ? (params,) : params, t)
prob.f(lhs, vars, params, t)

if DiffEqBase.is_diagonal_noise(prob)
neqs = similar(vars, Any)
prob.g(neqs, vars, p isa MTKParameters ? (params,) : params, t)
prob.g(neqs, vars, params, t)
else
neqs = similar(vars, Any, size(prob.noise_rate_prototype))
prob.g(neqs, vars, p isa MTKParameters ? (params,) : params, t)
prob.g(neqs, vars, params, t)
end
else
lhs = prob.f(vars, params, t)
Expand Down
70 changes: 68 additions & 2 deletions src/systems/index_cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,15 @@ struct IndexCache
# sym => (clockidx, idx_in_clockbuffer)
callback_to_clocks::Dict{Any, Vector{Int}}
tunable_idx::TunableIndexMap
initials_idx::TunableIndexMap
constant_idx::ParamIndexMap
nonnumeric_idx::NonnumericMap
observed_syms_to_timeseries::Dict{BasicSymbolic, TimeseriesSetType}
dependent_pars_to_timeseries::Dict{
Union{BasicSymbolic, CallWithMetadata}, TimeseriesSetType}
discrete_buffer_sizes::Vector{Vector{BufferTemplate}}
tunable_buffer_size::BufferTemplate
initials_buffer_size::BufferTemplate
constant_buffer_sizes::Vector{BufferTemplate}
nonnumeric_buffer_sizes::Vector{BufferTemplate}
symbol_to_variable::Dict{Symbol, SymbolicParam}
Expand Down Expand Up @@ -251,7 +253,9 @@ function IndexCache(sys::AbstractSystem)

tunable_idxs = TunableIndexMap()
tunable_buffer_size = 0
for buffers in (tunable_buffers, initial_param_buffers)
bufferlist = is_initializesystem(sys) ? (tunable_buffers, initial_param_buffers) :
(tunable_buffers,)
for buffers in bufferlist
for (i, (_, buf)) in enumerate(buffers)
for (j, p) in enumerate(buf)
idx = if size(p) == ()
Expand All @@ -271,6 +275,43 @@ function IndexCache(sys::AbstractSystem)
end
end

initials_idxs = TunableIndexMap()
initials_buffer_size = 0
if !is_initializesystem(sys)
for (i, (_, buf)) in enumerate(initial_param_buffers)
for (j, p) in enumerate(buf)
idx = if size(p) == ()
initials_buffer_size + 1
else
reshape(
(initials_buffer_size + 1):(initials_buffer_size + length(p)), size(p))
end
initials_buffer_size += length(p)
initials_idxs[p] = idx
initials_idxs[default_toterm(p)] = idx
if hasname(p) && (!iscall(p) || operation(p) !== getindex)
symbol_to_variable[getname(p)] = p
symbol_to_variable[getname(default_toterm(p))] = p
end
end
end
end

for k in collect(keys(tunable_idxs))
v = tunable_idxs[k]
v isa AbstractArray || continue
for (kk, vv) in zip(collect(k), v)
tunable_idxs[kk] = vv
end
end
for k in collect(keys(initials_idxs))
v = initials_idxs[k]
v isa AbstractArray || continue
for (kk, vv) in zip(collect(k), v)
initials_idxs[kk] = vv
end
end

dependent_pars_to_timeseries = Dict{
Union{BasicSymbolic, CallWithMetadata}, TimeseriesSetType}()

Expand Down Expand Up @@ -341,12 +382,14 @@ function IndexCache(sys::AbstractSystem)
disc_idxs,
callback_to_clocks,
tunable_idxs,
initials_idxs,
const_idxs,
nonnumeric_idxs,
observed_syms_to_timeseries,
dependent_pars_to_timeseries,
disc_buffer_templates,
BufferTemplate(Real, tunable_buffer_size),
BufferTemplate(Real, initials_buffer_size),
const_buffer_sizes,
nonnumeric_buffer_sizes,
symbol_to_variable
Expand Down Expand Up @@ -385,6 +428,8 @@ function SymbolicIndexingInterface.parameter_index(ic::IndexCache, sym)
Symbolics.shape(sym) !== Symbolics.Unknown()
return if (idx = check_index_map(ic.tunable_idx, sym)) !== nothing
ParameterIndex(SciMLStructures.Tunable(), idx, validate_size)
elseif (idx = check_index_map(ic.initials_idx, sym)) !== nothing
ParameterIndex(SciMLStructures.Initials(), idx, validate_size)
elseif (idx = check_index_map(ic.discrete_idx, sym)) !== nothing
ParameterIndex(
SciMLStructures.Discrete(), (idx.buffer_idx, idx.idx_in_buffer), validate_size)
Expand Down Expand Up @@ -465,6 +510,12 @@ function reorder_parameters(ic::IndexCache, ps; drop_missing = false)
(BasicSymbolic[unwrap(variable(:DEF))
for _ in 1:(ic.tunable_buffer_size.length)],)
end
initials_buf = if ic.initials_buffer_size.length == 0
()
else
(BasicSymbolic[unwrap(variable(:DEF))
for _ in 1:(ic.initials_buffer_size.length)],)
end

disc_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF))
for _ in 1:(sum(x -> x.length, temp))]
Expand All @@ -486,6 +537,13 @@ function reorder_parameters(ic::IndexCache, ps; drop_missing = false)
else
param_buf[1][i] = unwrap.(collect(p))
end
elseif haskey(ic.initials_idx, p)
i = ic.initials_idx[p]
if i isa Int
initials_buf[1][i] = unwrap(p)
else
initials_buf[1][i] = unwrap.(collect(p))
end
elseif haskey(ic.constant_idx, p)
i, j = ic.constant_idx[p]
const_buf[i][j] = p
Expand All @@ -498,7 +556,8 @@ function reorder_parameters(ic::IndexCache, ps; drop_missing = false)
end

result = broadcast.(
unwrap, (param_buf..., disc_buf..., const_buf..., nonnumeric_buf...))
unwrap, (
param_buf..., initials_buf..., disc_buf..., const_buf..., nonnumeric_buf...))
if drop_missing
result = map(result) do buf
filter(buf) do sym
Expand All @@ -521,6 +580,11 @@ function iterated_buffer_index(ic::IndexCache, ind::ParameterIndex)
elseif ic.tunable_buffer_size.length > 0
idx += 1
end
if ind.portion isa SciMLStructures.Initials
return idx + 1
elseif ic.initials_buffer_size.length > 0
idx += 1
end
if ind.portion isa SciMLStructures.Discrete
return idx + ind.idx[1]
elseif !isempty(ic.discrete_buffer_sizes)
Expand All @@ -542,6 +606,8 @@ function get_buffer_template(ic::IndexCache, pidx::ParameterIndex)

if portion isa SciMLStructures.Tunable
return ic.tunable_buffer_size
elseif portion isa SciMLStructures.Initials
return ic.initials_buffer_size
elseif portion isa SciMLStructures.Discrete
return ic.discrete_buffer_sizes[idx[1]][1]
elseif portion isa SciMLStructures.Constants
Expand Down
22 changes: 21 additions & 1 deletion src/systems/nonlinear/initializesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,15 @@ function (rip::ReconstructInitializeprob)(srcvalp, dstvalp)
end
buf, repack, alias = SciMLStructures.canonicalize(SciMLStructures.Tunable(), newp)
if eltype(buf) != T
newp = repack(T.(buf))
newbuf = similar(buf, T)
copyto!(newbuf, buf)
newp = repack(newbuf)
end
buf, repack, alias = SciMLStructures.canonicalize(SciMLStructures.Initials(), newp)
if eltype(buf) != T
newbuf = similar(buf, T)
copyto!(newbuf, buf)
newp = repack(newbuf)
end
return u0, newp
end
Expand Down Expand Up @@ -520,6 +528,9 @@ function SciMLBase.late_binding_update_u0_p(
tunables, repack, alias = SciMLStructures.canonicalize(SciMLStructures.Tunable(), newp)
tunables = DiffEqBase.promote_u0(tunables, newu0, t0)
newp = repack(tunables)
initials, repack, alias = SciMLStructures.canonicalize(SciMLStructures.Initials(), newp)
initials = DiffEqBase.promote_u0(initials, newu0, t0)
newp = repack(initials)

allsyms = all_symbols(sys)
for (k, v) in u0
Expand All @@ -538,6 +549,15 @@ function SciMLBase.late_binding_update_u0_p(
return newu0, newp
end

"""
$(TYPEDSIGNATURES)

Check if the given system is an initialization system.
"""
function is_initializesystem(sys::AbstractSystem)
sys isa NonlinearSystem && get_metadata(sys) isa InitializationSystemMetadata
end

"""
Counteracts the CSE/array variable hacks in `symbolics_tearing.jl` so it works with
initialization.
Expand Down
Loading
Loading