Skip to content

Commit be9c8c8

Browse files
committedJan 30, 2025
fix: handle scalarized array parameters in initialization
1 parent 1b7261a commit be9c8c8

File tree

2 files changed

+37
-8
lines changed

2 files changed

+37
-8
lines changed
 

‎src/systems/nonlinear/initializesystem.jl

+26-8
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,9 @@ function generate_initializesystem(sys::AbstractSystem;
112112
# If either of them are `missing` the parameter is an unknown
113113
# But if the parameter is passed a value, use that as an additional
114114
# equation in the system
115-
_val1 = get(pmap, p, nothing)
116-
_val2 = get(defs, p, nothing)
117-
_val3 = get(guesses, p, nothing)
115+
_val1 = get_possibly_array_fallback_singletons(pmap, p)
116+
_val2 = get_possibly_array_fallback_singletons(defs, p)
117+
_val3 = get_possibly_array_fallback_singletons(guesses, p)
118118
varp = tovar(p)
119119
paramsubs[p] = varp
120120
# Has a default of `missing`, and (either an equation using the value passed to `ODEProblem` or a guess)
@@ -139,7 +139,7 @@ function generate_initializesystem(sys::AbstractSystem;
139139
error("Invalid setup: parameter $(p) has no default value, initial value, or guess")
140140
end
141141
# given a symbolic value to ODEProblem
142-
elseif symbolic_type(_val1) != NotSymbolic()
142+
elseif symbolic_type(_val1) != NotSymbolic() || is_array_of_symbolics(_val1)
143143
push!(eqs_ics, varp ~ _val1)
144144
push!(defs, varp => _val3)
145145
# No value passed to `ODEProblem`, but a default and a guess are present
@@ -268,16 +268,34 @@ struct InitializationSystemMetadata
268268
oop_reconstruct_u0_p::Union{Nothing, ReconstructInitializeprob}
269269
end
270270

271+
function get_possibly_array_fallback_singletons(varmap, p)
272+
if haskey(varmap, p)
273+
return varmap[p]
274+
end
275+
symbolic_type(p) == ArraySymbolic() || return nothing
276+
scal = collect(p)
277+
if all(x -> haskey(varmap, x), scal)
278+
res = [varmap[x] for x in scal]
279+
if any(x -> x === nothing, res)
280+
return nothing
281+
elseif any(x -> x === missing, res)
282+
return missing
283+
end
284+
return res
285+
end
286+
return nothing
287+
end
288+
271289
function is_parameter_solvable(p, pmap, defs, guesses)
272290
p = unwrap(p)
273291
is_variable_floatingpoint(p) || return false
274-
_val1 = pmap isa AbstractDict ? get(pmap, p, nothing) : nothing
275-
_val2 = get(defs, p, nothing)
276-
_val3 = get(guesses, p, nothing)
292+
_val1 = pmap isa AbstractDict ? get_possibly_array_fallback_singletons(pmap, p) : nothing
293+
_val2 = get_possibly_array_fallback_singletons(defs, p)
294+
_val3 = get_possibly_array_fallback_singletons(guesses, p)
277295
# either (missing is a default or was passed to the ODEProblem) or (nothing was passed to
278296
# the ODEProblem and it has a default and a guess)
279297
return ((_val1 === missing || _val2 === missing) ||
280-
(symbolic_type(_val1) != NotSymbolic() ||
298+
(symbolic_type(_val1) != NotSymbolic() || is_array_of_symbolics(_val1) ||
281299
_val1 === nothing && _val2 !== nothing)) && _val3 !== nothing
282300
end
283301

‎test/initializationsystem.jl

+11
Original file line numberDiff line numberDiff line change
@@ -1281,3 +1281,14 @@ end
12811281
@test sol[S, 1] 999
12821282
@test SciMLBase.successful_retcode(sol)
12831283
end
1284+
1285+
@testset "Solvable array parameters with scalarized guesses" begin
1286+
@variables x(t)
1287+
@parameters p[1:2] q
1288+
@mtkbuild sys = ODESystem(D(x) ~ p[1] + p[2] + q, t; defaults = [p[1] => q, p[2] => 2q], guesses = [p[1] => q, p[2] => 2q])
1289+
@test ModelingToolkit.is_parameter_solvable(p, Dict(), defaults(sys), guesses(sys))
1290+
prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [q => 2.0])
1291+
@test length(observed(prob.f.initialization_data.initializeprob.f.sys)) == 3
1292+
sol = solve(prob, Tsit5())
1293+
@test sol.ps[p] [2.0, 4.0]
1294+
end

0 commit comments

Comments
 (0)