Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: allow vector of parameters for split system of pure tunables #3389

Merged
merged 3 commits into from
Feb 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 2 additions & 33 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -422,28 +422,7 @@
end

function SymbolicIndexingInterface.parameter_observed(sys::AbstractSystem, sym)
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
rawobs = build_explicit_observed_function(
sys, sym; param_only = true, return_inplace = true)
if rawobs isa Tuple
if is_time_dependent(sys)
obsfn = let oop = rawobs[1], iip = rawobs[2]
f1a(p, t) = oop(p, t)
f1a(out, p, t) = iip(out, p, t)
end
else
obsfn = let oop = rawobs[1], iip = rawobs[2]
f1b(p) = oop(p)
f1b(out, p) = iip(out, p)
end
end
else
obsfn = rawobs
end
else
obsfn = build_explicit_observed_function(sys, sym; param_only = true)
end
return obsfn
return build_explicit_observed_function(sys, sym; param_only = true)
end

function has_observed_with_lhs(sys, sym)
Expand Down Expand Up @@ -579,18 +558,8 @@
end
end
end
_fn = build_explicit_observed_function(
return build_explicit_observed_function(
sys, sym; eval_expression, eval_module, checkbounds)

if is_time_dependent(sys)
return _fn
else
return let _fn = _fn
fn2(u, p) = _fn(u, p)
fn2(::Nothing, p) = _fn([], p)
fn2
end
end
end

function SymbolicIndexingInterface.default_values(sys::AbstractSystem)
Expand Down Expand Up @@ -653,7 +622,7 @@
"""
Initial(x)

The `Initial` operator. Used by initializaton to store constant constraints on variables

Check warning on line 625 in src/systems/abstractsystem.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"initializaton" should be "initialization".
of a system. See the documentation section on initialization for more information.
"""
struct Initial <: Symbolics.Operator end
Expand Down
56 changes: 56 additions & 0 deletions src/systems/codegen_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -230,3 +230,59 @@ function build_function_wrapper(sys::AbstractSystem, expr, args...; p_start = 2,
end
return build_function(expr, args...; wrap_code, similarto, kwargs...)
end

"""
$(TYPEDEF)

A wrapper around a generated in-place and out-of-place function. The type-parameter `P`
must be a 3-tuple where the first element is the index of the parameter object in the
arguments, the second is the expected number of arguments in the out-of-place variant
of the function, and the third is a boolean indicating whether the generated functions
are for a split system. For scalar functions, the inplace variant can be `nothing`.
"""
struct GeneratedFunctionWrapper{P, O, I} <: Function
f_oop::O
f_iip::I
end

function GeneratedFunctionWrapper{P}(foop::O, fiip::I) where {P, O, I}
GeneratedFunctionWrapper{P, O, I}(foop, fiip)
end

function (gfw::GeneratedFunctionWrapper)(args...)
_generated_call(gfw, args...)
end

@generated function _generated_call(gfw::GeneratedFunctionWrapper{P}, args...) where {P}
paramidx, nargs, issplit = P
iip = false
# IIP case has one more argument
if length(args) == nargs + 1
nargs += 1
paramidx += 1
iip = true
end
if length(args) != nargs
throw(ArgumentError("Expected $nargs arguments, got $(length(args))."))
end

# the function to use
f = iip ? :(gfw.f_iip) : :(gfw.f_oop)
# non-split systems just call it as-is
if !issplit
return :($f(args...))
end
if args[paramidx] <: Union{Tuple, MTKParameters} &&
!(args[paramidx] <: Tuple{Vararg{Number}})
# for split systems, call it as-is if the parameter object is a tuple or MTKParameters
# but not if it is a tuple of numbers
return :($f(args...))
else
# The user provided a single buffer/tuple for the parameter object, so wrap that
# one in a tuple
fargs = ntuple(Val(length(args))) do i
i == paramidx ? :((args[$i],)) : :(args[$i])
end
return :($f($(fargs...)))
end
end
52 changes: 15 additions & 37 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -320,9 +320,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
expression_module = eval_module, checkbounds = checkbounds,
kwargs...)
f_oop, f_iip = eval_or_rgf.(f_gen; eval_expression, eval_module)

f(u, p, t) = f_oop(u, p, t)
f(du, u, p, t) = f_iip(du, u, p, t)
f = GeneratedFunctionWrapper{(2, 3, is_split(sys))}(f_oop, f_iip)

if specialize === SciMLBase.FunctionWrapperSpecialize && iip
if u0 === nothing || p === nothing || t === nothing
Expand All @@ -338,10 +336,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
expression_module = eval_module,
checkbounds = checkbounds, kwargs...)
tgrad_oop, tgrad_iip = eval_or_rgf.(tgrad_gen; eval_expression, eval_module)

___tgrad(u, p, t) = tgrad_oop(u, p, t)
___tgrad(J, u, p, t) = tgrad_iip(J, u, p, t)
_tgrad = ___tgrad
_tgrad = GeneratedFunctionWrapper{(2, 3, is_split(sys))}(tgrad_oop, tgrad_iip)
else
_tgrad = nothing
end
Expand All @@ -354,8 +349,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
checkbounds = checkbounds, kwargs...)
jac_oop, jac_iip = eval_or_rgf.(jac_gen; eval_expression, eval_module)

_jac(u, p, t) = jac_oop(u, p, t)
_jac(J, u, p, t) = jac_iip(J, u, p, t)
_jac = GeneratedFunctionWrapper{(2, 3, is_split(sys))}(jac_oop, jac_iip)
else
_jac = nothing
end
Expand Down Expand Up @@ -435,8 +429,7 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
expression_module = eval_module, checkbounds = checkbounds,
kwargs...)
f_oop, f_iip = eval_or_rgf.(f_gen; eval_expression, eval_module)
f(du, u, p, t) = f_oop(du, u, p, t)
f(out, du, u, p, t) = f_iip(out, du, u, p, t)
f = GeneratedFunctionWrapper{(3, 4, is_split(sys))}(f_oop, f_iip)

if jac
jac_gen = generate_dae_jacobian(sys, dvs, ps;
Expand All @@ -446,8 +439,7 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
checkbounds = checkbounds, kwargs...)
jac_oop, jac_iip = eval_or_rgf.(jac_gen; eval_expression, eval_module)

_jac(du, u, p, ˍ₋gamma, t) = jac_oop(du, u, p, ˍ₋gamma, t)
_jac(J, du, u, p, ˍ₋gamma, t) = jac_iip(J, du, u, p, ˍ₋gamma, t)
_jac = GeneratedFunctionWrapper{(3, 5, is_split(sys))}(jac_oop, jac_iip)
else
_jac = nothing
end
Expand Down Expand Up @@ -496,8 +488,7 @@ function DiffEqBase.DDEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
expression_module = eval_module, checkbounds = checkbounds,
kwargs...)
f_oop, f_iip = eval_or_rgf.(f_gen; eval_expression, eval_module)
f(u, h, p, t) = f_oop(u, h, p, t)
f(du, u, h, p, t) = f_iip(du, u, h, p, t)
f = GeneratedFunctionWrapper{(3, 4, is_split(sys))}(f_oop, f_iip)

DDEFunction{iip}(f; sys = sys, initialization_data)
end
Expand All @@ -521,14 +512,12 @@ function DiffEqBase.SDDEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys
expression_module = eval_module, checkbounds = checkbounds,
kwargs...)
f_oop, f_iip = eval_or_rgf.(f_gen; eval_expression, eval_module)
f(u, h, p, t) = f_oop(u, h, p, t)
f(du, u, h, p, t) = f_iip(du, u, h, p, t)
f = GeneratedFunctionWrapper{(3, 4, is_split(sys))}(f_oop, f_iip)

g_gen = generate_diffusion_function(sys, dvs, ps; expression = Val{true},
isdde = true, kwargs...)
g_oop, g_iip = eval_or_rgf.(g_gen; eval_expression, eval_module)
g(u, h, p, t) = g_oop(u, h, p, t)
g(du, u, h, p, t) = g_iip(du, u, h, p, t)
g = GeneratedFunctionWrapper{(3, 4, is_split(sys))}(g_oop, g_iip)

SDDEFunction{iip}(f, g; sys = sys, initialization_data)
end
Expand All @@ -549,13 +538,6 @@ variable and parameter vectors, respectively.
"""
struct ODEFunctionExpr{iip, specialize} end

struct ODEFunctionClosure{O, I} <: Function
f_oop::O
f_iip::I
end
(f::ODEFunctionClosure)(u, p, t) = f.f_oop(u, p, t)
(f::ODEFunctionClosure)(du, u, p, t) = f.f_iip(du, u, p, t)

function ODEFunctionExpr{iip, specialize}(sys::AbstractODESystem, dvs = unknowns(sys),
ps = parameters(sys), u0 = nothing;
version = nothing, tgrad = false,
Expand All @@ -572,13 +554,14 @@ function ODEFunctionExpr{iip, specialize}(sys::AbstractODESystem, dvs = unknowns
f_oop, f_iip = generate_function(sys, dvs, ps; expression = Val{true}, kwargs...)

fsym = gensym(:f)
_f = :($fsym = $ODEFunctionClosure($f_oop, $f_iip))
_f = :($fsym = $(GeneratedFunctionWrapper{(2, 3, is_split(sys))})($f_oop, $f_iip))
tgradsym = gensym(:tgrad)
if tgrad
tgrad_oop, tgrad_iip = generate_tgrad(sys, dvs, ps;
simplify = simplify,
expression = Val{true}, kwargs...)
_tgrad = :($tgradsym = $ODEFunctionClosure($tgrad_oop, $tgrad_iip))
_tgrad = :($tgradsym = $(GeneratedFunctionWrapper{(2, 3, is_split(sys))})(
$tgrad_oop, $tgrad_iip))
else
_tgrad = :($tgradsym = nothing)
end
Expand All @@ -588,7 +571,8 @@ function ODEFunctionExpr{iip, specialize}(sys::AbstractODESystem, dvs = unknowns
jac_oop, jac_iip = generate_jacobian(sys, dvs, ps;
sparse = sparse, simplify = simplify,
expression = Val{true}, kwargs...)
_jac = :($jacsym = $ODEFunctionClosure($jac_oop, $jac_iip))
_jac = :($jacsym = $(GeneratedFunctionWrapper{(2, 3, is_split(sys))})(
$jac_oop, $jac_iip))
else
_jac = :($jacsym = nothing)
end
Expand Down Expand Up @@ -647,13 +631,6 @@ variable and parameter vectors, respectively.
"""
struct DAEFunctionExpr{iip} end

struct DAEFunctionClosure{O, I} <: Function
f_oop::O
f_iip::I
end
(f::DAEFunctionClosure)(du, u, p, t) = f.f_oop(du, u, p, t)
(f::DAEFunctionClosure)(out, du, u, p, t) = f.f_iip(out, du, u, p, t)

function DAEFunctionExpr{iip}(sys::AbstractODESystem, dvs = unknowns(sys),
ps = parameters(sys), u0 = nothing;
version = nothing, tgrad = false,
Expand All @@ -667,7 +644,7 @@ function DAEFunctionExpr{iip}(sys::AbstractODESystem, dvs = unknowns(sys),
f_oop, f_iip = generate_function(sys, dvs, ps; expression = Val{true},
implicit_dae = true, kwargs...)
fsym = gensym(:f)
_f = :($fsym = $DAEFunctionClosure($f_oop, $f_iip))
_f = :($fsym = $(GeneratedFunctionWrapper{(3, 4, is_split(sys))})($f_oop, $f_iip))
ex = quote
$_f
ODEFunction{$iip}($fsym)
Expand Down Expand Up @@ -708,6 +685,7 @@ function SymbolicTstops(
expression = Val{true},
p_start = 1, p_end = length(rps), add_observed = false, force_SA = true)
tstops = eval_or_rgf(tstops; eval_expression, eval_module)
tstops = GeneratedFunctionWrapper{(1, 3, is_split(sys))}(tstops, nothing)
return SymbolicTstops(tstops)
end

Expand Down
9 changes: 7 additions & 2 deletions src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -521,9 +521,14 @@ function build_explicit_observed_function(sys, ts;
output_type, mkarray, try_namespaced = true, expression = Val{true})
if fns isa Tuple
oop, iip = eval_or_rgf.(fns; eval_expression, eval_module)
return return_inplace ? (oop, iip) : oop
f = GeneratedFunctionWrapper{(
p_start, length(args) - length(ps) + 1, is_split(sys))}(oop, iip)
return return_inplace ? (f, f) : f
else
return eval_or_rgf(fns; eval_expression, eval_module)
f = eval_or_rgf(fns; eval_expression, eval_module)
f = GeneratedFunctionWrapper{(
p_start, length(args) - length(ps) + 1, is_split(sys))}(f, nothing)
return f
end
end

Expand Down
19 changes: 6 additions & 13 deletions src/systems/diffeqs/sdesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -604,18 +604,14 @@ function DiffEqBase.SDEFunction{iip, specialize}(sys::SDESystem, dvs = unknowns(
kwargs...)
g_oop, g_iip = eval_or_rgf.(g_gen; eval_expression, eval_module)

f(u, p, t) = f_oop(u, p, t)
f(du, u, p, t) = f_iip(du, u, p, t)
g(u, p, t) = g_oop(u, p, t)
g(du, u, p, t) = g_iip(du, u, p, t)
f = GeneratedFunctionWrapper{(2, 3, is_split(sys))}(f_oop, f_iip)
g = GeneratedFunctionWrapper{(2, 3, is_split(sys))}(g_oop, g_iip)

if tgrad
tgrad_gen = generate_tgrad(sys, dvs, ps; expression = Val{true},
kwargs...)
tgrad_oop, tgrad_iip = eval_or_rgf.(tgrad_gen; eval_expression, eval_module)

_tgrad(u, p, t) = tgrad_oop(u, p, t)
_tgrad(J, u, p, t) = tgrad_iip(J, u, p, t)
_tgrad = GeneratedFunctionWrapper{(2, 3, is_split(sys))}(tgrad_oop, tgrad_iip)
else
_tgrad = nothing
end
Expand All @@ -625,8 +621,7 @@ function DiffEqBase.SDEFunction{iip, specialize}(sys::SDESystem, dvs = unknowns(
sparse = sparse, kwargs...)
jac_oop, jac_iip = eval_or_rgf.(jac_gen; eval_expression, eval_module)

_jac(u, p, t) = jac_oop(u, p, t)
_jac(J, u, p, t) = jac_iip(J, u, p, t)
_jac = GeneratedFunctionWrapper{(2, 3, is_split(sys))}(jac_oop, jac_iip)
else
_jac = nothing
end
Expand All @@ -637,10 +632,8 @@ function DiffEqBase.SDEFunction{iip, specialize}(sys::SDESystem, dvs = unknowns(
Wfact_oop, Wfact_iip = eval_or_rgf.(tmp_Wfact; eval_expression, eval_module)
Wfact_oop_t, Wfact_iip_t = eval_or_rgf.(tmp_Wfact_t; eval_expression, eval_module)

_Wfact(u, p, dtgamma, t) = Wfact_oop(u, p, dtgamma, t)
_Wfact(W, u, p, dtgamma, t) = Wfact_iip(W, u, p, dtgamma, t)
_Wfact_t(u, p, dtgamma, t) = Wfact_oop_t(u, p, dtgamma, t)
_Wfact_t(W, u, p, dtgamma, t) = Wfact_iip_t(W, u, p, dtgamma, t)
_Wfact = GeneratedFunctionWrapper{(2, 4, is_split(sys))}(Wfact_oop, Wfact_iip)
_Wfact_t = GeneratedFunctionWrapper{(2, 4, is_split(sys))}(Wfact_oop_t, Wfact_iip_t)
else
_Wfact, _Wfact_t = nothing, nothing
end
Expand Down
3 changes: 1 addition & 2 deletions src/systems/discrete_system/discrete_system.jl
Original file line number Diff line number Diff line change
Expand Up @@ -354,8 +354,7 @@ function SciMLBase.DiscreteFunction{iip, specialize}(
f_gen = generate_function(sys, dvs, ps; expression = Val{true},
expression_module = eval_module, kwargs...)
f_oop, f_iip = eval_or_rgf.(f_gen; eval_expression, eval_module)
f(u, p, t) = f_oop(u, p, t)
f(du, u, p, t) = f_iip(du, u, p, t)
f = GeneratedFunctionWrapper{(2, 3, is_split(sys))}(f_oop, f_iip)

if specialize === SciMLBase.FunctionWrapperSpecialize && iip
if u0 === nothing || p === nothing || t === nothing
Expand Down
6 changes: 3 additions & 3 deletions src/systems/jumps/jumpsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ function generate_rate_function(js::JumpSystem, rate)
rate = substitute(rate, csubs)
end
p = reorder_parameters(js)
rf = build_function_wrapper(js, rate, unknowns(js), p...,
build_function_wrapper(js, rate, unknowns(js), p...,
get_iv(js),
expression = Val{true})
end
Expand All @@ -302,7 +302,7 @@ end
function assemble_vrj(
js, vrj, unknowntoid; eval_expression = false, eval_module = @__MODULE__)
rate = eval_or_rgf(generate_rate_function(js, vrj.rate); eval_expression, eval_module)

rate = GeneratedFunctionWrapper{(2, 3, is_split(js))}(rate, nothing)
outputvars = (value(affect.lhs) for affect in vrj.affect!)
outputidxs = [unknowntoid[var] for var in outputvars]
affect = eval_or_rgf(generate_affect_function(js, vrj.affect!, outputidxs);
Expand All @@ -326,7 +326,7 @@ end
function assemble_crj(
js, crj, unknowntoid; eval_expression = false, eval_module = @__MODULE__)
rate = eval_or_rgf(generate_rate_function(js, crj.rate); eval_expression, eval_module)

rate = GeneratedFunctionWrapper{(2, 3, is_split(js))}(rate, nothing)
outputvars = (value(affect.lhs) for affect in crj.affect!)
outputidxs = [unknowntoid[var] for var in outputvars]
affect = eval_or_rgf(generate_affect_function(js, crj.affect!, outputidxs);
Expand Down
Loading
Loading