Skip to content

Commit 14f2603

Browse files
Merge pull request #3340 from vyudu/validate-inputs-odesys
Validate types of state variables in ODESystem/SDESystem construction
2 parents 1b7261a + 84759d8 commit 14f2603

File tree

5 files changed

+126
-41
lines changed

5 files changed

+126
-41
lines changed

src/systems/diffeqs/odesystem.jl

+7-40
Original file line numberDiff line numberDiff line change
@@ -296,53 +296,21 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
296296
end
297297

298298
function ODESystem(eqs, iv; kwargs...)
299-
eqs = collect(eqs)
300-
# NOTE: this assumes that the order of algebraic equations doesn't matter
301-
diffvars = OrderedSet()
302-
allunknowns = OrderedSet()
303-
ps = OrderedSet()
304-
# reorder equations such that it is in the form of `diffeq, algeeq`
305-
diffeq = Equation[]
306-
algeeq = Equation[]
307-
# initial loop for finding `iv`
308-
if iv === nothing
309-
for eq in eqs
310-
if !(eq.lhs isa Number) # assume eq.lhs is either Differential or Number
311-
iv = iv_from_nested_derivative(eq.lhs)
312-
break
313-
end
314-
end
315-
end
316-
iv = value(iv)
317-
iv === nothing && throw(ArgumentError("Please pass in independent variables."))
318-
compressed_eqs = Equation[] # equations that need to be expanded later, like `connect(a, b)`
319-
for eq in eqs
320-
eq.lhs isa Union{Symbolic, Number} || (push!(compressed_eqs, eq); continue)
321-
collect_vars!(allunknowns, ps, eq, iv)
322-
if isdiffeq(eq)
323-
diffvar, _ = var_from_nested_derivative(eq.lhs)
324-
if check_scope_depth(getmetadata(diffvar, SymScope, LocalScope()), 0)
325-
isequal(iv, iv_from_nested_derivative(eq.lhs)) ||
326-
throw(ArgumentError("An ODESystem can only have one independent variable."))
327-
diffvar in diffvars &&
328-
throw(ArgumentError("The differential variable $diffvar is not unique in the system of equations."))
329-
push!(diffvars, diffvar)
330-
end
331-
push!(diffeq, eq)
332-
else
333-
push!(algeeq, eq)
334-
end
335-
end
299+
diffvars, allunknowns, ps, eqs = process_equations(eqs, iv)
300+
336301
for eq in get(kwargs, :parameter_dependencies, Equation[])
337302
collect_vars!(allunknowns, ps, eq, iv)
338303
end
304+
339305
for ssys in get(kwargs, :systems, ODESystem[])
340306
collect_scoped_vars!(allunknowns, ps, ssys, iv)
341307
end
308+
342309
for v in allunknowns
343310
isdelay(v, iv) || continue
344311
collect_vars!(allunknowns, ps, arguments(v)[1], iv)
345312
end
313+
346314
new_ps = OrderedSet()
347315
for p in ps
348316
if iscall(p) && operation(p) === getindex
@@ -358,9 +326,8 @@ function ODESystem(eqs, iv; kwargs...)
358326
end
359327
end
360328
algevars = setdiff(allunknowns, diffvars)
361-
# the orders here are very important!
362-
return ODESystem(Equation[diffeq; algeeq; compressed_eqs], iv,
363-
collect(Iterators.flatten((diffvars, algevars))), collect(new_ps); kwargs...)
329+
330+
return ODESystem(eqs, iv, collect(Iterators.flatten((diffvars, algevars))), collect(new_ps); kwargs...)
364331
end
365332

366333
# NOTE: equality does not check cached Jacobian

src/systems/diffeqs/sdesystem.jl

+46
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,52 @@ function SDESystem(sys::ODESystem, neqs; kwargs...)
273273
SDESystem(equations(sys), neqs, get_iv(sys), unknowns(sys), parameters(sys); kwargs...)
274274
end
275275

276+
function SDESystem(eqs::Vector{Equation}, noiseeqs::AbstractArray, iv; kwargs...)
277+
diffvars, allunknowns, ps, eqs = process_equations(eqs, iv)
278+
279+
for eq in get(kwargs, :parameter_dependencies, Equation[])
280+
collect_vars!(allunknowns, ps, eq, iv)
281+
end
282+
283+
for ssys in get(kwargs, :systems, ODESystem[])
284+
collect_scoped_vars!(allunknowns, ps, ssys, iv)
285+
end
286+
287+
for v in allunknowns
288+
isdelay(v, iv) || continue
289+
collect_vars!(allunknowns, ps, arguments(v)[1], iv)
290+
end
291+
292+
new_ps = OrderedSet()
293+
for p in ps
294+
if iscall(p) && operation(p) === getindex
295+
par = arguments(p)[begin]
296+
if Symbolics.shape(Symbolics.unwrap(par)) !== Symbolics.Unknown() &&
297+
all(par[i] in ps for i in eachindex(par))
298+
push!(new_ps, par)
299+
else
300+
push!(new_ps, p)
301+
end
302+
else
303+
push!(new_ps, p)
304+
end
305+
end
306+
307+
# validate noise equations
308+
noisedvs = OrderedSet()
309+
noiseps = OrderedSet()
310+
collect_vars!(noisedvs, noiseps, noiseeqs, iv)
311+
for dv in noisedvs
312+
dv allunknowns || throw(ArgumentError("Variable $dv in noise equations is not an unknown of the system."))
313+
end
314+
algevars = setdiff(allunknowns, diffvars)
315+
316+
return SDESystem(eqs, noiseeqs, iv, Iterators.flatten((diffvars, algevars)), [ps; collect(noiseps)]; kwargs...)
317+
end
318+
319+
SDESystem(eq::Equation, noiseeqs::AbstractArray, args...; kwargs...) = SDESystem([eq], noiseeqs, args...; kwargs...)
320+
SDESystem(eq::Equation, noiseeq, args...; kwargs...) = SDESystem([eq], [noiseeq], args...; kwargs...)
321+
276322
function Base.:(==)(sys1::SDESystem, sys2::SDESystem)
277323
sys1 === sys2 && return true
278324
iv1 = get_iv(sys1)

src/utils.jl

+51
Original file line numberDiff line numberDiff line change
@@ -1185,3 +1185,54 @@ function guesses_from_metadata!(guesses, vars)
11851185
guesses[vars[i]] = varguesses[i]
11861186
end
11871187
end
1188+
1189+
"""
1190+
$(TYPEDSIGNATURES)
1191+
1192+
Find all the unknowns and parameters from the equations of a SDESystem or ODESystem. Return re-ordered equations, differential variables, all variables, and parameters.
1193+
"""
1194+
function process_equations(eqs, iv)
1195+
eqs = collect(eqs)
1196+
1197+
diffvars = OrderedSet()
1198+
allunknowns = OrderedSet()
1199+
ps = OrderedSet()
1200+
1201+
# NOTE: this assumes that the order of algebraic equations doesn't matter
1202+
# reorder equations such that it is in the form of `diffeq, algeeq`
1203+
diffeq = Equation[]
1204+
algeeq = Equation[]
1205+
# initial loop for finding `iv`
1206+
if iv === nothing
1207+
for eq in eqs
1208+
if !(eq.lhs isa Number) # assume eq.lhs is either Differential or Number
1209+
iv = iv_from_nested_derivative(eq.lhs)
1210+
break
1211+
end
1212+
end
1213+
end
1214+
iv = value(iv)
1215+
iv === nothing && throw(ArgumentError("Please pass in independent variables."))
1216+
1217+
compressed_eqs = Equation[] # equations that need to be expanded later, like `connect(a, b)`
1218+
for eq in eqs
1219+
eq.lhs isa Union{Symbolic, Number} || (push!(compressed_eqs, eq); continue)
1220+
collect_vars!(allunknowns, ps, eq, iv)
1221+
if isdiffeq(eq)
1222+
diffvar, _ = var_from_nested_derivative(eq.lhs)
1223+
if check_scope_depth(getmetadata(diffvar, SymScope, LocalScope()), 0)
1224+
isequal(iv, iv_from_nested_derivative(eq.lhs)) ||
1225+
throw(ArgumentError("An ODESystem can only have one independent variable."))
1226+
diffvar in diffvars &&
1227+
throw(ArgumentError("The differential variable $diffvar is not unique in the system of equations."))
1228+
!(symtype(diffvar) === Real || eltype(symtype(diffvar)) === Real) && throw(ArgumentError("Differential variable $diffvar has type $(symtype(diffvar)). Differential variables should not be concretely typed."))
1229+
push!(diffvars, diffvar)
1230+
end
1231+
push!(diffeq, eq)
1232+
else
1233+
push!(algeeq, eq)
1234+
end
1235+
end
1236+
1237+
diffvars, allunknowns, ps, Equation[diffeq; algeeq; compressed_eqs]
1238+
end

test/odesystem.jl

+10
Original file line numberDiff line numberDiff line change
@@ -1544,6 +1544,16 @@ end
15441544
@test all(x -> any(isapprox(x, atol = 1e-6), sol2.t), expected_tstops)
15451545
end
15461546

1547+
@testset "Validate input types" begin
1548+
@parameters p d
1549+
@variables X(t)::Int64
1550+
eq = D(X) ~ p - d*X
1551+
@test_throws ArgumentError @mtkbuild osys = ODESystem([eq], t)
1552+
@variables Y(t)[1:3]::String
1553+
eq = D(Y) ~ [p, p, p]
1554+
@test_throws ArgumentError @mtkbuild osys = ODESystem([eq], t)
1555+
end
1556+
15471557
# Test `isequal`
15481558
@testset "`isequal`" begin
15491559
@variables X(t)

test/sdesystem.jl

+12-1
Original file line numberDiff line numberDiff line change
@@ -869,6 +869,17 @@ end
869869
@test length(observed(sys)) == 1
870870
end
871871

872+
# Test validating types of states
873+
@testset "Validate input types" begin
874+
@parameters p d
875+
@variables X(t)::Int64
876+
@brownian z
877+
eq2 = D(X) ~ p - d*X + z
878+
@test_throws ArgumentError @mtkbuild ssys = System([eq2], t)
879+
noiseeq = [1]
880+
@test_throws ArgumentError @named ssys = SDESystem([eq2], [noiseeq], t)
881+
end
882+
872883
@testset "SDEFunctionExpr" begin
873884
@parameters σ ρ β
874885
@variables x(tt) y(tt) z(tt)
@@ -953,4 +964,4 @@ end
953964
@test_throws ErrorException("SDESystem constructed by defining Brownian variables with @brownian must be simplified by calling `structural_simplify` before a SDEProblem can be constructed.") SDEProblem(de, u0map, (0.0, 100.0), parammap)
954965
de = structural_simplify(de)
955966
@test SDEProblem(de, u0map, (0.0, 100.0), parammap) isa SDEProblem
956-
end
967+
end

0 commit comments

Comments
 (0)