Skip to content

Commit 25c8001

Browse files
Merge pull request #3358 from SciML/as/arr-param-init
fix: handle scalarized array parameters in initialization
2 parents b25edc0 + 47b636c commit 25c8001

File tree

7 files changed

+58
-17
lines changed

7 files changed

+58
-17
lines changed

src/systems/diffeqs/odesystem.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,8 @@ function ODESystem(eqs, iv; kwargs...)
327327
end
328328
algevars = setdiff(allunknowns, diffvars)
329329

330-
return ODESystem(eqs, iv, collect(Iterators.flatten((diffvars, algevars))), collect(new_ps); kwargs...)
330+
return ODESystem(eqs, iv, collect(Iterators.flatten((diffvars, algevars))),
331+
collect(new_ps); kwargs...)
331332
end
332333

333334
# NOTE: equality does not check cached Jacobian

src/systems/diffeqs/sdesystem.jl

+11-5
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ function SDESystem(sys::ODESystem, neqs; kwargs...)
273273
SDESystem(equations(sys), neqs, get_iv(sys), unknowns(sys), parameters(sys); kwargs...)
274274
end
275275

276-
function SDESystem(eqs::Vector{Equation}, noiseeqs::AbstractArray, iv; kwargs...)
276+
function SDESystem(eqs::Vector{Equation}, noiseeqs::AbstractArray, iv; kwargs...)
277277
diffvars, allunknowns, ps, eqs = process_equations(eqs, iv)
278278

279279
for eq in get(kwargs, :parameter_dependencies, Equation[])
@@ -309,15 +309,21 @@ function SDESystem(eqs::Vector{Equation}, noiseeqs::AbstractArray, iv; kwargs...
309309
noiseps = OrderedSet()
310310
collect_vars!(noisedvs, noiseps, noiseeqs, iv)
311311
for dv in noisedvs
312-
dv allunknowns || throw(ArgumentError("Variable $dv in noise equations is not an unknown of the system."))
312+
dv allunknowns ||
313+
throw(ArgumentError("Variable $dv in noise equations is not an unknown of the system."))
313314
end
314315
algevars = setdiff(allunknowns, diffvars)
315316

316-
return SDESystem(eqs, noiseeqs, iv, Iterators.flatten((diffvars, algevars)), [ps; collect(noiseps)]; kwargs...)
317+
return SDESystem(eqs, noiseeqs, iv, Iterators.flatten((diffvars, algevars)),
318+
[ps; collect(noiseps)]; kwargs...)
317319
end
318320

319-
SDESystem(eq::Equation, noiseeqs::AbstractArray, args...; kwargs...) = SDESystem([eq], noiseeqs, args...; kwargs...)
320-
SDESystem(eq::Equation, noiseeq, args...; kwargs...) = SDESystem([eq], [noiseeq], args...; kwargs...)
321+
function SDESystem(eq::Equation, noiseeqs::AbstractArray, args...; kwargs...)
322+
SDESystem([eq], noiseeqs, args...; kwargs...)
323+
end
324+
function SDESystem(eq::Equation, noiseeq, args...; kwargs...)
325+
SDESystem([eq], [noiseeq], args...; kwargs...)
326+
end
321327

322328
function Base.:(==)(sys1::SDESystem, sys2::SDESystem)
323329
sys1 === sys2 && return true

src/systems/nonlinear/initializesystem.jl

+27-8
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,9 @@ function generate_initializesystem(sys::AbstractSystem;
112112
# If either of them are `missing` the parameter is an unknown
113113
# But if the parameter is passed a value, use that as an additional
114114
# equation in the system
115-
_val1 = get(pmap, p, nothing)
116-
_val2 = get(defs, p, nothing)
117-
_val3 = get(guesses, p, nothing)
115+
_val1 = get_possibly_array_fallback_singletons(pmap, p)
116+
_val2 = get_possibly_array_fallback_singletons(defs, p)
117+
_val3 = get_possibly_array_fallback_singletons(guesses, p)
118118
varp = tovar(p)
119119
paramsubs[p] = varp
120120
# Has a default of `missing`, and (either an equation using the value passed to `ODEProblem` or a guess)
@@ -139,7 +139,7 @@ function generate_initializesystem(sys::AbstractSystem;
139139
error("Invalid setup: parameter $(p) has no default value, initial value, or guess")
140140
end
141141
# given a symbolic value to ODEProblem
142-
elseif symbolic_type(_val1) != NotSymbolic()
142+
elseif symbolic_type(_val1) != NotSymbolic() || is_array_of_symbolics(_val1)
143143
push!(eqs_ics, varp ~ _val1)
144144
push!(defs, varp => _val3)
145145
# No value passed to `ODEProblem`, but a default and a guess are present
@@ -268,16 +268,35 @@ struct InitializationSystemMetadata
268268
oop_reconstruct_u0_p::Union{Nothing, ReconstructInitializeprob}
269269
end
270270

271+
function get_possibly_array_fallback_singletons(varmap, p)
272+
if haskey(varmap, p)
273+
return varmap[p]
274+
end
275+
symbolic_type(p) == ArraySymbolic() || return nothing
276+
scal = collect(p)
277+
if all(x -> haskey(varmap, x), scal)
278+
res = [varmap[x] for x in scal]
279+
if any(x -> x === nothing, res)
280+
return nothing
281+
elseif any(x -> x === missing, res)
282+
return missing
283+
end
284+
return res
285+
end
286+
return nothing
287+
end
288+
271289
function is_parameter_solvable(p, pmap, defs, guesses)
272290
p = unwrap(p)
273291
is_variable_floatingpoint(p) || return false
274-
_val1 = pmap isa AbstractDict ? get(pmap, p, nothing) : nothing
275-
_val2 = get(defs, p, nothing)
276-
_val3 = get(guesses, p, nothing)
292+
_val1 = pmap isa AbstractDict ? get_possibly_array_fallback_singletons(pmap, p) :
293+
nothing
294+
_val2 = get_possibly_array_fallback_singletons(defs, p)
295+
_val3 = get_possibly_array_fallback_singletons(guesses, p)
277296
# either (missing is a default or was passed to the ODEProblem) or (nothing was passed to
278297
# the ODEProblem and it has a default and a guess)
279298
return ((_val1 === missing || _val2 === missing) ||
280-
(symbolic_type(_val1) != NotSymbolic() ||
299+
(symbolic_type(_val1) != NotSymbolic() || is_array_of_symbolics(_val1) ||
281300
_val1 === nothing && _val2 !== nothing)) && _val3 !== nothing
282301
end
283302

src/utils.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -1225,7 +1225,8 @@ function process_equations(eqs, iv)
12251225
throw(ArgumentError("An ODESystem can only have one independent variable."))
12261226
diffvar in diffvars &&
12271227
throw(ArgumentError("The differential variable $diffvar is not unique in the system of equations."))
1228-
!(symtype(diffvar) === Real || eltype(symtype(diffvar)) === Real) && throw(ArgumentError("Differential variable $diffvar has type $(symtype(diffvar)). Differential variables should not be concretely typed."))
1228+
!(symtype(diffvar) === Real || eltype(symtype(diffvar)) === Real) &&
1229+
throw(ArgumentError("Differential variable $diffvar has type $(symtype(diffvar)). Differential variables should not be concretely typed."))
12291230
push!(diffvars, diffvar)
12301231
end
12311232
push!(diffeq, eq)

test/initializationsystem.jl

+14
Original file line numberDiff line numberDiff line change
@@ -1281,3 +1281,17 @@ end
12811281
@test sol[S, 1] 999
12821282
@test SciMLBase.successful_retcode(sol)
12831283
end
1284+
1285+
@testset "Solvable array parameters with scalarized guesses" begin
1286+
@variables x(t)
1287+
@parameters p[1:2] q
1288+
@mtkbuild sys = ODESystem(
1289+
D(x) ~ p[1] + p[2] + q, t; defaults = [p[1] => q, p[2] => 2q],
1290+
guesses = [p[1] => q, p[2] => 2q])
1291+
@test ModelingToolkit.is_parameter_solvable(p, Dict(), defaults(sys), guesses(sys))
1292+
prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [q => 2.0])
1293+
@test length(ModelingToolkit.observed(prob.f.initialization_data.initializeprob.f.sys)) ==
1294+
3
1295+
sol = solve(prob, Tsit5())
1296+
@test sol.ps[p] [2.0, 4.0]
1297+
end

test/odesystem.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -1547,7 +1547,7 @@ end
15471547
@testset "Validate input types" begin
15481548
@parameters p d
15491549
@variables X(t)::Int64
1550-
eq = D(X) ~ p - d*X
1550+
eq = D(X) ~ p - d * X
15511551
@test_throws ArgumentError @mtkbuild osys = ODESystem([eq], t)
15521552
@variables Y(t)[1:3]::String
15531553
eq = D(Y) ~ [p, p, p]

test/sdesystem.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -874,7 +874,7 @@ end
874874
@parameters p d
875875
@variables X(t)::Int64
876876
@brownian z
877-
eq2 = D(X) ~ p - d*X + z
877+
eq2 = D(X) ~ p - d * X + z
878878
@test_throws ArgumentError @mtkbuild ssys = System([eq2], t)
879879
noiseeq = [1]
880880
@test_throws ArgumentError @named ssys = SDESystem([eq2], [noiseeq], t)

0 commit comments

Comments
 (0)