Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: handle scalarized array parameters in initialization #3358

Merged
merged 2 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,8 @@ function ODESystem(eqs, iv; kwargs...)
end
algevars = setdiff(allunknowns, diffvars)

return ODESystem(eqs, iv, collect(Iterators.flatten((diffvars, algevars))), collect(new_ps); kwargs...)
return ODESystem(eqs, iv, collect(Iterators.flatten((diffvars, algevars))),
collect(new_ps); kwargs...)
end

# NOTE: equality does not check cached Jacobian
Expand Down
16 changes: 11 additions & 5 deletions src/systems/diffeqs/sdesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ function SDESystem(sys::ODESystem, neqs; kwargs...)
SDESystem(equations(sys), neqs, get_iv(sys), unknowns(sys), parameters(sys); kwargs...)
end

function SDESystem(eqs::Vector{Equation}, noiseeqs::AbstractArray, iv; kwargs...)
function SDESystem(eqs::Vector{Equation}, noiseeqs::AbstractArray, iv; kwargs...)
diffvars, allunknowns, ps, eqs = process_equations(eqs, iv)

for eq in get(kwargs, :parameter_dependencies, Equation[])
Expand Down Expand Up @@ -309,15 +309,21 @@ function SDESystem(eqs::Vector{Equation}, noiseeqs::AbstractArray, iv; kwargs...
noiseps = OrderedSet()
collect_vars!(noisedvs, noiseps, noiseeqs, iv)
for dv in noisedvs
dv ∈ allunknowns || throw(ArgumentError("Variable $dv in noise equations is not an unknown of the system."))
dv ∈ allunknowns ||
throw(ArgumentError("Variable $dv in noise equations is not an unknown of the system."))
end
algevars = setdiff(allunknowns, diffvars)

return SDESystem(eqs, noiseeqs, iv, Iterators.flatten((diffvars, algevars)), [ps; collect(noiseps)]; kwargs...)
return SDESystem(eqs, noiseeqs, iv, Iterators.flatten((diffvars, algevars)),
[ps; collect(noiseps)]; kwargs...)
end

SDESystem(eq::Equation, noiseeqs::AbstractArray, args...; kwargs...) = SDESystem([eq], noiseeqs, args...; kwargs...)
SDESystem(eq::Equation, noiseeq, args...; kwargs...) = SDESystem([eq], [noiseeq], args...; kwargs...)
function SDESystem(eq::Equation, noiseeqs::AbstractArray, args...; kwargs...)
SDESystem([eq], noiseeqs, args...; kwargs...)
end
function SDESystem(eq::Equation, noiseeq, args...; kwargs...)
SDESystem([eq], [noiseeq], args...; kwargs...)
end

function Base.:(==)(sys1::SDESystem, sys2::SDESystem)
sys1 === sys2 && return true
Expand Down
35 changes: 27 additions & 8 deletions src/systems/nonlinear/initializesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,9 @@ function generate_initializesystem(sys::AbstractSystem;
# If either of them are `missing` the parameter is an unknown
# But if the parameter is passed a value, use that as an additional
# equation in the system
_val1 = get(pmap, p, nothing)
_val2 = get(defs, p, nothing)
_val3 = get(guesses, p, nothing)
_val1 = get_possibly_array_fallback_singletons(pmap, p)
_val2 = get_possibly_array_fallback_singletons(defs, p)
_val3 = get_possibly_array_fallback_singletons(guesses, p)
varp = tovar(p)
paramsubs[p] = varp
# Has a default of `missing`, and (either an equation using the value passed to `ODEProblem` or a guess)
Expand All @@ -139,7 +139,7 @@ function generate_initializesystem(sys::AbstractSystem;
error("Invalid setup: parameter $(p) has no default value, initial value, or guess")
end
# given a symbolic value to ODEProblem
elseif symbolic_type(_val1) != NotSymbolic()
elseif symbolic_type(_val1) != NotSymbolic() || is_array_of_symbolics(_val1)
push!(eqs_ics, varp ~ _val1)
push!(defs, varp => _val3)
# No value passed to `ODEProblem`, but a default and a guess are present
Expand Down Expand Up @@ -268,16 +268,35 @@ struct InitializationSystemMetadata
oop_reconstruct_u0_p::Union{Nothing, ReconstructInitializeprob}
end

function get_possibly_array_fallback_singletons(varmap, p)
if haskey(varmap, p)
return varmap[p]
end
symbolic_type(p) == ArraySymbolic() || return nothing
scal = collect(p)
if all(x -> haskey(varmap, x), scal)
res = [varmap[x] for x in scal]
if any(x -> x === nothing, res)
return nothing
elseif any(x -> x === missing, res)
return missing
end
return res
end
return nothing
end

function is_parameter_solvable(p, pmap, defs, guesses)
p = unwrap(p)
is_variable_floatingpoint(p) || return false
_val1 = pmap isa AbstractDict ? get(pmap, p, nothing) : nothing
_val2 = get(defs, p, nothing)
_val3 = get(guesses, p, nothing)
_val1 = pmap isa AbstractDict ? get_possibly_array_fallback_singletons(pmap, p) :
nothing
_val2 = get_possibly_array_fallback_singletons(defs, p)
_val3 = get_possibly_array_fallback_singletons(guesses, p)
# either (missing is a default or was passed to the ODEProblem) or (nothing was passed to
# the ODEProblem and it has a default and a guess)
return ((_val1 === missing || _val2 === missing) ||
(symbolic_type(_val1) != NotSymbolic() ||
(symbolic_type(_val1) != NotSymbolic() || is_array_of_symbolics(_val1) ||
_val1 === nothing && _val2 !== nothing)) && _val3 !== nothing
end

Expand Down
3 changes: 2 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1225,7 +1225,8 @@ function process_equations(eqs, iv)
throw(ArgumentError("An ODESystem can only have one independent variable."))
diffvar in diffvars &&
throw(ArgumentError("The differential variable $diffvar is not unique in the system of equations."))
!(symtype(diffvar) === Real || eltype(symtype(diffvar)) === Real) && throw(ArgumentError("Differential variable $diffvar has type $(symtype(diffvar)). Differential variables should not be concretely typed."))
!(symtype(diffvar) === Real || eltype(symtype(diffvar)) === Real) &&
throw(ArgumentError("Differential variable $diffvar has type $(symtype(diffvar)). Differential variables should not be concretely typed."))
push!(diffvars, diffvar)
end
push!(diffeq, eq)
Expand Down
14 changes: 14 additions & 0 deletions test/initializationsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1281,3 +1281,17 @@ end
@test sol[S, 1] ≈ 999
@test SciMLBase.successful_retcode(sol)
end

@testset "Solvable array parameters with scalarized guesses" begin
@variables x(t)
@parameters p[1:2] q
@mtkbuild sys = ODESystem(
D(x) ~ p[1] + p[2] + q, t; defaults = [p[1] => q, p[2] => 2q],
guesses = [p[1] => q, p[2] => 2q])
@test ModelingToolkit.is_parameter_solvable(p, Dict(), defaults(sys), guesses(sys))
prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [q => 2.0])
@test length(ModelingToolkit.observed(prob.f.initialization_data.initializeprob.f.sys)) ==
3
sol = solve(prob, Tsit5())
@test sol.ps[p] ≈ [2.0, 4.0]
end
2 changes: 1 addition & 1 deletion test/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1547,7 +1547,7 @@ end
@testset "Validate input types" begin
@parameters p d
@variables X(t)::Int64
eq = D(X) ~ p - d*X
eq = D(X) ~ p - d * X
@test_throws ArgumentError @mtkbuild osys = ODESystem([eq], t)
@variables Y(t)[1:3]::String
eq = D(Y) ~ [p, p, p]
Expand Down
2 changes: 1 addition & 1 deletion test/sdesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -874,7 +874,7 @@ end
@parameters p d
@variables X(t)::Int64
@brownian z
eq2 = D(X) ~ p - d*X + z
eq2 = D(X) ~ p - d * X + z
@test_throws ArgumentError @mtkbuild ssys = System([eq2], t)
noiseeq = [1]
@test_throws ArgumentError @named ssys = SDESystem([eq2], [noiseeq], t)
Expand Down
Loading