Skip to content

Commit b701f62

Browse files
Merge pull request #3389 from AayushSabharwal/as/vector-ps
fix: allow vector of parameters for split system of pure tunables
2 parents f051109 + 11e3945 commit b701f62

12 files changed

+109
-110
lines changed

src/systems/abstractsystem.jl

+2-33
Original file line numberDiff line numberDiff line change
@@ -422,28 +422,7 @@ function SymbolicIndexingInterface.timeseries_parameter_index(sys::AbstractSyste
422422
end
423423

424424
function SymbolicIndexingInterface.parameter_observed(sys::AbstractSystem, sym)
425-
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
426-
rawobs = build_explicit_observed_function(
427-
sys, sym; param_only = true, return_inplace = true)
428-
if rawobs isa Tuple
429-
if is_time_dependent(sys)
430-
obsfn = let oop = rawobs[1], iip = rawobs[2]
431-
f1a(p, t) = oop(p, t)
432-
f1a(out, p, t) = iip(out, p, t)
433-
end
434-
else
435-
obsfn = let oop = rawobs[1], iip = rawobs[2]
436-
f1b(p) = oop(p)
437-
f1b(out, p) = iip(out, p)
438-
end
439-
end
440-
else
441-
obsfn = rawobs
442-
end
443-
else
444-
obsfn = build_explicit_observed_function(sys, sym; param_only = true)
445-
end
446-
return obsfn
425+
return build_explicit_observed_function(sys, sym; param_only = true)
447426
end
448427

449428
function has_observed_with_lhs(sys, sym)
@@ -579,18 +558,8 @@ function SymbolicIndexingInterface.observed(
579558
end
580559
end
581560
end
582-
_fn = build_explicit_observed_function(
561+
return build_explicit_observed_function(
583562
sys, sym; eval_expression, eval_module, checkbounds)
584-
585-
if is_time_dependent(sys)
586-
return _fn
587-
else
588-
return let _fn = _fn
589-
fn2(u, p) = _fn(u, p)
590-
fn2(::Nothing, p) = _fn([], p)
591-
fn2
592-
end
593-
end
594563
end
595564

596565
function SymbolicIndexingInterface.default_values(sys::AbstractSystem)

src/systems/codegen_utils.jl

+56
Original file line numberDiff line numberDiff line change
@@ -230,3 +230,59 @@ function build_function_wrapper(sys::AbstractSystem, expr, args...; p_start = 2,
230230
end
231231
return build_function(expr, args...; wrap_code, similarto, kwargs...)
232232
end
233+
234+
"""
235+
$(TYPEDEF)
236+
237+
A wrapper around a generated in-place and out-of-place function. The type-parameter `P`
238+
must be a 3-tuple where the first element is the index of the parameter object in the
239+
arguments, the second is the expected number of arguments in the out-of-place variant
240+
of the function, and the third is a boolean indicating whether the generated functions
241+
are for a split system. For scalar functions, the inplace variant can be `nothing`.
242+
"""
243+
struct GeneratedFunctionWrapper{P, O, I} <: Function
244+
f_oop::O
245+
f_iip::I
246+
end
247+
248+
function GeneratedFunctionWrapper{P}(foop::O, fiip::I) where {P, O, I}
249+
GeneratedFunctionWrapper{P, O, I}(foop, fiip)
250+
end
251+
252+
function (gfw::GeneratedFunctionWrapper)(args...)
253+
_generated_call(gfw, args...)
254+
end
255+
256+
@generated function _generated_call(gfw::GeneratedFunctionWrapper{P}, args...) where {P}
257+
paramidx, nargs, issplit = P
258+
iip = false
259+
# IIP case has one more argument
260+
if length(args) == nargs + 1
261+
nargs += 1
262+
paramidx += 1
263+
iip = true
264+
end
265+
if length(args) != nargs
266+
throw(ArgumentError("Expected $nargs arguments, got $(length(args))."))
267+
end
268+
269+
# the function to use
270+
f = iip ? :(gfw.f_iip) : :(gfw.f_oop)
271+
# non-split systems just call it as-is
272+
if !issplit
273+
return :($f(args...))
274+
end
275+
if args[paramidx] <: Union{Tuple, MTKParameters} &&
276+
!(args[paramidx] <: Tuple{Vararg{Number}})
277+
# for split systems, call it as-is if the parameter object is a tuple or MTKParameters
278+
# but not if it is a tuple of numbers
279+
return :($f(args...))
280+
else
281+
# The user provided a single buffer/tuple for the parameter object, so wrap that
282+
# one in a tuple
283+
fargs = ntuple(Val(length(args))) do i
284+
i == paramidx ? :((args[$i],)) : :(args[$i])
285+
end
286+
return :($f($(fargs...)))
287+
end
288+
end

src/systems/diffeqs/abstractodesystem.jl

+15-37
Original file line numberDiff line numberDiff line change
@@ -320,9 +320,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
320320
expression_module = eval_module, checkbounds = checkbounds,
321321
kwargs...)
322322
f_oop, f_iip = eval_or_rgf.(f_gen; eval_expression, eval_module)
323-
324-
f(u, p, t) = f_oop(u, p, t)
325-
f(du, u, p, t) = f_iip(du, u, p, t)
323+
f = GeneratedFunctionWrapper{(2, 3, is_split(sys))}(f_oop, f_iip)
326324

327325
if specialize === SciMLBase.FunctionWrapperSpecialize && iip
328326
if u0 === nothing || p === nothing || t === nothing
@@ -338,10 +336,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
338336
expression_module = eval_module,
339337
checkbounds = checkbounds, kwargs...)
340338
tgrad_oop, tgrad_iip = eval_or_rgf.(tgrad_gen; eval_expression, eval_module)
341-
342-
___tgrad(u, p, t) = tgrad_oop(u, p, t)
343-
___tgrad(J, u, p, t) = tgrad_iip(J, u, p, t)
344-
_tgrad = ___tgrad
339+
_tgrad = GeneratedFunctionWrapper{(2, 3, is_split(sys))}(tgrad_oop, tgrad_iip)
345340
else
346341
_tgrad = nothing
347342
end
@@ -354,8 +349,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
354349
checkbounds = checkbounds, kwargs...)
355350
jac_oop, jac_iip = eval_or_rgf.(jac_gen; eval_expression, eval_module)
356351

357-
_jac(u, p, t) = jac_oop(u, p, t)
358-
_jac(J, u, p, t) = jac_iip(J, u, p, t)
352+
_jac = GeneratedFunctionWrapper{(2, 3, is_split(sys))}(jac_oop, jac_iip)
359353
else
360354
_jac = nothing
361355
end
@@ -435,8 +429,7 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
435429
expression_module = eval_module, checkbounds = checkbounds,
436430
kwargs...)
437431
f_oop, f_iip = eval_or_rgf.(f_gen; eval_expression, eval_module)
438-
f(du, u, p, t) = f_oop(du, u, p, t)
439-
f(out, du, u, p, t) = f_iip(out, du, u, p, t)
432+
f = GeneratedFunctionWrapper{(3, 4, is_split(sys))}(f_oop, f_iip)
440433

441434
if jac
442435
jac_gen = generate_dae_jacobian(sys, dvs, ps;
@@ -446,8 +439,7 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
446439
checkbounds = checkbounds, kwargs...)
447440
jac_oop, jac_iip = eval_or_rgf.(jac_gen; eval_expression, eval_module)
448441

449-
_jac(du, u, p, ˍ₋gamma, t) = jac_oop(du, u, p, ˍ₋gamma, t)
450-
_jac(J, du, u, p, ˍ₋gamma, t) = jac_iip(J, du, u, p, ˍ₋gamma, t)
442+
_jac = GeneratedFunctionWrapper{(3, 5, is_split(sys))}(jac_oop, jac_iip)
451443
else
452444
_jac = nothing
453445
end
@@ -496,8 +488,7 @@ function DiffEqBase.DDEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
496488
expression_module = eval_module, checkbounds = checkbounds,
497489
kwargs...)
498490
f_oop, f_iip = eval_or_rgf.(f_gen; eval_expression, eval_module)
499-
f(u, h, p, t) = f_oop(u, h, p, t)
500-
f(du, u, h, p, t) = f_iip(du, u, h, p, t)
491+
f = GeneratedFunctionWrapper{(3, 4, is_split(sys))}(f_oop, f_iip)
501492

502493
DDEFunction{iip}(f; sys = sys, initialization_data)
503494
end
@@ -521,14 +512,12 @@ function DiffEqBase.SDDEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys
521512
expression_module = eval_module, checkbounds = checkbounds,
522513
kwargs...)
523514
f_oop, f_iip = eval_or_rgf.(f_gen; eval_expression, eval_module)
524-
f(u, h, p, t) = f_oop(u, h, p, t)
525-
f(du, u, h, p, t) = f_iip(du, u, h, p, t)
515+
f = GeneratedFunctionWrapper{(3, 4, is_split(sys))}(f_oop, f_iip)
526516

527517
g_gen = generate_diffusion_function(sys, dvs, ps; expression = Val{true},
528518
isdde = true, kwargs...)
529519
g_oop, g_iip = eval_or_rgf.(g_gen; eval_expression, eval_module)
530-
g(u, h, p, t) = g_oop(u, h, p, t)
531-
g(du, u, h, p, t) = g_iip(du, u, h, p, t)
520+
g = GeneratedFunctionWrapper{(3, 4, is_split(sys))}(g_oop, g_iip)
532521

533522
SDDEFunction{iip}(f, g; sys = sys, initialization_data)
534523
end
@@ -549,13 +538,6 @@ variable and parameter vectors, respectively.
549538
"""
550539
struct ODEFunctionExpr{iip, specialize} end
551540

552-
struct ODEFunctionClosure{O, I} <: Function
553-
f_oop::O
554-
f_iip::I
555-
end
556-
(f::ODEFunctionClosure)(u, p, t) = f.f_oop(u, p, t)
557-
(f::ODEFunctionClosure)(du, u, p, t) = f.f_iip(du, u, p, t)
558-
559541
function ODEFunctionExpr{iip, specialize}(sys::AbstractODESystem, dvs = unknowns(sys),
560542
ps = parameters(sys), u0 = nothing;
561543
version = nothing, tgrad = false,
@@ -572,13 +554,14 @@ function ODEFunctionExpr{iip, specialize}(sys::AbstractODESystem, dvs = unknowns
572554
f_oop, f_iip = generate_function(sys, dvs, ps; expression = Val{true}, kwargs...)
573555

574556
fsym = gensym(:f)
575-
_f = :($fsym = $ODEFunctionClosure($f_oop, $f_iip))
557+
_f = :($fsym = $(GeneratedFunctionWrapper{(2, 3, is_split(sys))})($f_oop, $f_iip))
576558
tgradsym = gensym(:tgrad)
577559
if tgrad
578560
tgrad_oop, tgrad_iip = generate_tgrad(sys, dvs, ps;
579561
simplify = simplify,
580562
expression = Val{true}, kwargs...)
581-
_tgrad = :($tgradsym = $ODEFunctionClosure($tgrad_oop, $tgrad_iip))
563+
_tgrad = :($tgradsym = $(GeneratedFunctionWrapper{(2, 3, is_split(sys))})(
564+
$tgrad_oop, $tgrad_iip))
582565
else
583566
_tgrad = :($tgradsym = nothing)
584567
end
@@ -588,7 +571,8 @@ function ODEFunctionExpr{iip, specialize}(sys::AbstractODESystem, dvs = unknowns
588571
jac_oop, jac_iip = generate_jacobian(sys, dvs, ps;
589572
sparse = sparse, simplify = simplify,
590573
expression = Val{true}, kwargs...)
591-
_jac = :($jacsym = $ODEFunctionClosure($jac_oop, $jac_iip))
574+
_jac = :($jacsym = $(GeneratedFunctionWrapper{(2, 3, is_split(sys))})(
575+
$jac_oop, $jac_iip))
592576
else
593577
_jac = :($jacsym = nothing)
594578
end
@@ -647,13 +631,6 @@ variable and parameter vectors, respectively.
647631
"""
648632
struct DAEFunctionExpr{iip} end
649633

650-
struct DAEFunctionClosure{O, I} <: Function
651-
f_oop::O
652-
f_iip::I
653-
end
654-
(f::DAEFunctionClosure)(du, u, p, t) = f.f_oop(du, u, p, t)
655-
(f::DAEFunctionClosure)(out, du, u, p, t) = f.f_iip(out, du, u, p, t)
656-
657634
function DAEFunctionExpr{iip}(sys::AbstractODESystem, dvs = unknowns(sys),
658635
ps = parameters(sys), u0 = nothing;
659636
version = nothing, tgrad = false,
@@ -667,7 +644,7 @@ function DAEFunctionExpr{iip}(sys::AbstractODESystem, dvs = unknowns(sys),
667644
f_oop, f_iip = generate_function(sys, dvs, ps; expression = Val{true},
668645
implicit_dae = true, kwargs...)
669646
fsym = gensym(:f)
670-
_f = :($fsym = $DAEFunctionClosure($f_oop, $f_iip))
647+
_f = :($fsym = $(GeneratedFunctionWrapper{(3, 4, is_split(sys))})($f_oop, $f_iip))
671648
ex = quote
672649
$_f
673650
ODEFunction{$iip}($fsym)
@@ -708,6 +685,7 @@ function SymbolicTstops(
708685
expression = Val{true},
709686
p_start = 1, p_end = length(rps), add_observed = false, force_SA = true)
710687
tstops = eval_or_rgf(tstops; eval_expression, eval_module)
688+
tstops = GeneratedFunctionWrapper{(1, 3, is_split(sys))}(tstops, nothing)
711689
return SymbolicTstops(tstops)
712690
end
713691

src/systems/diffeqs/odesystem.jl

+7-2
Original file line numberDiff line numberDiff line change
@@ -521,9 +521,14 @@ function build_explicit_observed_function(sys, ts;
521521
output_type, mkarray, try_namespaced = true, expression = Val{true})
522522
if fns isa Tuple
523523
oop, iip = eval_or_rgf.(fns; eval_expression, eval_module)
524-
return return_inplace ? (oop, iip) : oop
524+
f = GeneratedFunctionWrapper{(
525+
p_start, length(args) - length(ps) + 1, is_split(sys))}(oop, iip)
526+
return return_inplace ? (f, f) : f
525527
else
526-
return eval_or_rgf(fns; eval_expression, eval_module)
528+
f = eval_or_rgf(fns; eval_expression, eval_module)
529+
f = GeneratedFunctionWrapper{(
530+
p_start, length(args) - length(ps) + 1, is_split(sys))}(f, nothing)
531+
return f
527532
end
528533
end
529534

src/systems/diffeqs/sdesystem.jl

+6-13
Original file line numberDiff line numberDiff line change
@@ -604,18 +604,14 @@ function DiffEqBase.SDEFunction{iip, specialize}(sys::SDESystem, dvs = unknowns(
604604
kwargs...)
605605
g_oop, g_iip = eval_or_rgf.(g_gen; eval_expression, eval_module)
606606

607-
f(u, p, t) = f_oop(u, p, t)
608-
f(du, u, p, t) = f_iip(du, u, p, t)
609-
g(u, p, t) = g_oop(u, p, t)
610-
g(du, u, p, t) = g_iip(du, u, p, t)
607+
f = GeneratedFunctionWrapper{(2, 3, is_split(sys))}(f_oop, f_iip)
608+
g = GeneratedFunctionWrapper{(2, 3, is_split(sys))}(g_oop, g_iip)
611609

612610
if tgrad
613611
tgrad_gen = generate_tgrad(sys, dvs, ps; expression = Val{true},
614612
kwargs...)
615613
tgrad_oop, tgrad_iip = eval_or_rgf.(tgrad_gen; eval_expression, eval_module)
616-
617-
_tgrad(u, p, t) = tgrad_oop(u, p, t)
618-
_tgrad(J, u, p, t) = tgrad_iip(J, u, p, t)
614+
_tgrad = GeneratedFunctionWrapper{(2, 3, is_split(sys))}(tgrad_oop, tgrad_iip)
619615
else
620616
_tgrad = nothing
621617
end
@@ -625,8 +621,7 @@ function DiffEqBase.SDEFunction{iip, specialize}(sys::SDESystem, dvs = unknowns(
625621
sparse = sparse, kwargs...)
626622
jac_oop, jac_iip = eval_or_rgf.(jac_gen; eval_expression, eval_module)
627623

628-
_jac(u, p, t) = jac_oop(u, p, t)
629-
_jac(J, u, p, t) = jac_iip(J, u, p, t)
624+
_jac = GeneratedFunctionWrapper{(2, 3, is_split(sys))}(jac_oop, jac_iip)
630625
else
631626
_jac = nothing
632627
end
@@ -637,10 +632,8 @@ function DiffEqBase.SDEFunction{iip, specialize}(sys::SDESystem, dvs = unknowns(
637632
Wfact_oop, Wfact_iip = eval_or_rgf.(tmp_Wfact; eval_expression, eval_module)
638633
Wfact_oop_t, Wfact_iip_t = eval_or_rgf.(tmp_Wfact_t; eval_expression, eval_module)
639634

640-
_Wfact(u, p, dtgamma, t) = Wfact_oop(u, p, dtgamma, t)
641-
_Wfact(W, u, p, dtgamma, t) = Wfact_iip(W, u, p, dtgamma, t)
642-
_Wfact_t(u, p, dtgamma, t) = Wfact_oop_t(u, p, dtgamma, t)
643-
_Wfact_t(W, u, p, dtgamma, t) = Wfact_iip_t(W, u, p, dtgamma, t)
635+
_Wfact = GeneratedFunctionWrapper{(2, 4, is_split(sys))}(Wfact_oop, Wfact_iip)
636+
_Wfact_t = GeneratedFunctionWrapper{(2, 4, is_split(sys))}(Wfact_oop_t, Wfact_iip_t)
644637
else
645638
_Wfact, _Wfact_t = nothing, nothing
646639
end

src/systems/discrete_system/discrete_system.jl

+1-2
Original file line numberDiff line numberDiff line change
@@ -354,8 +354,7 @@ function SciMLBase.DiscreteFunction{iip, specialize}(
354354
f_gen = generate_function(sys, dvs, ps; expression = Val{true},
355355
expression_module = eval_module, kwargs...)
356356
f_oop, f_iip = eval_or_rgf.(f_gen; eval_expression, eval_module)
357-
f(u, p, t) = f_oop(u, p, t)
358-
f(du, u, p, t) = f_iip(du, u, p, t)
357+
f = GeneratedFunctionWrapper{(2, 3, is_split(sys))}(f_oop, f_iip)
359358

360359
if specialize === SciMLBase.FunctionWrapperSpecialize && iip
361360
if u0 === nothing || p === nothing || t === nothing

src/systems/jumps/jumpsystem.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ function generate_rate_function(js::JumpSystem, rate)
283283
rate = substitute(rate, csubs)
284284
end
285285
p = reorder_parameters(js)
286-
rf = build_function_wrapper(js, rate, unknowns(js), p...,
286+
build_function_wrapper(js, rate, unknowns(js), p...,
287287
get_iv(js),
288288
expression = Val{true})
289289
end
@@ -302,7 +302,7 @@ end
302302
function assemble_vrj(
303303
js, vrj, unknowntoid; eval_expression = false, eval_module = @__MODULE__)
304304
rate = eval_or_rgf(generate_rate_function(js, vrj.rate); eval_expression, eval_module)
305-
305+
rate = GeneratedFunctionWrapper{(2, 3, is_split(js))}(rate, nothing)
306306
outputvars = (value(affect.lhs) for affect in vrj.affect!)
307307
outputidxs = [unknowntoid[var] for var in outputvars]
308308
affect = eval_or_rgf(generate_affect_function(js, vrj.affect!, outputidxs);
@@ -326,7 +326,7 @@ end
326326
function assemble_crj(
327327
js, crj, unknowntoid; eval_expression = false, eval_module = @__MODULE__)
328328
rate = eval_or_rgf(generate_rate_function(js, crj.rate); eval_expression, eval_module)
329-
329+
rate = GeneratedFunctionWrapper{(2, 3, is_split(js))}(rate, nothing)
330330
outputvars = (value(affect.lhs) for affect in crj.affect!)
331331
outputidxs = [unknowntoid[var] for var in outputvars]
332332
affect = eval_or_rgf(generate_affect_function(js, crj.affect!, outputidxs);

0 commit comments

Comments
 (0)