Skip to content

Commit d0bb07a

Browse files
Merge pull request #453 from SciML/jumpevalfunc
Generate JumpProblem with EvalFunc to avoid GG
2 parents c7112d3 + cfb489b commit d0bb07a

File tree

1 file changed

+23
-17
lines changed

1 file changed

+23
-17
lines changed

src/systems/jumps/jumpsystem.jl

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ js = JumpSystem([j₁,j₂,j₃], t, [S,I,R], [β,γ])
2727
"""
2828
struct JumpSystem{U <: ArrayPartition} <: AbstractSystem
2929
"""
30-
The jumps of the system. Allowable types are `ConstantRateJump`,
30+
The jumps of the system. Allowable types are `ConstantRateJump`,
3131
`VariableRateJump`, `MassActionJump`.
3232
"""
3333
eqs::U
@@ -48,7 +48,7 @@ function JumpSystem(eqs, iv, states, ps; systems = JumpSystem[],
4848

4949
ap = ArrayPartition(MassActionJump[], ConstantRateJump[], VariableRateJump[])
5050
for eq in eqs
51-
if eq isa MassActionJump
51+
if eq isa MassActionJump
5252
push!(ap.x[1], eq)
5353
elseif eq isa ConstantRateJump
5454
push!(ap.x[2], eq)
@@ -63,17 +63,19 @@ function JumpSystem(eqs, iv, states, ps; systems = JumpSystem[],
6363
end
6464

6565

66-
generate_rate_function(js, rate) = build_function(rate, states(js), parameters(js),
67-
independent_variable(js),
68-
expression=Val{false})
66+
generate_rate_function(js, rate) = ModelingToolkit.eval(
67+
build_function(rate, states(js), parameters(js),
68+
independent_variable(js),
69+
expression=Val{true}))
6970

70-
generate_affect_function(js, affect, outputidxs) = build_function(affect, states(js),
71+
generate_affect_function(js, affect, outputidxs) = ModelingToolkit.eval(
72+
build_function(affect, states(js),
7173
parameters(js),
7274
independent_variable(js),
73-
expression=Val{false},
75+
expression=Val{true},
7476
headerfun=add_integrator_header,
75-
outputidxs=outputidxs)[2]
76-
77+
outputidxs=outputidxs)[2])
78+
7779
function assemble_vrj(js, vrj, statetoid)
7880
rate = generate_rate_function(js, vrj.rate)
7981
outputvars = (convert(Variable,affect.lhs) for affect in vrj.affect!)
@@ -124,11 +126,13 @@ end
124126

125127
"""
126128
```julia
127-
function DiffEqBase.DiscreteProblem(sys::AbstractSystem, u0map, tspan,
129+
function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan,
128130
parammap=DiffEqBase.NullParameters; kwargs...)
129131
```
130132
131-
Generates a DiscreteProblem from an AbstractSystem.
133+
Generates a black DiscreteProblem for a JumpSystem to utilize as its
134+
solving `prob.prob`. This is used in the case where there are no ODEs
135+
and no SDEs associated with the system.
132136
133137
Continuing the example from the [`JumpSystem`](@ref) definition:
134138
```julia
@@ -139,11 +143,13 @@ tspan = (0.0, 250.0)
139143
dprob = DiscreteProblem(js, u₀map, tspan, parammap)
140144
```
141145
"""
142-
function DiffEqBase.DiscreteProblem(sys::AbstractSystem, u0map, tspan::Tuple,
146+
function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan::Tuple,
143147
parammap=DiffEqBase.NullParameters(); kwargs...)
144148
u0 = varmap_to_vars(u0map, states(sys))
145149
p = varmap_to_vars(parammap, parameters(sys))
146-
f = (du,u,p,t) -> du.=u # identity function to make syms works
150+
# identity function to make syms works
151+
# EvalFunc because we know that the jump functions are generated via eval
152+
f = DiffEqBase.EvalFunc(DiffEqBase.DISCRETE_INPLACE_DEFAULT)
147153
df = DiscreteFunction(f, syms=Symbol.(states(sys)))
148154
DiscreteProblem(df, u0, tspan, p; kwargs...)
149155
end
@@ -164,19 +170,19 @@ sol = solve(jprob, SSAStepper())
164170
"""
165171
function DiffEqJump.JumpProblem(js::JumpSystem, prob, aggregator; kwargs...)
166172

167-
statetoid = Dict(convert(Variable,state) => i for (i,state) in enumerate(states(js)))
173+
statetoid = Dict(convert(Variable,state) => i for (i,state) in enumerate(states(js)))
168174
eqs = equations(js)
169175
invttype = typeof(1 / prob.tspan[2])
170176

171177
# handling parameter substition and empty param vecs
172178
p = (prob.p == DiffEqBase.NullParameters()) ? Operation[] : prob.p
173179
parammap = map((x,y)->Pair(x(),y), parameters(js), p)
174180
subber = substituter(first.(parammap), last.(parammap))
175-
181+
176182
majs = MassActionJump[assemble_maj(js, j, statetoid, subber, invttype) for j in eqs.x[1]]
177183
crjs = ConstantRateJump[assemble_crj(js, j, statetoid) for j in eqs.x[2]]
178184
vrjs = VariableRateJump[assemble_vrj(js, j, statetoid) for j in eqs.x[3]]
179-
((prob isa DiscreteProblem) && !isempty(vrjs)) && error("Use continuous problems such as an ODEProblem or a SDEProblem with VariableRateJumps")
185+
((prob isa DiscreteProblem) && !isempty(vrjs)) && error("Use continuous problems such as an ODEProblem or a SDEProblem with VariableRateJumps")
180186
jset = JumpSet(Tuple(vrjs), Tuple(crjs), nothing, isempty(majs) ? nothing : majs)
181187

182188
if needs_vartojumps_map(aggregator) || needs_depgraph(aggregator)
@@ -207,7 +213,7 @@ end
207213
### Functions to determine which states are modified by a given jump
208214
function modified_states!(mstates, jump::Union{ConstantRateJump,VariableRateJump}, sts)
209215
for eq in jump.affect!
210-
st = eq.lhs
216+
st = eq.lhs
211217
(st.op in sts) && push!(mstates, st)
212218
end
213219
end

0 commit comments

Comments
 (0)