Skip to content

Commit f7f0221

Browse files
Merge pull request #3324 from AayushSabharwal/as/scc-fix
feat: support caching of different types of subexpressions in `SCCNonlinearProblem`
2 parents 083a639 + 6ce9e85 commit f7f0221

File tree

5 files changed

+204
-38
lines changed

5 files changed

+204
-38
lines changed

src/ModelingToolkit.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ import SCCNonlinearSolve
5454
using Reexport
5555
using RecursiveArrayTools
5656
import Graphs: SimpleDiGraph, add_edge!, incidence_matrix
57-
import BlockArrays: BlockedArray, Block, blocksize, blocksizes
57+
import BlockArrays: BlockArray, BlockedArray, Block, blocksize, blocksizes, blockpush!,
58+
undef_blocks, blocks
5859
import CommonSolve
5960
import EnumX
6061

src/systems/abstractsystem.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -1919,7 +1919,7 @@ function Base.show(
19191919
nrows > 0 && hint && print(io, " see hierarchy($name)")
19201920
for i in 1:nrows
19211921
sub = subs[i]
1922-
name = String(nameof(sub))
1922+
local name = String(nameof(sub))
19231923
print(io, "\n ", name)
19241924
desc = description(sub)
19251925
if !isempty(desc)

src/systems/nonlinear/nonlinearsystem.jl

+91-27
Original file line numberDiff line numberDiff line change
@@ -573,29 +573,37 @@ function DiffEqBase.NonlinearLeastSquaresProblem{iip}(sys::NonlinearSystem, u0ma
573573
NonlinearLeastSquaresProblem{iip}(f, u0, p; filter_kwargs(kwargs)...)
574574
end
575575

576+
const TypeT = Union{DataType, UnionAll}
577+
576578
struct CacheWriter{F}
577579
fn::F
578580
end
579581

580582
function (cw::CacheWriter)(p, sols)
581-
cw.fn(p.caches[1], sols, p...)
583+
cw.fn(p.caches, sols, p...)
582584
end
583585

584-
function CacheWriter(sys::AbstractSystem, exprs, solsyms, obseqs::Vector{Equation};
586+
function CacheWriter(sys::AbstractSystem, buffer_types::Vector{TypeT},
587+
exprs::Dict{TypeT, Vector{Any}}, solsyms, obseqs::Vector{Equation};
585588
eval_expression = false, eval_module = @__MODULE__)
586589
ps = parameters(sys)
587590
rps = reorder_parameters(sys, ps)
588591
obs_assigns = [eq.lhs eq.rhs for eq in obseqs]
589592
cmap, cs = get_cmap(sys)
590593
cmap_assigns = [eq.lhs eq.rhs for eq in cmap]
594+
595+
outsyms = [Symbol(:out, i) for i in eachindex(buffer_types)]
596+
body = map(eachindex(buffer_types), buffer_types) do i, T
597+
Symbol(:tmp, i) SetArray(true, :(out[$i]), get(exprs, T, []))
598+
end
591599
fn = Func(
592600
[:out, DestructuredArgs(DestructuredArgs.(solsyms)),
593601
DestructuredArgs.(rps)...],
594602
[],
595-
SetArray(true, :out, exprs)
603+
Let(body, :())
596604
) |> wrap_assignments(false, obs_assigns)[2] |>
597605
wrap_parameter_dependencies(sys, false)[2] |>
598-
wrap_array_vars(sys, exprs; dvs = nothing, inputs = [])[2] |>
606+
wrap_array_vars(sys, []; dvs = nothing, inputs = [])[2] |>
599607
wrap_assignments(false, cmap_assigns)[2] |> toexpr
600608
return CacheWriter(eval_or_rgf(fn; eval_expression, eval_module))
601609
end
@@ -677,8 +685,17 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map,
677685

678686
explicitfuns = []
679687
nlfuns = []
680-
prevobsidxs = Int[]
681-
cachesize = 0
688+
prevobsidxs = BlockArray(undef_blocks, Vector{Int}, Int[])
689+
# Cache buffer types and corresponding sizes. Stored as a pair of arrays instead of a
690+
# dict to maintain a consistent order of buffers across SCCs
691+
cachetypes = TypeT[]
692+
cachesizes = Int[]
693+
# explicitfun! related information for each SCC
694+
# We need to compute buffer sizes before doing any codegen
695+
scc_cachevars = Dict{TypeT, Vector{Any}}[]
696+
scc_cacheexprs = Dict{TypeT, Vector{Any}}[]
697+
scc_eqs = Vector{Equation}[]
698+
scc_obs = Vector{Equation}[]
682699
for (i, (escc, vscc)) in enumerate(zip(eq_sccs, var_sccs))
683700
# subset unknowns and equations
684701
_dvs = dvs[vscc]
@@ -690,11 +707,10 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map,
690707
_obs = obs[obsidxs]
691708

692709
# get all subexpressions in the RHS which we can precompute in the cache
710+
# precomputed subexpressions should not contain `banned_vars`
693711
banned_vars = Set{Any}(vcat(_dvs, getproperty.(_obs, (:lhs,))))
694-
for var in banned_vars
695-
iscall(var) || continue
696-
operation(var) === getindex || continue
697-
push!(banned_vars, arguments(var)[1])
712+
filter!(banned_vars) do var
713+
symbolic_type(var) != ArraySymbolic() || all(x -> var[i] in banned_vars, eachindex(var))
698714
end
699715
state = Dict()
700716
for i in eachindex(_obs)
@@ -706,37 +722,85 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map,
706722
_eqs[i].rhs, banned_vars, state)
707723
end
708724

709-
# cached variables and their corresponding expressions
710-
cachevars = Any[obs[i].lhs for i in prevobsidxs]
711-
cacheexprs = Any[obs[i].lhs for i in prevobsidxs]
725+
# map from symtype to cached variables and their expressions
726+
cachevars = Dict{Union{DataType, UnionAll}, Vector{Any}}()
727+
cacheexprs = Dict{Union{DataType, UnionAll}, Vector{Any}}()
728+
# observed of previous SCCs are in the cache
729+
# NOTE: When we get proper CSE, we can substitute these
730+
# and then use `subexpressions_not_involving_vars!`
731+
for i in prevobsidxs
732+
T = symtype(obs[i].lhs)
733+
buf = get!(() -> Any[], cachevars, T)
734+
push!(buf, obs[i].lhs)
735+
736+
buf = get!(() -> Any[], cacheexprs, T)
737+
push!(buf, obs[i].lhs)
738+
end
739+
712740
for (k, v) in state
713-
push!(cachevars, unwrap(v))
714-
push!(cacheexprs, unwrap(k))
741+
k = unwrap(k)
742+
v = unwrap(v)
743+
T = symtype(k)
744+
buf = get!(() -> Any[], cachevars, T)
745+
push!(buf, v)
746+
buf = get!(() -> Any[], cacheexprs, T)
747+
push!(buf, k)
715748
end
716-
cachesize = max(cachesize, length(cachevars))
749+
750+
# update the sizes of cache buffers
751+
for (T, buf) in cachevars
752+
idx = findfirst(isequal(T), cachetypes)
753+
if idx === nothing
754+
push!(cachetypes, T)
755+
push!(cachesizes, 0)
756+
idx = lastindex(cachetypes)
757+
end
758+
cachesizes[idx] = max(cachesizes[idx], length(buf))
759+
end
760+
761+
push!(scc_cachevars, cachevars)
762+
push!(scc_cacheexprs, cacheexprs)
763+
push!(scc_eqs, _eqs)
764+
push!(scc_obs, _obs)
765+
blockpush!(prevobsidxs, obsidxs)
766+
end
767+
768+
for (i, (escc, vscc)) in enumerate(zip(eq_sccs, var_sccs))
769+
_dvs = dvs[vscc]
770+
_eqs = scc_eqs[i]
771+
_prevobsidxs = reduce(vcat, blocks(prevobsidxs)[1:(i - 1)]; init = Int[])
772+
_obs = scc_obs[i]
773+
cachevars = scc_cachevars[i]
774+
cacheexprs = scc_cacheexprs[i]
717775

718776
if isempty(cachevars)
719777
push!(explicitfuns, Returns(nothing))
720778
else
721779
solsyms = getindex.((dvs,), view(var_sccs, 1:(i - 1)))
722780
push!(explicitfuns,
723-
CacheWriter(sys, cacheexprs, solsyms, obs[prevobsidxs];
781+
CacheWriter(sys, cachetypes, cacheexprs, solsyms, obs[_prevobsidxs];
724782
eval_expression, eval_module))
725783
end
784+
785+
cachebufsyms = Tuple(map(cachetypes) do T
786+
get(cachevars, T, [])
787+
end)
726788
f = SCCNonlinearFunction{iip}(
727-
sys, _eqs, _dvs, _obs, (cachevars,); eval_expression, eval_module, kwargs...)
789+
sys, _eqs, _dvs, _obs, cachebufsyms; eval_expression, eval_module, kwargs...)
728790
push!(nlfuns, f)
729-
append!(cachevars, _dvs)
730-
append!(cacheexprs, _dvs)
731-
for i in obsidxs
732-
push!(cachevars, obs[i].lhs)
733-
push!(cacheexprs, obs[i].rhs)
734-
end
735-
append!(prevobsidxs, obsidxs)
736791
end
737792

738-
if cachesize != 0
739-
p = rebuild_with_caches(p, BufferTemplate(eltype(u0), cachesize))
793+
if !isempty(cachetypes)
794+
templates = map(cachetypes, cachesizes) do T, n
795+
# Real refers to `eltype(u0)`
796+
if T == Real
797+
T = eltype(u0)
798+
elseif T <: Array && eltype(T) == Real
799+
T = Array{eltype(u0), ndims(T)}
800+
end
801+
BufferTemplate(T, n)
802+
end
803+
p = rebuild_with_caches(p, templates...)
740804
end
741805

742806
subprobs = []

src/utils.jl

+18-9
Original file line numberDiff line numberDiff line change
@@ -1108,23 +1108,33 @@ returns the modified `expr`.
11081108
"""
11091109
function subexpressions_not_involving_vars!(expr, vars, state::Dict{Any, Any})
11101110
expr = unwrap(expr)
1111-
symbolic_type(expr) == NotSymbolic() && return expr
1111+
if symbolic_type(expr) == NotSymbolic()
1112+
if is_array_of_symbolics(expr)
1113+
return map(expr) do el
1114+
subexpressions_not_involving_vars!(el, vars, state)
1115+
end
1116+
end
1117+
return expr
1118+
end
1119+
any(isequal(expr), vars) && return expr
11121120
iscall(expr) || return expr
1113-
is_variable_floatingpoint(expr) || return expr
1114-
symtype(expr) <: Union{Real, AbstractArray{<:Real}} || return expr
11151121
Symbolics.shape(expr) == Symbolics.Unknown() && return expr
11161122
haskey(state, expr) && return state[expr]
1117-
vs = ModelingToolkit.vars(expr)
1118-
intersect!(vs, vars)
1119-
if isempty(vs)
1123+
op = operation(expr)
1124+
args = arguments(expr)
1125+
# if this is a `getindex` and the getindex-ed value is a `Sym`
1126+
# or it is not a called parameter
1127+
# OR
1128+
# none of `vars` are involved in `expr`
1129+
if op === getindex && (issym(args[1]) || !iscalledparameter(args[1])) ||
1130+
(vs = ModelingToolkit.vars(expr); intersect!(vs, vars); isempty(vs))
11201131
sym = gensym(:subexpr)
11211132
stype = symtype(expr)
11221133
var = similar_variable(expr, sym)
11231134
state[expr] = var
11241135
return var
11251136
end
1126-
op = operation(expr)
1127-
args = arguments(expr)
1137+
11281138
if (op == (+) || op == (*)) && symbolic_type(expr) !== ArraySymbolic()
11291139
indep_args = []
11301140
dep_args = []
@@ -1143,7 +1153,6 @@ function subexpressions_not_involving_vars!(expr, vars, state::Dict{Any, Any})
11431153
return op(indep_term, dep_term)
11441154
end
11451155
newargs = map(args) do arg
1146-
symbolic_type(arg) != NotSymbolic() || is_array_of_symbolics(arg) || return arg
11471156
subexpressions_not_involving_vars!(arg, vars, state)
11481157
end
11491158
return maketerm(typeof(expr), op, newargs, metadata(expr))

test/scc_nonlinear_problem.jl

+92
Original file line numberDiff line numberDiff line change
@@ -161,3 +161,95 @@ end
161161
@test SciMLBase.successful_retcode(sccsol)
162162
@test val[] == 1
163163
end
164+
165+
import ModelingToolkitStandardLibrary.Blocks as B
166+
import ModelingToolkitStandardLibrary.Mechanical.Translational as T
167+
import ModelingToolkitStandardLibrary.Hydraulic.IsothermalCompressible as IC
168+
169+
@testset "Caching of subexpressions of different types" begin
170+
liquid_pressure(rho, rho_0, bulk) = (rho / rho_0 - 1) * bulk
171+
gas_pressure(rho, rho_0, p_gas, rho_gas) = rho * ((0 - p_gas) / (rho_0 - rho_gas))
172+
full_pressure(rho, rho_0, bulk, p_gas, rho_gas) = ifelse(
173+
rho >= rho_0, liquid_pressure(rho, rho_0, bulk),
174+
gas_pressure(rho, rho_0, p_gas, rho_gas))
175+
176+
@component function Volume(;
177+
#parameters
178+
area,
179+
direction = +1,
180+
x_int,
181+
name)
182+
pars = @parameters begin
183+
area = area
184+
x_int = x_int
185+
rho_0 = 1000
186+
bulk = 1e9
187+
p_gas = -1000
188+
rho_gas = 1
189+
end
190+
191+
vars = @variables begin
192+
x(t) = x_int
193+
dx(t), [guess = 0]
194+
p(t), [guess = 0]
195+
f(t), [guess = 0]
196+
rho(t), [guess = 0]
197+
m(t), [guess = 0]
198+
dm(t), [guess = 0]
199+
end
200+
201+
systems = @named begin
202+
port = IC.HydraulicPort()
203+
flange = T.MechanicalPort()
204+
end
205+
206+
eqs = [
207+
# connectors
208+
port.p ~ p
209+
port.dm ~ dm
210+
flange.v * direction ~ dx
211+
flange.f * direction ~ -f
212+
213+
# differentials
214+
D(x) ~ dx
215+
D(m) ~ dm
216+
217+
# physics
218+
p ~ full_pressure(rho, rho_0, bulk, p_gas, rho_gas)
219+
f ~ p * area
220+
m ~ rho * x * area]
221+
222+
return ODESystem(eqs, t, vars, pars; name, systems)
223+
end
224+
225+
systems = @named begin
226+
fluid = IC.HydraulicFluid(; bulk_modulus = 1e9)
227+
228+
src1 = IC.Pressure(;)
229+
src2 = IC.Pressure(;)
230+
231+
vol1 = Volume(; area = 0.01, direction = +1, x_int = 0.1)
232+
vol2 = Volume(; area = 0.01, direction = +1, x_int = 0.1)
233+
234+
mass = T.Mass(; m = 10)
235+
236+
sin1 = B.Sine(; frequency = 0.5, amplitude = +0.5e5, offset = 10e5)
237+
sin2 = B.Sine(; frequency = 0.5, amplitude = -0.5e5, offset = 10e5)
238+
end
239+
240+
eqs = [connect(fluid, src1.port)
241+
connect(fluid, src2.port)
242+
connect(src1.port, vol1.port)
243+
connect(src2.port, vol2.port)
244+
connect(vol1.flange, mass.flange, vol2.flange)
245+
connect(src1.p, sin1.output)
246+
connect(src2.p, sin2.output)]
247+
248+
initialization_eqs = [mass.s ~ 0.0
249+
mass.v ~ 0.0]
250+
251+
@mtkbuild sys = ODESystem(eqs, t, [], []; systems, initialization_eqs)
252+
prob = ODEProblem(sys, [], (0, 5))
253+
sol = solve(prob)
254+
@test SciMLBase.successful_retcode(sol)
255+
end

0 commit comments

Comments
 (0)