@@ -573,29 +573,37 @@ function DiffEqBase.NonlinearLeastSquaresProblem{iip}(sys::NonlinearSystem, u0ma
573
573
NonlinearLeastSquaresProblem {iip} (f, u0, p; filter_kwargs (kwargs)... )
574
574
end
575
575
576
+ const TypeT = Union{DataType, UnionAll}
577
+
576
578
struct CacheWriter{F}
577
579
fn:: F
578
580
end
579
581
580
582
function (cw:: CacheWriter )(p, sols)
581
- cw. fn (p. caches[ 1 ] , sols, p... )
583
+ cw. fn (p. caches, sols, p... )
582
584
end
583
585
584
- function CacheWriter (sys:: AbstractSystem , exprs, solsyms, obseqs:: Vector{Equation} ;
586
+ function CacheWriter (sys:: AbstractSystem , buffer_types:: Vector{TypeT} ,
587
+ exprs:: Dict{TypeT, Vector{Any}} , solsyms, obseqs:: Vector{Equation} ;
585
588
eval_expression = false , eval_module = @__MODULE__ )
586
589
ps = parameters (sys)
587
590
rps = reorder_parameters (sys, ps)
588
591
obs_assigns = [eq. lhs ← eq. rhs for eq in obseqs]
589
592
cmap, cs = get_cmap (sys)
590
593
cmap_assigns = [eq. lhs ← eq. rhs for eq in cmap]
594
+
595
+ outsyms = [Symbol (:out , i) for i in eachindex (buffer_types)]
596
+ body = map (eachindex (buffer_types), buffer_types) do i, T
597
+ Symbol (:tmp , i) ← SetArray (true , :(out[$ i]), get (exprs, T, []))
598
+ end
591
599
fn = Func (
592
600
[:out , DestructuredArgs (DestructuredArgs .(solsyms)),
593
601
DestructuredArgs .(rps)... ],
594
602
[],
595
- SetArray ( true , :out , exprs )
603
+ Let (body , :() )
596
604
) |> wrap_assignments (false , obs_assigns)[2 ] |>
597
605
wrap_parameter_dependencies (sys, false )[2 ] |>
598
- wrap_array_vars (sys, exprs ; dvs = nothing , inputs = [])[2 ] |>
606
+ wrap_array_vars (sys, [] ; dvs = nothing , inputs = [])[2 ] |>
599
607
wrap_assignments (false , cmap_assigns)[2 ] |> toexpr
600
608
return CacheWriter (eval_or_rgf (fn; eval_expression, eval_module))
601
609
end
@@ -677,8 +685,17 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map,
677
685
678
686
explicitfuns = []
679
687
nlfuns = []
680
- prevobsidxs = Int[]
681
- cachesize = 0
688
+ prevobsidxs = BlockArray (undef_blocks, Vector{Int}, Int[])
689
+ # Cache buffer types and corresponding sizes. Stored as a pair of arrays instead of a
690
+ # dict to maintain a consistent order of buffers across SCCs
691
+ cachetypes = TypeT[]
692
+ cachesizes = Int[]
693
+ # explicitfun! related information for each SCC
694
+ # We need to compute buffer sizes before doing any codegen
695
+ scc_cachevars = Dict{TypeT, Vector{Any}}[]
696
+ scc_cacheexprs = Dict{TypeT, Vector{Any}}[]
697
+ scc_eqs = Vector{Equation}[]
698
+ scc_obs = Vector{Equation}[]
682
699
for (i, (escc, vscc)) in enumerate (zip (eq_sccs, var_sccs))
683
700
# subset unknowns and equations
684
701
_dvs = dvs[vscc]
@@ -690,11 +707,10 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map,
690
707
_obs = obs[obsidxs]
691
708
692
709
# get all subexpressions in the RHS which we can precompute in the cache
710
+ # precomputed subexpressions should not contain `banned_vars`
693
711
banned_vars = Set {Any} (vcat (_dvs, getproperty .(_obs, (:lhs ,))))
694
- for var in banned_vars
695
- iscall (var) || continue
696
- operation (var) === getindex || continue
697
- push! (banned_vars, arguments (var)[1 ])
712
+ filter! (banned_vars) do var
713
+ symbolic_type (var) != ArraySymbolic () || all (x -> var[i] in banned_vars, eachindex (var))
698
714
end
699
715
state = Dict ()
700
716
for i in eachindex (_obs)
@@ -706,37 +722,85 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map,
706
722
_eqs[i]. rhs, banned_vars, state)
707
723
end
708
724
709
- # cached variables and their corresponding expressions
710
- cachevars = Any[obs[i]. lhs for i in prevobsidxs]
711
- cacheexprs = Any[obs[i]. lhs for i in prevobsidxs]
725
+ # map from symtype to cached variables and their expressions
726
+ cachevars = Dict {Union{DataType, UnionAll}, Vector{Any}} ()
727
+ cacheexprs = Dict {Union{DataType, UnionAll}, Vector{Any}} ()
728
+ # observed of previous SCCs are in the cache
729
+ # NOTE: When we get proper CSE, we can substitute these
730
+ # and then use `subexpressions_not_involving_vars!`
731
+ for i in prevobsidxs
732
+ T = symtype (obs[i]. lhs)
733
+ buf = get! (() -> Any[], cachevars, T)
734
+ push! (buf, obs[i]. lhs)
735
+
736
+ buf = get! (() -> Any[], cacheexprs, T)
737
+ push! (buf, obs[i]. lhs)
738
+ end
739
+
712
740
for (k, v) in state
713
- push! (cachevars, unwrap (v))
714
- push! (cacheexprs, unwrap (k))
741
+ k = unwrap (k)
742
+ v = unwrap (v)
743
+ T = symtype (k)
744
+ buf = get! (() -> Any[], cachevars, T)
745
+ push! (buf, v)
746
+ buf = get! (() -> Any[], cacheexprs, T)
747
+ push! (buf, k)
715
748
end
716
- cachesize = max (cachesize, length (cachevars))
749
+
750
+ # update the sizes of cache buffers
751
+ for (T, buf) in cachevars
752
+ idx = findfirst (isequal (T), cachetypes)
753
+ if idx === nothing
754
+ push! (cachetypes, T)
755
+ push! (cachesizes, 0 )
756
+ idx = lastindex (cachetypes)
757
+ end
758
+ cachesizes[idx] = max (cachesizes[idx], length (buf))
759
+ end
760
+
761
+ push! (scc_cachevars, cachevars)
762
+ push! (scc_cacheexprs, cacheexprs)
763
+ push! (scc_eqs, _eqs)
764
+ push! (scc_obs, _obs)
765
+ blockpush! (prevobsidxs, obsidxs)
766
+ end
767
+
768
+ for (i, (escc, vscc)) in enumerate (zip (eq_sccs, var_sccs))
769
+ _dvs = dvs[vscc]
770
+ _eqs = scc_eqs[i]
771
+ _prevobsidxs = reduce (vcat, blocks (prevobsidxs)[1 : (i - 1 )]; init = Int[])
772
+ _obs = scc_obs[i]
773
+ cachevars = scc_cachevars[i]
774
+ cacheexprs = scc_cacheexprs[i]
717
775
718
776
if isempty (cachevars)
719
777
push! (explicitfuns, Returns (nothing ))
720
778
else
721
779
solsyms = getindex .((dvs,), view (var_sccs, 1 : (i - 1 )))
722
780
push! (explicitfuns,
723
- CacheWriter (sys, cacheexprs, solsyms, obs[prevobsidxs ];
781
+ CacheWriter (sys, cachetypes, cacheexprs, solsyms, obs[_prevobsidxs ];
724
782
eval_expression, eval_module))
725
783
end
784
+
785
+ cachebufsyms = Tuple (map (cachetypes) do T
786
+ get (cachevars, T, [])
787
+ end )
726
788
f = SCCNonlinearFunction {iip} (
727
- sys, _eqs, _dvs, _obs, (cachevars,) ; eval_expression, eval_module, kwargs... )
789
+ sys, _eqs, _dvs, _obs, cachebufsyms ; eval_expression, eval_module, kwargs... )
728
790
push! (nlfuns, f)
729
- append! (cachevars, _dvs)
730
- append! (cacheexprs, _dvs)
731
- for i in obsidxs
732
- push! (cachevars, obs[i]. lhs)
733
- push! (cacheexprs, obs[i]. rhs)
734
- end
735
- append! (prevobsidxs, obsidxs)
736
791
end
737
792
738
- if cachesize != 0
739
- p = rebuild_with_caches (p, BufferTemplate (eltype (u0), cachesize))
793
+ if ! isempty (cachetypes)
794
+ templates = map (cachetypes, cachesizes) do T, n
795
+ # Real refers to `eltype(u0)`
796
+ if T == Real
797
+ T = eltype (u0)
798
+ elseif T <: Array && eltype (T) == Real
799
+ T = Array{eltype (u0), ndims (T)}
800
+ end
801
+ BufferTemplate (T, n)
802
+ end
803
+ p = rebuild_with_caches (p, templates... )
740
804
end
741
805
742
806
subprobs = []
0 commit comments