Skip to content

Commit 8936c97

Browse files
Merge pull request #3329 from AayushSabharwal/as/jumpsys-initsys
fix: fix initialization of `DiscreteProblem(::JumpSystem)`
2 parents 4c86290 + 29d1a18 commit 8936c97

File tree

4 files changed

+35
-6
lines changed

4 files changed

+35
-6
lines changed

Project.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ NonlinearSolve = "4.3"
123123
OffsetArrays = "1"
124124
OrderedCollections = "1"
125125
OrdinaryDiffEq = "6.82.0"
126-
OrdinaryDiffEqCore = "1.13.0"
126+
OrdinaryDiffEqCore = "1.15.0"
127127
OrdinaryDiffEqDefault = "1.2"
128128
OrdinaryDiffEqNonlinearSolve = "1.3.0"
129129
PrecompileTools = "1"
@@ -132,7 +132,7 @@ RecursiveArrayTools = "3.26"
132132
Reexport = "0.2, 1"
133133
RuntimeGeneratedFunctions = "0.5.9"
134134
SCCNonlinearSolve = "1.0.0"
135-
SciMLBase = "2.68.1"
135+
SciMLBase = "2.71"
136136
SciMLStructures = "1.0"
137137
Serialization = "1"
138138
Setfield = "0.7, 0.8, 1"

src/systems/jumps/jumpsystem.jl

+3-2
Original file line numberDiff line numberDiff line change
@@ -425,14 +425,15 @@ function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan::Union{Tuple,
425425
error("The passed in JumpSystem contains `Equation`s or continuous events, please use a problem type that supports these features, such as ODEProblem.")
426426
end
427427

428-
_, u0, p = process_SciMLProblem(EmptySciMLFunction, sys, u0map, parammap;
428+
_f, u0, p = process_SciMLProblem(EmptySciMLFunction, sys, u0map, parammap;
429429
t = tspan === nothing ? nothing : tspan[1], use_union, tofloat = false, check_length = false)
430430
f = DiffEqBase.DISCRETE_INPLACE_DEFAULT
431431

432432
observedfun = ObservedFunctionCache(
433433
sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))
434434

435-
df = DiscreteFunction{true, true}(f; sys = sys, observed = observedfun)
435+
df = DiscreteFunction{true, true}(f; sys = sys, observed = observedfun,
436+
initialization_data = get(_f.kwargs, :initialization_data, nothing))
436437
DiscreteProblem(df, u0, tspan, p; kwargs...)
437438
end
438439

src/systems/nonlinear/initializesystem.jl

+3-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@ function generate_initializesystem(sys::AbstractSystem;
1313
check_units = true, check_defguess = false,
1414
name = nameof(sys), extra_metadata = (;), kwargs...)
1515
eqs = equations(sys)
16-
eqs = filter(x -> x isa Equation, eqs)
16+
if !(eqs isa Vector{Equation})
17+
eqs = Equation[x for x in eqs if x isa Equation]
18+
end
1719
trueobs, eqs = unhack_observed(observed(sys), eqs)
1820
vars = unique([unknowns(sys); getfield.(trueobs, :lhs)])
1921
vars_set = Set(vars) # for efficient in-lookup

test/initializationsystem.jl

+27-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using ModelingToolkit, OrdinaryDiffEq, NonlinearSolve, Test
2-
using StochasticDiffEq, DelayDiffEq, StochasticDelayDiffEq
2+
using StochasticDiffEq, DelayDiffEq, StochasticDelayDiffEq, JumpProcesses
33
using ForwardDiff
44
using SymbolicIndexingInterface, SciMLStructures
55
using SciMLStructures: Tunable
@@ -1306,3 +1306,29 @@ end
13061306
@test integ[X] 4.0
13071307
@test integ[Y] 7.0
13081308
end
1309+
1310+
@testset "Issue#3297: `generate_initializesystem(::JumpSystem)`" begin
1311+
@parameters β γ S0
1312+
@variables S(t)=S0 I(t) R(t)
1313+
rate₁ = β * S * I
1314+
affect₁ = [S ~ S - 1, I ~ I + 1]
1315+
rate₂ = γ * I
1316+
affect₂ = [I ~ I - 1, R ~ R + 1]
1317+
j₁ = ConstantRateJump(rate₁, affect₁)
1318+
j₂ = ConstantRateJump(rate₂, affect₂)
1319+
j₃ = MassActionJump(2 * β + γ, [R => 1], [S => 1, R => -1])
1320+
@mtkbuild js = JumpSystem([j₁, j₂, j₃], t, [S, I, R], [β, γ, S0])
1321+
1322+
u0s = [I => 1, R => 0]
1323+
ps = [S0 => 999, β => 0.01, γ => 0.001]
1324+
dprob = DiscreteProblem(js, u0s, (0.0, 10.0), ps)
1325+
@test dprob.f.initialization_data !== nothing
1326+
sol = solve(dprob, FunctionMap())
1327+
@test sol[S, 1] 999
1328+
@test SciMLBase.successful_retcode(sol)
1329+
1330+
jprob = JumpProblem(js, dprob)
1331+
sol = solve(jprob, SSAStepper())
1332+
@test sol[S, 1] 999
1333+
@test SciMLBase.successful_retcode(sol)
1334+
end

0 commit comments

Comments
 (0)