Skip to content

Commit b32de83

Browse files
authored
Merge pull request SciML#2122 from SciML/myb/domain
Implement domain
2 parents d7b3bfb + c7a1140 commit b32de83

File tree

2 files changed

+182
-20
lines changed

2 files changed

+182
-20
lines changed

Diff for: src/systems/abstractsystem.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -555,8 +555,8 @@ function defaults(sys::AbstractSystem)
555555
isempty(systems) ? defs : mapfoldr(namespace_defaults, merge, systems; init = defs)
556556
end
557557

558-
states(sys::AbstractSystem, v) = renamespace(sys, v)
559-
parameters(sys::AbstractSystem, v) = toparam(states(sys, v))
558+
states(sys::Union{AbstractSystem, Nothing}, v) = renamespace(sys, v)
559+
parameters(sys::Union{AbstractSystem, Nothing}, v) = toparam(states(sys, v))
560560
for f in [:states, :parameters]
561561
@eval function $f(sys::AbstractSystem, vs::AbstractArray)
562562
map(v -> $f(sys, v), vs)
@@ -806,7 +806,7 @@ end
806806
# TODO: what about inputs?
807807
function n_extra_equations(sys::AbstractSystem)
808808
isconnector(sys) && return length(get_states(sys))
809-
sys, csets = generate_connection_set(sys)
809+
sys, (csets, _) = generate_connection_set(sys)
810810
ceqs, instream_csets = generate_connection_equations_and_stream_connections(csets)
811811
n_outer_stream_variables = 0
812812
for cset in instream_csets

Diff for: src/systems/connectors.jl

+179-17
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ end
1313
abstract type AbstractConnectorType end
1414
struct StreamConnector <: AbstractConnectorType end
1515
struct RegularConnector <: AbstractConnectorType end
16+
struct DomainConnector <: AbstractConnectorType end
1617

1718
function connector_type(sys::AbstractSystem)
1819
sts = get_states(sys)
@@ -31,7 +32,10 @@ function connector_type(sys::AbstractSystem)
3132
end
3233
end
3334
(n_stream > 0 && n_flow > 1) &&
34-
error("There are multiple flow variables in $(nameof(sys))!")
35+
error("There are multiple flow variables in the stream connector $(nameof(sys))!")
36+
if n_flow == 1 && length(sts) == 1
37+
return DomainConnector()
38+
end
3539
if n_flow != n_regular
3640
@warn "$(nameof(sys)) contains $n_flow flow variables, yet $n_regular regular " *
3741
"(non-flow, non-stream, non-input, non-output) variables. " *
@@ -41,6 +45,8 @@ function connector_type(sys::AbstractSystem)
4145
n_stream > 0 ? StreamConnector() : RegularConnector()
4246
end
4347

48+
is_domain_connector(s) = isconnector(s) && get_connector_type(s) === DomainConnector()
49+
4450
get_systems(c::Connection) = c.systems
4551

4652
instream(a) = term(instream, unwrap(a), type = symtype(a))
@@ -117,18 +123,21 @@ function generate_isouter(sys::AbstractSystem)
117123
end
118124

119125
struct LazyNamespace
120-
namespace::Union{Nothing, Symbol}
126+
namespace::Union{Nothing, AbstractSystem}
121127
sys::Any
122128
end
123129

124-
Base.copy(l::LazyNamespace) = renamespace(l.namespace, l.sys)
125-
Base.nameof(l::LazyNamespace) = renamespace(l.namespace, nameof(l.sys))
130+
_getname(::Nothing) = nothing
131+
_getname(sys) = nameof(sys)
132+
Base.copy(l::LazyNamespace) = renamespace(_getname(l.namespace), l.sys)
133+
Base.nameof(l::LazyNamespace) = renamespace(_getname(l.namespace), nameof(l.sys))
126134

127135
struct ConnectionElement
128136
sys::LazyNamespace
129137
v::Any
130138
isouter::Bool
131139
end
140+
Base.nameof(l::ConnectionElement) = renamespace(nameof(l.sys), getname(l.v))
132141
function Base.hash(l::ConnectionElement, salt::UInt)
133142
hash(nameof(l.sys)) hash(l.v) hash(l.isouter) salt
134143
end
@@ -142,6 +151,7 @@ states(l::ConnectionElement, v) = states(copy(l.sys), v)
142151
struct ConnectionSet
143152
set::Vector{ConnectionElement} # namespace.sys, var, isouter
144153
end
154+
Base.copy(c::ConnectionSet) = ConnectionSet(copy(c.set))
145155

146156
function Base.show(io::IO, c::ConnectionSet)
147157
print(io, "<")
@@ -173,7 +183,41 @@ function ori(sys)
173183
end
174184

175185
function connection2set!(connectionsets, namespace, ss, isouter)
176-
nn = map(nameof, ss)
186+
regular_ss = []
187+
domain_ss = nothing
188+
for s in ss
189+
if is_domain_connector(s)
190+
if domain_ss === nothing
191+
domain_ss = s
192+
else
193+
names = join(string(map(name, ss)), ",")
194+
error("connect($names) contains multiple source domain connectors. There can only be one!")
195+
end
196+
else
197+
push!(regular_ss, s)
198+
end
199+
end
200+
T = ConnectionElement
201+
@assert !isempty(regular_ss)
202+
ss = regular_ss
203+
# domain connections don't generate any equations
204+
if domain_ss !== nothing
205+
cset = ConnectionElement[]
206+
dv = only(states(domain_ss))
207+
for (i, s) in enumerate(ss)
208+
sts = states(s)
209+
io = isouter(s)
210+
for (j, v) in enumerate(sts)
211+
vtype = get_connection_type(v)
212+
(vtype === Flow && isequal(v, dv)) || continue
213+
push!(cset, T(LazyNamespace(namespace, domain_ss), dv, false))
214+
push!(cset, T(LazyNamespace(namespace, s), v, io))
215+
end
216+
end
217+
@assert length(cset) > 0
218+
push!(connectionsets, ConnectionSet(cset))
219+
return connectionsets
220+
end
177221
s1 = first(ss)
178222
sts1v = states(s1)
179223
if isframe(s1) # Multibody
@@ -182,7 +226,6 @@ function connection2set!(connectionsets, namespace, ss, isouter)
182226
sts1v = [sts1v; orientation_vars]
183227
end
184228
sts1 = Set(sts1v)
185-
T = ConnectionElement
186229
num_statevars = length(sts1)
187230
csets = [T[] for _ in 1:num_statevars] # Add 9 orientation variables if connection is between multibody frames
188231
for (i, s) in enumerate(ss)
@@ -200,7 +243,12 @@ function connection2set!(connectionsets, namespace, ss, isouter)
200243
end
201244
end
202245
for cset in csets
203-
vtype = get_connection_type(first(cset).v)
246+
v = first(cset).v
247+
vtype = get_connection_type(v)
248+
if domain_ss !== nothing && vtype === Flow &&
249+
(dv = only(states(domain_ss)); isequal(v, dv))
250+
push!(cset, T(LazyNamespace(namespace, domain_ss), dv, false))
251+
end
204252
for k in 2:length(cset)
205253
vtype === get_connection_type(cset[k].v) || connection_error(ss)
206254
end
@@ -211,7 +259,11 @@ end
211259
function generate_connection_set(sys::AbstractSystem, find = nothing, replace = nothing)
212260
connectionsets = ConnectionSet[]
213261
sys = generate_connection_set!(connectionsets, sys, find, replace)
214-
sys, merge(connectionsets)
262+
domain_free_connectionsets = filter(connectionsets) do cset
263+
!any(s -> is_domain_connector(s.sys.sys), cset.set)
264+
end
265+
_, domainset = merge(connectionsets, true)
266+
sys, (merge(domain_free_connectionsets), domainset)
215267
end
216268

217269
function generate_connection_set!(connectionsets, sys::AbstractSystem, find, replace,
@@ -227,8 +279,8 @@ function generate_connection_set!(connectionsets, sys::AbstractSystem, find, rep
227279
for eq in eqs′
228280
lhs = eq.lhs
229281
rhs = eq.rhs
230-
if find !== nothing && find(rhs, namespace)
231-
neweq, extra_state = replace(rhs, namespace)
282+
if find !== nothing && find(rhs, _getname(namespace))
283+
neweq, extra_state = replace(rhs, _getname(namespace))
232284
if extra_state isa AbstractArray
233285
append!(extra_states, unwrap.(extra_state))
234286
elseif extra_state !== nothing
@@ -263,12 +315,12 @@ function generate_connection_set!(connectionsets, sys::AbstractSystem, find, rep
263315
@set! sys.states = [get_states(sys); extra_states]
264316
end
265317
@set! sys.systems = map(s -> generate_connection_set!(connectionsets, s, find, replace,
266-
renamespace(namespace, nameof(s))),
318+
renamespace(namespace, s)),
267319
subsys)
268320
@set! sys.eqs = eqs
269321
end
270322

271-
function Base.merge(csets::AbstractVector{<:ConnectionSet})
323+
function Base.merge(csets::AbstractVector{<:ConnectionSet}, domain = false)
272324
mcsets = ConnectionSet[]
273325
ele2idx = Dict{ConnectionElement, Int}()
274326
cacheset = Set{ConnectionElement}()
@@ -298,7 +350,94 @@ function Base.merge(csets::AbstractVector{<:ConnectionSet})
298350
empty!(cacheset)
299351
end
300352
end
301-
mcsets
353+
csets = mcsets
354+
domain || return csets
355+
g, roots = rooted_system_domain_graph(csets)
356+
domain_csets = []
357+
root_ijs = Set(g.id2cset[r] for r in roots)
358+
for r in roots
359+
nh = neighborhood(g, r, Inf)
360+
sources_idxs = intersect(nh, roots)
361+
# TODO: error reporting when length(sources_idxs) > 1
362+
length(sources_idxs) > 1 && error()
363+
i′, j′ = g.id2cset[r]
364+
source = csets[i′].set[j′]
365+
domain = source => []
366+
push!(domain_csets, domain)
367+
# get unique cset indices that `r` is (implicitly) connected to.
368+
idxs = BitSet(g.id2cset[i][1] for i in nh)
369+
for i in idxs
370+
for (j, ele) in enumerate(csets[i].set)
371+
(i, j) == (i′, j′) && continue
372+
if (i, j) in root_ijs
373+
error("Domain source $(nameof(source)) and $(nameof(ele)) are connected!")
374+
end
375+
push!(domain[2], ele)
376+
end
377+
end
378+
end
379+
csets, domain_csets
380+
end
381+
382+
struct SystemDomainGraph{C <: AbstractVector{<:ConnectionSet}} <: Graphs.AbstractGraph{Int}
383+
id2cset::Vector{NTuple{2, Int}}
384+
cset2id::Vector{Vector{Int}}
385+
csets::C
386+
sys2id::Dict{Symbol, Int}
387+
outne::Vector{Union{Nothing, Vector{Int}}}
388+
end
389+
390+
Graphs.nv(g::SystemDomainGraph) = length(g.id2cset)
391+
function Graphs.outneighbors(g::SystemDomainGraph, n::Int)
392+
i, j = g.id2cset[n]
393+
ids = copy(g.cset2id[i])
394+
visited = BitSet(n)
395+
for s in g.csets[i].set
396+
sys = s.sys.sys
397+
is_domain_connector(s.sys.sys) && continue
398+
ts = TearingState(s.sys.namespace)
399+
graph = ts.structure.graph
400+
mm = linear_subsys_adjmat!(ts)
401+
lineqs = BitSet(mm.nzrows)
402+
var2idx = Dict(reverse(en) for en in enumerate(ts.fullvars))
403+
vidx = get(var2idx, states(sys, s.v), 0)
404+
iszero(vidx) && continue
405+
ies = 𝑑neighbors(graph, vidx)
406+
for ie in ies
407+
ie in lineqs || continue
408+
for iv in 𝑠neighbors(graph, ie)
409+
iv == vidx && continue
410+
fv = ts.fullvars[iv]
411+
vtype = get_connection_type(fv)
412+
@assert vtype === Flow
413+
n′ = g.sys2id[renamespace(_getname(s.sys.namespace), getname(fv))]
414+
n′ in visited && continue
415+
push!(visited, n′)
416+
append!(ids, g.cset2id[g.id2cset[n′][1]])
417+
end
418+
end
419+
end
420+
ids
421+
end
422+
function rooted_system_domain_graph(csets::AbstractVector{<:ConnectionSet})
423+
id2cset = NTuple{2, Int}[]
424+
cset2id = Vector{Int}[]
425+
sys2id = Dict{Symbol, Int}()
426+
roots = BitSet()
427+
for (i, c) in enumerate(csets)
428+
cset2id′ = Int[]
429+
for (j, s) in enumerate(c.set)
430+
ij = (i, j)
431+
push!(id2cset, ij)
432+
n = length(id2cset)
433+
push!(cset2id′, n)
434+
sys2id[nameof(s)] = n
435+
is_domain_connector(s.sys.sys) && push!(roots, n)
436+
end
437+
push!(cset2id, cset2id′)
438+
end
439+
outne = Vector{Union{Nothing, Vector{Int}}}(undef, length(id2cset))
440+
SystemDomainGraph(id2cset, cset2id, csets, sys2id, outne), roots
302441
end
303442

304443
function generate_connection_equations_and_stream_connections(csets::AbstractVector{
@@ -331,13 +470,34 @@ function generate_connection_equations_and_stream_connections(csets::AbstractVec
331470
eqs, stream_connections
332471
end
333472

473+
function domain_defaults(domain_csets)
474+
def = Dict()
475+
for (s, mods) in domain_csets
476+
s_def = defaults(s.sys.sys)
477+
for m in mods
478+
ns_s_def = Dict(states(m.sys.sys, n) => n for (n, v) in s_def)
479+
for p in parameters(m.sys.namespace)
480+
d_p = get(ns_s_def, p, nothing)
481+
if d_p !== nothing
482+
def[parameters(m.sys.namespace, p)] = parameters(s.sys.namespace,
483+
parameters(s.sys.sys,
484+
d_p))
485+
end
486+
end
487+
end
488+
end
489+
def
490+
end
491+
334492
function expand_connections(sys::AbstractSystem, find = nothing, replace = nothing;
335493
debug = false, tol = 1e-10)
336-
sys, csets = generate_connection_set(sys, find, replace)
494+
sys, (csets, domain_csets) = generate_connection_set(sys, find, replace)
495+
d_defs = domain_defaults(domain_csets)
337496
ceqs, instream_csets = generate_connection_equations_and_stream_connections(csets)
338497
_sys = expand_instream(instream_csets, sys; debug = debug, tol = tol)
339498
sys = flatten(sys, true)
340499
@set! sys.eqs = [equations(_sys); ceqs]
500+
@set! sys.defaults = merge(get_defaults(sys), d_defs)
341501
end
342502

343503
function unnamespace(root, namespace)
@@ -387,7 +547,7 @@ function expand_instream(csets::AbstractVector{<:ConnectionSet}, sys::AbstractSy
387547
expr_cset = Dict()
388548
for cset in csets
389549
crep = first(cset.set)
390-
current = namespace == crep.sys.namespace
550+
current = namespace == _getname(crep.sys.namespace)
391551
for v in cset.set
392552
if (current || !v.isouter)
393553
expr_cset[namespaced_var(v)] = cset.set
@@ -445,7 +605,8 @@ function expand_instream(csets::AbstractVector{<:ConnectionSet}, sys::AbstractSy
445605

446606
# additional equations
447607
additional_eqs = Equation[]
448-
csets = filter(cset -> any(e -> e.sys.namespace === namespace, cset.set), csets)
608+
csets = filter(cset -> any(e -> _getname(e.sys.namespace) === namespace, cset.set),
609+
csets)
449610
for cset′ in csets
450611
cset = cset′.set
451612
connectors = Vector{Any}(undef, length(cset))
@@ -524,7 +685,8 @@ function expand_instream(csets::AbstractVector{<:ConnectionSet}, sys::AbstractSy
524685
end
525686

526687
function get_current_var(namespace, cele, sv)
527-
states(renamespace(unnamespace(namespace, cele.sys.namespace), cele.sys.sys), sv)
688+
states(renamespace(unnamespace(namespace, _getname(cele.sys.namespace)), cele.sys.sys),
689+
sv)
528690
end
529691

530692
function get_cset_sv(full_name_sv, cset)

0 commit comments

Comments
 (0)