Skip to content

refactor: centralize all code generation #3360

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 51 commits into from
Feb 3, 2025
Merged
Changes from 1 commit
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
d90dda3
feat: allow specifying observed equations to `observed_equations_used…
AayushSabharwal Jan 30, 2025
cf40621
fix: fix `observed_equations_used_by` for `DiscreteSystem`
AayushSabharwal Jan 30, 2025
b50cb27
feat: add `build_function_wrapper` for centralized codegen
AayushSabharwal Jan 30, 2025
9166c74
refactor: use `build_function_wrapper` in `generate_tgrad`
AayushSabharwal Jan 30, 2025
2887ca4
refactor: use `build_function` in `generate_jacobian(::AbstractODESys…
AayushSabharwal Jan 30, 2025
3574e3f
refactor: use `build_function_wrapper` in `generate_control_jacobian`
AayushSabharwal Jan 30, 2025
f183e8e
refactor: use `build_function_wrapper` in `generate_dae_jacobian`
AayushSabharwal Jan 30, 2025
d162999
refactor: use `build_function_wrapper` in `generate_function(::Abstra…
AayushSabharwal Jan 30, 2025
7010d12
refactor: use new codegen in `ODEFunction`
AayushSabharwal Jan 30, 2025
92364b8
refactor: use new codegen in `DAEFunction`
AayushSabharwal Jan 30, 2025
141fcab
refactor: use new codegen in `ODEFunctionClosure`, `DAEFunctionClosure`
AayushSabharwal Jan 30, 2025
a3ff92f
refactor: use `build_function_wrapper` in `generate_diffusion_function`
AayushSabharwal Jan 30, 2025
b0fdacf
refactor: use new codegen in `SDEFunction`
AayushSabharwal Jan 30, 2025
cff74a9
refactor: use `build_function_wrapper` in `generate_jacobian(::Nonlin…
AayushSabharwal Jan 30, 2025
517c1a5
refactor: use `build_function_wrapper` in `generate_hessian(::Nonline…
AayushSabharwal Jan 30, 2025
4c7cffd
refactor: use `build_function_wrapper` in `generate_function(::Nonlin…
AayushSabharwal Jan 30, 2025
0bc240d
refactor: use new codegen in `NonlinearFunction`, `IntervalNonlinearF…
AayushSabharwal Jan 30, 2025
de83f70
fix: handle extra constants in `build_function_wrapper`
AayushSabharwal Jan 30, 2025
a845c2b
feat: better observed handling, optional `DestructuredArgs` binding i…
AayushSabharwal Jan 30, 2025
60d501b
fix: use `time_varying_as_func` in `build_function_wrapper`
AayushSabharwal Jan 30, 2025
1cab1c3
feat: better delay handling, preface handling in `build_function_wrap…
AayushSabharwal Jan 30, 2025
106b05a
refactor: use `build_function_wrapper` in callbacks
AayushSabharwal Jan 30, 2025
d6607f9
refactor: use `build_function_wrapper` in `generate_custom_function`
AayushSabharwal Jan 30, 2025
f9a5027
fix: don't consider parameters as delays
AayushSabharwal Jan 30, 2025
5637a0c
refactor: use `build_function_wrapper` in `SymbolicTstops`
AayushSabharwal Jan 30, 2025
6491c17
refactor: use `build_function_wrapper` for history functions
AayushSabharwal Jan 30, 2025
ead141a
refactor: refactor `generate_function(::DiscreteSystem)` to use new c…
AayushSabharwal Jan 30, 2025
dc09909
refactor: use `build_function_wrapper` in `JumpSystem` codegen
AayushSabharwal Jan 30, 2025
95a75ab
refactor: fix `linearize_symbolic` to respect new codegen
AayushSabharwal Jan 30, 2025
1600013
fix: refactor `modelingtoolkitize` to use new codegen
AayushSabharwal Jan 30, 2025
a42cec7
test: use new codegen in index reduction test
AayushSabharwal Jan 30, 2025
5e7b583
test: use new codegen in distributed test
AayushSabharwal Jan 30, 2025
8ce0eb2
test: use new codegen in jumpsystem test
AayushSabharwal Jan 30, 2025
f9069b1
test: use new codegen in odesystem test
AayushSabharwal Jan 30, 2025
ddce985
test: use new codegen in sdesystem test
AayushSabharwal Jan 30, 2025
b8d9de5
test: use new codegen in labelledarrays tests
AayushSabharwal Jan 30, 2025
5fb0eff
fix: add `collect_constants!` fallback for `Symbol`
AayushSabharwal Jan 31, 2025
3e33d96
feat: support `get_cmap(::OptimizationSystem)`
AayushSabharwal Jan 31, 2025
29ef316
fix: handle array symbolics in `observed_equations_used_by`
AayushSabharwal Jan 31, 2025
322b5db
feat: add output type support to `build_function_wrapper`
AayushSabharwal Jan 31, 2025
07f8206
fix: fix `all_symbols` not returning dependent parameters
AayushSabharwal Jan 31, 2025
0774860
feat: use `build_function_wrapper` in `build_explicit_observed_function`
AayushSabharwal Jan 31, 2025
14020bd
feat: allow opting out of destructuring `MTKParameters` in `build_fun…
AayushSabharwal Jan 31, 2025
4686893
refactor: use `build_function_wrapper` in `generate_control_function`
AayushSabharwal Jan 31, 2025
f28df16
fix: refactor `modelingtoolkitize(::OptimizationProblem)` to use new …
AayushSabharwal Jan 31, 2025
cff02c9
refactor: use `build_function_wrapper` in `OptimizationProblem` codegen
AayushSabharwal Jan 31, 2025
4274ec9
docs: add docstring for code generation utils
AayushSabharwal Jan 31, 2025
dfc49b4
feat: add `extra_assignments` to `build_function_wrapper`
AayushSabharwal Jan 31, 2025
884a0fb
refactor: use `build_function_wrapper` in `SCCNonlinearProblem`
AayushSabharwal Jan 31, 2025
4c39175
fix: search through filtered observed equations in `build_function_wr…
AayushSabharwal Feb 3, 2025
1fc081f
build: bump Symbolics compat
AayushSabharwal Feb 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
docs: add docstring for code generation utils
  • Loading branch information
AayushSabharwal committed Feb 3, 2025
commit 4274ec9bb9d741e5767f514c92014159546df6c3
92 changes: 86 additions & 6 deletions src/systems/codegen_utils.jl
Original file line number Diff line number Diff line change
@@ -1,66 +1,135 @@
"""
$(TYPEDSIGNATURES)

Return the name for the `i`th argument in a function generated by `build_function_wrapper`.
"""
function generated_argument_name(i::Int)
return Symbol(:__mtk_arg_, i)
end

"""
$(TYPEDSIGNATURES)

Given the arguments to `build_function_wrapper`, return a list of assignments which
reconstruct array variables if they are present scalarized in `args`.
"""
function array_variable_assignments(args...)
# map array symbolic to an identically sized array where each element is (buffer_idx, idx_in_buffer)
var_to_arridxs = Dict{BasicSymbolic, Array{Tuple{Int, Int}}}()
for (i, arg) in enumerate(args)
# filter out non-arrays
# any element of args which is not an array is assumed to not contain a
# scalarized array symbolic. This works because the only non-array element
# is the independent variable
symbolic_type(arg) == NotSymbolic() || continue
arg isa AbstractArray || continue

# go through symbolics
for (j, var) in enumerate(arg)
var = unwrap(var)
# filter out non-array-symbolics
iscall(var) || continue
operation(var) == getindex || continue
arrvar = arguments(var)[1]
idxbuffer = get!(() -> map(Returns((0, 0)), eachindex(arrvar)), var_to_arridxs, arrvar)
# get and/or construct the buffer storing indexes
idxbuffer = get!(
() -> map(Returns((0, 0)), eachindex(arrvar)), var_to_arridxs, arrvar)
idxbuffer[arguments(var)[2:end]...] = (i, j)
end
end

assignments = Assignment[]
for (arrvar, idxs) in var_to_arridxs
# all elements of the array need to be present in `args` to form the
# reconstructing assignment
any(iszero ∘ first, idxs) && continue

# if they are all in the same buffer, we can take a shortcut and `view` into it
if allequal(Iterators.map(first, idxs))
buffer_idx = first(first(idxs))
idxs = map(last, idxs)
# if all the elements are contiguous and ordered, turn the array of indexes into a range
# to help reduce allocations
if first(idxs) < last(idxs) && vec(idxs) == first(idxs):last(idxs)
idxs = first(idxs):last(idxs)
elseif vec(idxs) == last(idxs):-1:first(idxs)
idxs = last(idxs):-1:first(idxs)
else
# Otherwise, turn the indexes into an `SArray` so they're stack-allocated
idxs = SArray{Tuple{size(idxs)...}}(idxs)
end
# view and reshape
push!(assignments, arrvar ← term(reshape, term(view, generated_argument_name(buffer_idx), idxs), size(arrvar)))
else
elems = map(idxs) do idx
i, j = idx
term(getindex, generated_argument_name(i), j)
end
# use `MakeArray` and generate a stack-allocated array
push!(assignments, arrvar ← MakeArray(elems, SArray))
end
end

return assignments
end

"""
$(TYPEDSIGNATURES)

A wrapper around `build_function` which performs the necessary transformations for
code generation of all types of systems. `expr` is the expression returned from the
generated functions, and `args` are the arguments.

# Keyword Arguments

- `p_start`, `p_end`: Denotes the indexes in `args` where the buffers of the splatted
`MTKParameters` object are present. These are collapsed into a single argument and
destructured inside the function. `p_start` must also be provided for non-split systems
since it is used by `wrap_delays`.
- `wrap_delays`: Whether to transform delayed unknowns of `sys` present in `expr` into
calls to a history function. The history function is added to the list of arguments
right before parameters, at the index `p_start`.
- `wrap_code`: Forwarded to `build_function`.
- `add_observed`: Whether to add assignment statements for observed equations in the
generated code.
- `filter_observed`: A predicate function to filter out observed equations which should
not be added to the generated code.
- `create_bindings`: Whether to explicitly destructure arrays of symbolics present in
`args` in the generated code. If `false`, all usages of the individual symbolics will
instead call `getindex` on the relevant argument. This is useful if the generated
function writes to one of its arguments and expects subsequent code to use the new
values. Note that the collapsed `MTKParameters` argument will always be explicitly
destructured regardless of this keyword argument.
- `output_type`: The type of the output buffer. If `mkarray` (see below) is `nothing`,
this will be passed to the `similarto` argument of `build_function`. If `output_type`
is `Tuple`, `expr` will be wrapped in `SymbolicUtils.Code.MakeTuple` (regardless of
whether it is scalar or an array).
- `mkarray`: A function which accepts `expr` and `output_type` and returns a code
generation object similar to `MakeArray` or `MakeTuple` to be used to generate
code for `expr`.
- `wrap_mtkparameters`: Whether to collapse parameter buffers for a split system into a
argument.

All other keyword arguments are forwarded to `build_function`.
"""
function build_function_wrapper(sys::AbstractSystem, expr, args...; p_start = 2, p_end = is_time_dependent(sys) ? length(args) - 1 : length(args), wrap_delays = is_dde(sys), wrap_code = identity, add_observed = true, filter_observed = Returns(true), create_bindings = true, output_type = nothing, mkarray = nothing, wrap_mtkparameters = true, kwargs...)
isscalar = !(expr isa AbstractArray || symbolic_type(expr) == ArraySymbolic())

# filter observed equations
obs = filter(filter_observed, observed(sys))
# turn delayed unknowns into calls to the history function
if wrap_delays
history_arg = is_split(sys) ? MTKPARAMETERS_ARG : generated_argument_name(p_start)
obs = map(obs) do eq
delay_to_function(sys, eq; history_arg)
end
expr = delay_to_function(sys, expr; history_arg)
args = (args[1:p_start-1]..., DDE_HISTORY_FUN, args[p_start:end]...)
# add extra argument
args = (args[1:(p_start - 1)]..., DDE_HISTORY_FUN, args[p_start:end]...)
p_start += 1
p_end += 1
end
pdeps = parameter_dependencies(sys)

# get the constants to add to the code
cmap, _ = get_cmap(sys)
extra_constants = collect_constants(expr)
filter!(extra_constants) do c
Expand All @@ -69,13 +138,15 @@ function build_function_wrapper(sys::AbstractSystem, expr, args...; p_start = 2,
for c in extra_constants
push!(cmap, c ~ getdefault(c))
end
# only get the necessary observed equations, avoiding extra computation
if add_observed
obsidxs = observed_equations_used_by(sys, expr)
else
obsidxs = Int[]
end
# similarly for parameter dependency equations
pdepidxs = observed_equations_used_by(sys, expr; obs = pdeps)

# assignments for reconstructing scalarized array symbolics
assignments = array_variable_assignments(args...)

for eq in Iterators.flatten((cmap, pdeps[pdepidxs], obs[obsidxs]))
Expand All @@ -84,6 +155,9 @@ function build_function_wrapper(sys::AbstractSystem, expr, args...; p_start = 2,

args = ntuple(Val(length(args))) do i
arg = args[i]
# for time-dependent systems, all arguments are passed through `time_varying_as_func`
# TODO: This is legacy behavior and a candidate for removal in v10 since we have callable
# parameters now.
if is_time_dependent(sys)
arg = if symbolic_type(arg) == NotSymbolic()
arg isa AbstractArray ?
Expand All @@ -92,16 +166,19 @@ function build_function_wrapper(sys::AbstractSystem, expr, args...; p_start = 2,
time_varying_as_func(unwrap(arg), sys)
end
end
# Make sure to use the proper names for arguments
if symbolic_type(arg) == NotSymbolic() && arg isa AbstractArray
DestructuredArgs(arg, generated_argument_name(i); create_bindings)
else
arg
end
end

# wrap into a single MTKParameters argument
if is_split(sys) && wrap_mtkparameters
if p_start > p_end
args = (args[1:p_start-1]..., MTKPARAMETERS_ARG, args[p_end+1:end]...)
# In case there are no parameter buffers, still insert an argument
args = (args[1:(p_start - 1)]..., MTKPARAMETERS_ARG, args[(p_end + 1):end]...)
else
# cannot apply `create_bindings` here since it doesn't nest
args = (args[1:(p_start - 1)]...,
Expand All @@ -110,12 +187,14 @@ function build_function_wrapper(sys::AbstractSystem, expr, args...; p_start = 2,
end
end

# add preface assignments
if has_preface(sys) && (pref = preface(sys)) !== nothing
append!(assignments, pref)
end

wrap_code = wrap_code .∘ wrap_assignments(isscalar, assignments)

# handling of `output_type` and `mkarray`
similarto = nothing
if output_type === Tuple
expr = MakeTuple(Tuple(expr))
Expand All @@ -127,6 +206,7 @@ function build_function_wrapper(sys::AbstractSystem, expr, args...; p_start = 2,
wrap_code = wrap_code[2]
end

# scalar `build_function` only accepts a single function for `wrap_code`.
if wrap_code isa Tuple && symbolic_type(expr) == ScalarSymbolic()
wrap_code = wrap_code[1]
end
Expand Down