@@ -230,3 +230,59 @@ function build_function_wrapper(sys::AbstractSystem, expr, args...; p_start = 2,
230
230
end
231
231
return build_function (expr, args... ; wrap_code, similarto, kwargs... )
232
232
end
233
+
234
+ """
235
+ $(TYPEDEF)
236
+
237
+ A wrapper around a generated in-place and out-of-place function. The type-parameter `P`
238
+ must be a 3-tuple where the first element is the index of the parameter object in the
239
+ arguments, the second is the expected number of arguments in the out-of-place variant
240
+ of the function, and the third is a boolean indicating whether the generated functions
241
+ are for a split system. For scalar functions, the inplace variant can be `nothing`.
242
+ """
243
+ struct GeneratedFunctionWrapper{P, O, I} <: Function
244
+ f_oop:: O
245
+ f_iip:: I
246
+ end
247
+
248
+ function GeneratedFunctionWrapper {P} (foop:: O , fiip:: I ) where {P, O, I}
249
+ GeneratedFunctionWrapper {P, O, I} (foop, fiip)
250
+ end
251
+
252
+ function (gfw:: GeneratedFunctionWrapper )(args... )
253
+ _generated_call (gfw, args... )
254
+ end
255
+
256
+ @generated function _generated_call (gfw:: GeneratedFunctionWrapper{P} , args... ) where {P}
257
+ paramidx, nargs, issplit = P
258
+ iip = false
259
+ # IIP case has one more argument
260
+ if length (args) == nargs + 1
261
+ nargs += 1
262
+ paramidx += 1
263
+ iip = true
264
+ end
265
+ if length (args) != nargs
266
+ throw (ArgumentError (" Expected $nargs arguments, got $(length (args)) ." ))
267
+ end
268
+
269
+ # the function to use
270
+ f = iip ? :(gfw. f_iip) : :(gfw. f_oop)
271
+ # non-split systems just call it as-is
272
+ if ! issplit
273
+ return :($ f (args... ))
274
+ end
275
+ if args[paramidx] <: Union{Tuple, MTKParameters} &&
276
+ ! (args[paramidx] <: Tuple{Vararg{Number}} )
277
+ # for split systems, call it as-is if the parameter object is a tuple or MTKParameters
278
+ # but not if it is a tuple of numbers
279
+ return :($ f (args... ))
280
+ else
281
+ # The user provided a single buffer/tuple for the parameter object, so wrap that
282
+ # one in a tuple
283
+ fargs = ntuple (Val (length (args))) do i
284
+ i == paramidx ? :((args[$ i],)) : :(args[$ i])
285
+ end
286
+ return :($ f ($ (fargs... )))
287
+ end
288
+ end
0 commit comments