Skip to content

Commit d9a401f

Browse files
committed
Support observed function building for split parameters
1 parent 983aaac commit d9a401f

File tree

3 files changed

+51
-15
lines changed

3 files changed

+51
-15
lines changed

src/systems/diffeqs/abstractodesystem.jl

+39-7
Original file line numberDiff line numberDiff line change
@@ -397,32 +397,64 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = s
397397

398398
obs = observed(sys)
399399
observedfun = if steady_state
400-
let sys = sys, dict = Dict()
400+
let sys = sys, dict = Dict(), ps = ps
401401
function generated_observed(obsvar, args...)
402402
obs = get!(dict, value(obsvar)) do
403403
build_explicit_observed_function(sys, obsvar)
404404
end
405405
if args === ()
406406
let obs = obs
407-
(u, p, t = Inf) -> obs(u, p, t)
407+
(u, p, t = Inf) -> if ps isa Tuple
408+
obs(u, p..., t)
409+
else
410+
obs(u, p, t)
411+
end
408412
end
409413
else
410-
length(args) == 2 ? obs(args..., Inf) : obs(args...)
414+
if ps isa Tuple
415+
if length(args) == 2
416+
u, p = args
417+
obs(u, p..., Inf)
418+
else
419+
u, p, t = args
420+
obs(u, p..., t)
421+
end
422+
else
423+
if length(args) == 2
424+
u, p = args
425+
obs(u, p, Inf)
426+
else
427+
u, p, t = args
428+
obs(u, p, t)
429+
end
430+
end
411431
end
412432
end
413433
end
414434
else
415-
let sys = sys, dict = Dict()
435+
let sys = sys, dict = Dict(), ps = ps
416436
function generated_observed(obsvar, args...)
417437
obs = get!(dict, value(obsvar)) do
418-
build_explicit_observed_function(sys, obsvar; checkbounds = checkbounds)
438+
build_explicit_observed_function(sys,
439+
obsvar;
440+
checkbounds = checkbounds,
441+
ps)
419442
end
420443
if args === ()
421444
let obs = obs
422-
(u, p, t) -> obs(u, p, t)
445+
(u, p, t) -> if ps isa Tuple
446+
obs(u, p..., t)
447+
else
448+
obs(u, p, t)
449+
end
423450
end
424451
else
425-
obs(args...)
452+
if ps isa Tuple # split parameters
453+
u, p, t = args
454+
obs(u, p..., t)
455+
else
456+
obs(args...)
457+
end
426458
end
427459
end
428460
end

src/systems/diffeqs/odesystem.jl

+8-3
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,7 @@ function build_explicit_observed_function(sys, ts;
314314
output_type = Array,
315315
checkbounds = true,
316316
drop_expr = drop_expr,
317+
ps = paramteres(sys),
317318
throw = true)
318319
if (isscalar = !(ts isa AbstractVector))
319320
ts = [ts]
@@ -389,13 +390,17 @@ function build_explicit_observed_function(sys, ts;
389390
if inputs !== nothing
390391
pars = setdiff(pars, inputs) # Inputs have been converted to parameters by io_preprocessing, remove those from the parameter list
391392
end
392-
ps = DestructuredArgs(pars, inbounds = !checkbounds)
393+
if ps isa Tuple
394+
ps = DestructuredArgs.(ps, inbounds = !checkbounds)
395+
else
396+
ps = (DestructuredArgs(ps, inbounds = !checkbounds),)
397+
end
393398
dvs = DestructuredArgs(states(sys), inbounds = !checkbounds)
394399
if inputs === nothing
395-
args = [dvs, ps, ivs...]
400+
args = [dvs, ps..., ivs...]
396401
else
397402
ipts = DestructuredArgs(inputs, inbounds = !checkbounds)
398-
args = [dvs, ipts, ps, ivs...]
403+
args = [dvs, ipts, ps..., ivs...]
399404
end
400405
pre = get_postprocess_fbody(sys)
401406

test/split_parameters.jl

+4-5
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ eqs = [D(y) ~ dy * a
5959
D(dy) ~ ddy * b
6060
ddy ~ sin(t) * c]
6161

62-
@named sys = ODESystem(eqs, t, vars, pars)
63-
sys = structural_simplify(sys)
62+
@named model = ODESystem(eqs, t, vars, pars)
63+
sys = structural_simplify(model)
6464

6565
tspan = (0.0, t_end)
6666
prob = ODEProblem(sys, [], tspan, [])
@@ -76,8 +76,7 @@ prob = ODEProblem(sys, [], tspan, []; tofloat = false)
7676
@test prob.p isa Tuple{Vector{Float64}, Vector{Int64}}
7777
sol = solve(prob, ImplicitEuler());
7878
@test sol.retcode == ReturnCode.Success
79-
80-
79+
sol[states(model)]
8180

8281
# ------------------------- Observables
8382

@@ -94,4 +93,4 @@ sys = structural_simplify(model)
9493
prob = ODEProblem(sys, Pair[int.x => 0.0], (0.0, 1.0))
9594
sol = solve(prob, Rodas4())
9695
@test sol.retcode == ReturnCode.Success
97-
sol[absb.output.u]
96+
sol[absb.output.u]

0 commit comments

Comments
 (0)