Skip to content

Commit 27d7f70

Browse files
feat: cache subexpressions dependent only on previous SCCs
1 parent beb7070 commit 27d7f70

File tree

3 files changed

+140
-5
lines changed

3 files changed

+140
-5
lines changed

src/systems/nonlinear/nonlinearsystem.jl

+34-5
Original file line numberDiff line numberDiff line change
@@ -583,9 +583,11 @@ function SCCNonlinearFunction{iip}(
583583
f(resid, u, p) = f_iip(resid, u, p)
584584
f(resid, u, p::MTKParameters) = f_iip(resid, u, p...)
585585

586-
subsys = NonlinearSystem(_eqs, _dvs, ps; observed = _obs, parameter_dependencies = parameter_dependencies(sys), name = nameof(sys))
586+
subsys = NonlinearSystem(_eqs, _dvs, ps; observed = _obs,
587+
parameter_dependencies = parameter_dependencies(sys), name = nameof(sys))
587588
if get_index_cache(sys) !== nothing
588-
@set! subsys.index_cache = subset_unknowns_observed(get_index_cache(sys), sys, _dvs, getproperty.(_obs, (:lhs,)))
589+
@set! subsys.index_cache = subset_unknowns_observed(
590+
get_index_cache(sys), sys, _dvs, getproperty.(_obs, (:lhs,)))
589591
@set! subsys.complete = true
590592
end
591593

@@ -624,8 +626,7 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map,
624626
explicitfuns = []
625627
nlfuns = []
626628
prevobsidxs = Int[]
627-
cachevars = []
628-
cacheexprs = []
629+
cachesize = 0
629630
for (i, (escc, vscc)) in enumerate(zip(eq_sccs, var_sccs))
630631
# subset unknowns and equations
631632
_dvs = dvs[vscc]
@@ -636,6 +637,32 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map,
636637
setdiff!(obsidxs, prevobsidxs)
637638
_obs = obs[obsidxs]
638639

640+
# get all subexpressions in the RHS which we can precompute in the cache
641+
banned_vars = Set{Any}(vcat(_dvs, getproperty.(_obs, (:lhs,))))
642+
for var in banned_vars
643+
iscall(var) || continue
644+
operation(var) === getindex || continue
645+
push!(banned_vars, arguments(var)[1])
646+
end
647+
state = Dict()
648+
for i in eachindex(_obs)
649+
_obs[i] = _obs[i].lhs ~ subexpressions_not_involving_vars!(
650+
_obs[i].rhs, banned_vars, state)
651+
end
652+
for i in eachindex(_eqs)
653+
_eqs[i] = _eqs[i].lhs ~ subexpressions_not_involving_vars!(
654+
_eqs[i].rhs, banned_vars, state)
655+
end
656+
657+
# cached variables and their corresponding expressions
658+
cachevars = Any[obs[i].lhs for i in prevobsidxs]
659+
cacheexprs = Any[obs[i].rhs for i in prevobsidxs]
660+
for (k, v) in state
661+
push!(cachevars, unwrap(v))
662+
push!(cacheexprs, unwrap(k))
663+
end
664+
cachesize = max(cachesize, length(cachevars))
665+
639666
if isempty(cachevars)
640667
push!(explicitfuns, Returns(nothing))
641668
else
@@ -655,7 +682,9 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map,
655682
append!(prevobsidxs, obsidxs)
656683
end
657684

658-
p = rebuild_with_caches(p, BufferTemplate(eltype(u0), length(cachevars)))
685+
if cachesize != 0
686+
p = rebuild_with_caches(p, BufferTemplate(eltype(u0), cachesize))
687+
end
659688

660689
subprobs = []
661690
for (f, vscc) in zip(nlfuns, var_sccs)

src/utils.jl

+89
Original file line numberDiff line numberDiff line change
@@ -1001,6 +1001,11 @@ end
10011001
diff2term_with_unit(x, t) = _with_unit(diff2term, x, t)
10021002
lower_varname_with_unit(var, iv, order) = _with_unit(lower_varname, var, iv, iv, order)
10031003

1004+
"""
1005+
$(TYPEDSIGNATURES)
1006+
1007+
Check if `sym` represents a symbolic floating point number or array of such numbers.
1008+
"""
10041009
function is_variable_floatingpoint(sym)
10051010
sym = unwrap(sym)
10061011
T = symtype(sym)
@@ -1052,3 +1057,87 @@ function observed_equations_used_by(sys::AbstractSystem, exprs)
10521057
sort!(obsidxs)
10531058
return obsidxs
10541059
end
1060+
1061+
"""
1062+
$(TYPEDSIGNATURES)
1063+
1064+
Given an expression `expr`, return a dictionary mapping subexpressions of `expr` that do
1065+
not involve variables in `vars` to anonymous symbolic variables. Also return the modified
1066+
`expr` with the substitutions indicated by the dictionary. If `expr` is a function
1067+
of only `vars`, then all of the returned subexpressions can be precomputed.
1068+
1069+
Note that this will only process subexpressions floating point value. Additionally,
1070+
array variables must be passed in both scalarized and non-scalarized forms in `vars`.
1071+
"""
1072+
function subexpressions_not_involving_vars(expr, vars)
1073+
expr = unwrap(expr)
1074+
vars = map(unwrap, vars)
1075+
state = Dict()
1076+
newexpr = subexpressions_not_involving_vars!(expr, vars, state)
1077+
return state, newexpr
1078+
end
1079+
1080+
"""
1081+
$(TYPEDSIGNATURES)
1082+
1083+
Mutating version of `subexpressions_not_involving_vars` which writes to `state`. Only
1084+
returns the modified `expr`.
1085+
"""
1086+
function subexpressions_not_involving_vars!(expr, vars, state::Dict{Any, Any})
1087+
expr = unwrap(expr)
1088+
symbolic_type(expr) == NotSymbolic() && return expr
1089+
iscall(expr) || return expr
1090+
is_variable_floatingpoint(expr) || return expr
1091+
symtype(expr) <: Union{Real, AbstractArray{<:Real}} || return expr
1092+
Symbolics.shape(expr) == Symbolics.Unknown() && return expr
1093+
haskey(state, expr) && return state[expr]
1094+
vs = ModelingToolkit.vars(expr)
1095+
intersect!(vs, vars)
1096+
if isempty(vs)
1097+
sym = gensym(:subexpr)
1098+
stype = symtype(expr)
1099+
var = similar_variable(expr, sym)
1100+
state[expr] = var
1101+
return var
1102+
end
1103+
op = operation(expr)
1104+
args = arguments(expr)
1105+
if (op == (+) || op == (*)) && symbolic_type(expr) !== ArraySymbolic()
1106+
indep_args = []
1107+
dep_args = []
1108+
for arg in args
1109+
_vs = ModelingToolkit.vars(arg)
1110+
intersect!(_vs, vars)
1111+
if !isempty(_vs)
1112+
push!(dep_args, subexpressions_not_involving_vars!(arg, vars, state))
1113+
else
1114+
push!(indep_args, arg)
1115+
end
1116+
end
1117+
indep_term = reduce(op, indep_args; init = Int(op == (*)))
1118+
indep_term = subexpressions_not_involving_vars!(indep_term, vars, state)
1119+
dep_term = reduce(op, dep_args; init = Int(op == (*)))
1120+
return op(indep_term, dep_term)
1121+
end
1122+
newargs = map(args) do arg
1123+
symbolic_type(arg) != NotSymbolic() || is_array_of_symbolics(arg) || return arg
1124+
subexpressions_not_involving_vars!(arg, vars, state)
1125+
end
1126+
return maketerm(typeof(expr), op, newargs, metadata(expr))
1127+
end
1128+
1129+
"""
1130+
$(TYPEDSIGNATURES)
1131+
1132+
Create an anonymous symbolic variable of the same shape, size and symtype as `var`, with
1133+
name `gensym(name)`. Does not support unsized array symbolics.
1134+
"""
1135+
function similar_variable(var::BasicSymbolic, name = :anon)
1136+
name = gensym(name)
1137+
stype = symtype(var)
1138+
sym = Symbolics.variable(name; T = stype)
1139+
if size(var) !== ()
1140+
sym = setmetadata(sym, Symbolics.ArrayShapeCtx, map(Base.OneTo, size(var)))
1141+
end
1142+
return sym
1143+
end

test/scc_nonlinear_problem.jl

+17
Original file line numberDiff line numberDiff line change
@@ -143,3 +143,20 @@ end
143143
@test sol.usccsol.u atol=1e-10
144144
end
145145

146+
@testset "Expression caching" begin
147+
@variables x[1:4] = rand(4)
148+
val = Ref(0)
149+
function func(x, y)
150+
val[] += 1
151+
x + y
152+
end
153+
@register_symbolic func(x, y)
154+
@mtkbuild sys = NonlinearSystem([0 ~ x[1]^3 + x[2]^3 - 5
155+
0 ~ sin(x[1] - x[2]) - 0.5
156+
0 ~ func(x[1], x[2]) * exp(x[3]) - x[4]^3 - 5
157+
0 ~ func(x[1], x[2]) * exp(x[4]) - x[3]^3 - 4])
158+
sccprob = SCCNonlinearProblem(sys, [])
159+
sccsol = solve(sccprob, NewtonRaphson())
160+
@test SciMLBase.successful_retcode(sccsol)
161+
@test val[] == 1
162+
end

0 commit comments

Comments
 (0)