Skip to content

Commit 387df59

Browse files
Merge pull request #3368 from BenChung/fix-parameter-arrays
Fix parameter array handling in ImperativeAffect
2 parents 935ad12 + 2c1997b commit 387df59

File tree

2 files changed

+63
-1
lines changed

2 files changed

+63
-1
lines changed

src/systems/imperative_affect.jl

+6-1
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,12 @@ func(f::ImperativeAffect) = f.f
7575
context(a::ImperativeAffect) = a.ctx
7676
observed(a::ImperativeAffect) = a.obs
7777
observed_syms(a::ImperativeAffect) = a.obs_syms
78-
discretes(a::ImperativeAffect) = filter(ModelingToolkit.isparameter, a.modified)
78+
function discretes(a::ImperativeAffect)
79+
Iterators.filter(ModelingToolkit.isparameter,
80+
Iterators.flatten(Iterators.map(
81+
x -> symbolic_type(x) == NotSymbolic() && x isa AbstractArray ? x : [x],
82+
a.modified)))
83+
end
7984
modified(a::ImperativeAffect) = a.modified
8085
modified_syms(a::ImperativeAffect) = a.mod_syms
8186

test/symbolic_events.jl

+57
Original file line numberDiff line numberDiff line change
@@ -1377,3 +1377,60 @@ end
13771377
prob = ODEProblem(decay, [], (0.0, 10.0), [])
13781378
@test_nowarn solve(prob, Tsit5(), tstops = [1.0])
13791379
end
1380+
1381+
@testset "Array parameter updates in ImperativeEffect" begin
1382+
function weird1(max_time; name)
1383+
params = @parameters begin
1384+
θ(t) = 0.0
1385+
end
1386+
vars = @variables begin
1387+
x(t) = 0.0
1388+
end
1389+
eqs = reduce(vcat, Symbolics.scalarize.([
1390+
D(x) ~ 1.0
1391+
]))
1392+
reset = ModelingToolkit.ImperativeAffect(
1393+
modified = (; x, θ)) do m, o, _, i
1394+
@set! m.θ = 0.0
1395+
@set! m.x = 0.0
1396+
return m
1397+
end
1398+
return ODESystem(eqs, t, vars, params; name = name,
1399+
continuous_events = [[x ~ max_time] => reset])
1400+
end
1401+
1402+
function weird2(max_time; name)
1403+
params = @parameters begin
1404+
θ(t) = 0.0
1405+
end
1406+
vars = @variables begin
1407+
x(t) = 0.0
1408+
end
1409+
eqs = reduce(vcat, Symbolics.scalarize.([
1410+
D(x) ~ 1.0
1411+
]))
1412+
return ODESystem(eqs, t, vars, params; name = name) # note no event
1413+
end
1414+
1415+
@named wd1 = weird1(0.021)
1416+
@named wd2 = weird2(0.021)
1417+
1418+
sys1 = structural_simplify(ODESystem([], t; name = :parent,
1419+
discrete_events = [0.01 => ModelingToolkit.ImperativeAffect(
1420+
modified = (; θs = reduce(vcat, [[wd1.θ]])), ctx = [1]) do m, o, c, i
1421+
@set! m.θs[1] = c[] += 1
1422+
end],
1423+
systems = [wd1]))
1424+
sys2 = structural_simplify(ODESystem([], t; name = :parent,
1425+
discrete_events = [0.01 => ModelingToolkit.ImperativeAffect(
1426+
modified = (; θs = reduce(vcat, [[wd2.θ]])), ctx = [1]) do m, o, c, i
1427+
@set! m.θs[1] = c[] += 1
1428+
end],
1429+
systems = [wd2]))
1430+
1431+
sol1 = solve(ODEProblem(sys1, [], (0.0, 1.0)), Tsit5())
1432+
@test 100.0 sol1[sys1.wd1.θ]
1433+
1434+
sol2 = solve(ODEProblem(sys2, [], (0.0, 1.0)), Tsit5())
1435+
@test 100.0 sol2[sys2.wd2.θ]
1436+
end

0 commit comments

Comments
 (0)