Skip to content

Commit c982267

Browse files
feat: add GeneratedFunctionWrapper
1 parent 3327ec3 commit c982267

File tree

1 file changed

+56
-0
lines changed

1 file changed

+56
-0
lines changed

src/systems/codegen_utils.jl

+56
Original file line numberDiff line numberDiff line change
@@ -230,3 +230,59 @@ function build_function_wrapper(sys::AbstractSystem, expr, args...; p_start = 2,
230230
end
231231
return build_function(expr, args...; wrap_code, similarto, kwargs...)
232232
end
233+
234+
"""
235+
$(TYPEDEF)
236+
237+
A wrapper around a generated in-place and out-of-place function. The type-parameter `P`
238+
must be a 3-tuple where the first element is the index of the parameter object in the
239+
arguments, the second is the expected number of arguments in the out-of-place variant
240+
of the function, and the third is a boolean indicating whether the generated functions
241+
are for a split system. For scalar functions, the inplace variant can be `nothing`.
242+
"""
243+
struct GeneratedFunctionWrapper{P, O, I} <: Function
244+
f_oop::O
245+
f_iip::I
246+
end
247+
248+
function GeneratedFunctionWrapper{P}(foop::O, fiip::I) where {P, O, I}
249+
GeneratedFunctionWrapper{P, O, I}(foop, fiip)
250+
end
251+
252+
function (gfw::GeneratedFunctionWrapper)(args...)
253+
_generated_call(gfw, args...)
254+
end
255+
256+
@generated function _generated_call(gfw::GeneratedFunctionWrapper{P}, args...) where {P}
257+
paramidx, nargs, issplit = P
258+
iip = false
259+
# IIP case has one more argument
260+
if length(args) == nargs + 1
261+
nargs += 1
262+
paramidx += 1
263+
iip = true
264+
end
265+
if length(args) != nargs
266+
throw(ArgumentError("Expected $nargs arguments, got $(length(args))."))
267+
end
268+
269+
# the function to use
270+
f = iip ? :(gfw.f_iip) : :(gfw.f_oop)
271+
# non-split systems just call it as-is
272+
if !issplit
273+
return :($f(args...))
274+
end
275+
if args[paramidx] <: Union{Tuple, MTKParameters} &&
276+
!(args[paramidx] <: Tuple{Vararg{Number}})
277+
# for split systems, call it as-is if the parameter object is a tuple or MTKParameters
278+
# but not if it is a tuple of numbers
279+
return :($f(args...))
280+
else
281+
# The user provided a single buffer/tuple for the parameter object, so wrap that
282+
# one in a tuple
283+
fargs = ntuple(Val(length(args))) do i
284+
i == paramidx ? :((args[$i],)) : :(args[$i])
285+
end
286+
return :($f($(fargs...)))
287+
end
288+
end

0 commit comments

Comments
 (0)