diff --git a/src/systems/diffeqs/sdesystem.jl b/src/systems/diffeqs/sdesystem.jl index 366e8a6d09..fb2eb1d804 100644 --- a/src/systems/diffeqs/sdesystem.jl +++ b/src/systems/diffeqs/sdesystem.jl @@ -737,6 +737,7 @@ function DiffEqBase.SDEProblem{iip, specialize}( if !iscomplete(sys) error("A completed `SDESystem` is required. Call `complete` or `structural_simplify` on the system before creating an `SDEProblem`") end + f, u0, p = process_SciMLProblem( SDEFunction{iip, specialize}, sys, u0map, parammap; check_length, t = tspan === nothing ? nothing : tspan[1], kwargs...) @@ -767,6 +768,15 @@ function DiffEqBase.SDEProblem{iip, specialize}( noise_rate_prototype = noise_rate_prototype, kwargs...) end +function DiffEqBase.SDEProblem(sys::ODESystem, args...; kwargs...) + + if any(ModelingToolkit.isbrownian, unknowns(sys)) + error("SDESystem constructed by defining Brownian variables with @brownian must be simplified by calling `structural_simplify` before a SDEProblem can be constructed.") + else + error("Cannot construct SDEProblem from a normal ODESystem.") + end +end + """ ```julia DiffEqBase.SDEProblem{iip}(sys::SDESystem, u0map, tspan, p = parammap; diff --git a/test/sdesystem.jl b/test/sdesystem.jl index 8069581dcc..94f3e9cbfc 100644 --- a/test/sdesystem.jl +++ b/test/sdesystem.jl @@ -868,3 +868,22 @@ end @test length(ModelingToolkit.get_noiseeqs(sys)) == 1 @test length(observed(sys)) == 1 end + +@testset "Error when constructing SDESystem without `structural_simplify`" begin + @parameters σ ρ β + @variables x(tt) y(tt) z(tt) + @brownian a + eqs = [D(x) ~ σ * (y - x) + 0.1a * x, + D(y) ~ x * (ρ - z) - y + 0.1a * y, + D(z) ~ x * y - β * z + 0.1a * z] + + @named de = System(eqs, t) + de = complete(de) + + u0map = [x => 1.0, y => 0.0, z => 0.0] + parammap = [σ => 10.0, β => 26.0, ρ => 2.33] + + @test_throws ErrorException("SDESystem constructed by defining Brownian variables with @brownian must be simplified by calling `structural_simplify` before a SDEProblem can be constructed.") SDEProblem(de, u0map, (0.0, 100.0), parammap) + de = structural_simplify(de) + @test SDEProblem(de, u0map, (0.0, 100.0), parammap) isa SDEProblem +end