Skip to content

Commit 6c33eb3

Browse files
Merge pull request #3453 from AayushSabharwal/as/anti-connect
feat: support analysis points duplicating existing connections
2 parents a553fd1 + 1de31d0 commit 6c33eb3

File tree

5 files changed

+408
-23
lines changed

5 files changed

+408
-23
lines changed

src/systems/abstractsystem.jl

+70
Original file line numberDiff line numberDiff line change
@@ -878,6 +878,7 @@ for prop in [:eqs
878878
:assertions
879879
:solved_unknowns
880880
:split_idxs
881+
:ignored_connections
881882
:parent
882883
:is_dde
883884
:tstops
@@ -1394,6 +1395,75 @@ function assertions(sys::AbstractSystem)
13941395
return merge(asserts, namespaced_asserts)
13951396
end
13961397

1398+
const HierarchyVariableT = Vector{Union{BasicSymbolic, Symbol}}
1399+
const HierarchySystemT = Vector{Union{AbstractSystem, Symbol}}
1400+
"""
1401+
The type returned from `as_hierarchy`.
1402+
"""
1403+
const HierarchyT = Union{HierarchyVariableT, HierarchySystemT}
1404+
1405+
"""
1406+
$(TYPEDSIGNATURES)
1407+
1408+
The inverse operation of `as_hierarchy`.
1409+
"""
1410+
function from_hierarchy(hierarchy::HierarchyT)
1411+
namefn = hierarchy[1] isa AbstractSystem ? nameof : getname
1412+
foldl(@view hierarchy[2:end]; init = hierarchy[1]) do sys, name
1413+
rename(sys, Symbol(name, NAMESPACE_SEPARATOR, namefn(sys)))
1414+
end
1415+
end
1416+
1417+
"""
1418+
$(TYPEDSIGNATURES)
1419+
1420+
Represent a namespaced system (or variable) `sys` as a hierarchy. Return a vector, where
1421+
the first element is the unnamespaced system (variable) and subsequent elements are
1422+
`Symbol`s representing the parents of the unnamespaced system (variable) in order from
1423+
inner to outer.
1424+
"""
1425+
function as_hierarchy(sys::Union{AbstractSystem, BasicSymbolic})::HierarchyT
1426+
namefn = sys isa AbstractSystem ? nameof : getname
1427+
# get the hierarchy
1428+
hierarchy = namespace_hierarchy(namefn(sys))
1429+
# rename the system with unnamespaced name
1430+
newsys = rename(sys, hierarchy[end])
1431+
# and remove it from the list
1432+
pop!(hierarchy)
1433+
# reverse it to go from inner to outer
1434+
reverse!(hierarchy)
1435+
# concatenate
1436+
T = sys isa AbstractSystem ? AbstractSystem : BasicSymbolic
1437+
return Union{Symbol, T}[newsys; hierarchy]
1438+
end
1439+
1440+
"""
1441+
$(TYPEDSIGNATURES)
1442+
1443+
Get the connections to ignore for `sys` and its subsystems. The returned value is a
1444+
`Tuple` similar in structure to the `ignored_connections` field. Each system (variable)
1445+
in the first (second) element of the tuple is also passed through `as_hierarchy`.
1446+
"""
1447+
function ignored_connections(sys::AbstractSystem)
1448+
has_ignored_connections(sys) || return (HierarchySystemT[], HierarchyVariableT[])
1449+
1450+
ics = get_ignored_connections(sys)
1451+
if ics === nothing
1452+
ics = (HierarchySystemT[], HierarchyVariableT[])
1453+
end
1454+
# turn into hierarchies
1455+
ics = (map(as_hierarchy, ics[1]), map(as_hierarchy, ics[2]))
1456+
systems = get_systems(sys)
1457+
# for each subsystem, get its ignored connections, add the name of the subsystem
1458+
# to the hierarchy and concatenate corresponding buffers of the result
1459+
result = mapreduce(Broadcast.BroadcastFunction(vcat), systems; init = ics) do subsys
1460+
sub_ics = ignored_connections(subsys)
1461+
(map(Base.Fix2(push!, nameof(subsys)), sub_ics[1]),
1462+
map(Base.Fix2(push!, nameof(subsys)), sub_ics[2]))
1463+
end
1464+
return (Vector{HierarchySystemT}(result[1]), Vector{HierarchyVariableT}(result[2]))
1465+
end
1466+
13971467
"""
13981468
$(TYPEDSIGNATURES)
13991469

src/systems/analysis_points.jl

+27-3
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,27 @@ function get_analysis_variable(var, name, iv; perturb = true)
412412
return pvar, default
413413
end
414414

415+
function with_analysis_point_ignored(sys::AbstractSystem, ap::AnalysisPoint)
416+
has_ignored_connections(sys) || return sys
417+
ignored = get_ignored_connections(sys)
418+
if ignored === nothing
419+
ignored = (ODESystem[], BasicSymbolic[])
420+
else
421+
ignored = copy.(ignored)
422+
end
423+
if ap.outputs === nothing
424+
error("Empty analysis point")
425+
end
426+
for x in ap.outputs
427+
if x isa ODESystem
428+
push!(ignored[1], x)
429+
else
430+
push!(ignored[2], unwrap(x))
431+
end
432+
end
433+
return @set sys.ignored_connections = ignored
434+
end
435+
415436
#### PRIMITIVE TRANSFORMATIONS
416437

417438
const DOC_WILL_REMOVE_AP = """
@@ -469,7 +490,9 @@ function apply_transformation(tf::Break, sys::AbstractSystem)
469490
ap = breaksys_eqs[ap_idx].rhs
470491
deleteat!(breaksys_eqs, ap_idx)
471492

472-
tf.add_input || return sys, ()
493+
breaksys = with_analysis_point_ignored(breaksys, ap)
494+
495+
tf.add_input || return breaksys, ()
473496

474497
ap_ivar = ap_var(ap.input)
475498
new_var, new_def = get_analysis_variable(ap_ivar, nameof(ap), get_iv(sys))
@@ -511,7 +534,7 @@ function apply_transformation(tf::GetInput, sys::AbstractSystem)
511534
ap_idx === nothing &&
512535
error("Analysis point $(nameof(tf.ap)) not found in system $(nameof(sys)).")
513536
# get the anlysis point
514-
ap_sys_eqs = copy(get_eqs(ap_sys))
537+
ap_sys_eqs = get_eqs(ap_sys)
515538
ap = ap_sys_eqs[ap_idx].rhs
516539

517540
# input variable
@@ -570,6 +593,7 @@ function apply_transformation(tf::PerturbOutput, sys::AbstractSystem)
570593
ap = ap_sys_eqs[ap_idx].rhs
571594
# remove analysis point
572595
deleteat!(ap_sys_eqs, ap_idx)
596+
ap_sys = with_analysis_point_ignored(ap_sys, ap)
573597

574598
# add equations involving new variable
575599
ap_ivar = ap_var(ap.input)
@@ -634,7 +658,7 @@ function apply_transformation(tf::AddVariable, sys::AbstractSystem)
634658
ap_idx = analysis_point_index(ap_sys, tf.ap)
635659
ap_idx === nothing &&
636660
error("Analysis point $(nameof(tf.ap)) not found in system $(nameof(sys)).")
637-
ap_sys_eqs = copy(get_eqs(ap_sys))
661+
ap_sys_eqs = get_eqs(ap_sys)
638662
ap = ap_sys_eqs[ap_idx].rhs
639663

640664
# add equations involving new variable

src/systems/connectors.jl

+128-12
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,34 @@ function ori(sys)
315315
end
316316
end
317317

318-
function connection2set!(connectionsets, namespace, ss, isouter)
318+
"""
319+
$(TYPEDSIGNATURES)
320+
321+
Populate `connectionsets` with connections between the connectors `ss`, all of which are
322+
namespaced by `namespace`.
323+
324+
# Keyword Arguments
325+
- `ignored_connects`: A tuple of the systems and variables for which connections should be
326+
ignored. Of the format returned from `as_hierarchy`.
327+
- `namespaced_ignored_systems`: The `from_hierarchy` versions of entries in
328+
`ignored_connects[1]`, purely to avoid unnecessary recomputation.
329+
"""
330+
function connection2set!(connectionsets, namespace, ss, isouter;
331+
ignored_connects = (HierarchySystemT[], HierarchyVariableT[]),
332+
namespaced_ignored_systems = ODESystem[])
333+
ignored_systems, ignored_variables = ignored_connects
334+
# ignore specified systems
335+
ss = filter(ss) do s
336+
all(namespaced_ignored_systems) do igsys
337+
nameof(igsys) != nameof(s)
338+
end
339+
end
340+
# `ignored_variables` for each `s` in `ss`
341+
corresponding_ignored_variables = map(
342+
Base.Fix2(ignored_systems_for_subsystem, ignored_variables), ss)
343+
corresponding_namespaced_ignored_variables = map(
344+
Broadcast.BroadcastFunction(from_hierarchy), corresponding_ignored_variables)
345+
319346
regular_ss = []
320347
domain_ss = nothing
321348
for s in ss
@@ -340,9 +367,12 @@ function connection2set!(connectionsets, namespace, ss, isouter)
340367
for (i, s) in enumerate(ss)
341368
sts = unknowns(s)
342369
io = isouter(s)
343-
for (j, v) in enumerate(sts)
370+
_ignored_variables = corresponding_ignored_variables[i]
371+
_namespaced_ignored_variables = corresponding_namespaced_ignored_variables[i]
372+
for v in sts
344373
vtype = get_connection_type(v)
345374
(vtype === Flow && isequal(v, dv)) || continue
375+
any(isequal(v), _namespaced_ignored_variables) && continue
346376
push!(cset, T(LazyNamespace(namespace, domain_ss), dv, false))
347377
push!(cset, T(LazyNamespace(namespace, s), v, io))
348378
end
@@ -360,6 +390,12 @@ function connection2set!(connectionsets, namespace, ss, isouter)
360390
end
361391
sts1 = Set(sts1v)
362392
num_unknowns = length(sts1)
393+
394+
# we don't filter here because `csets` should include the full set of unknowns.
395+
# not all of `ss` will have the same (or any) variables filtered so the ones
396+
# that aren't should still go in the right cset. Since `sts1` is only used for
397+
# validating that all systems being connected are of the same type, it has
398+
# unfiltered entries.
363399
csets = [T[] for _ in 1:num_unknowns] # Add 9 orientation variables if connection is between multibody frames
364400
for (i, s) in enumerate(ss)
365401
unknown_vars = unknowns(s)
@@ -372,7 +408,10 @@ function connection2set!(connectionsets, namespace, ss, isouter)
372408
all(Base.Fix2(in, sts1), unknown_vars)) ||
373409
connection_error(ss))
374410
io = isouter(s)
411+
# don't `filter!` here so that `j` points to the correct cset regardless of
412+
# which variables are filtered.
375413
for (j, v) in enumerate(unknown_vars)
414+
any(isequal(v), corresponding_namespaced_ignored_variables[i]) && continue
376415
push!(csets[j], T(LazyNamespace(namespace, s), v, io))
377416
end
378417
end
@@ -395,16 +434,48 @@ function generate_connection_set(
395434
connectionsets = ConnectionSet[]
396435
domain_csets = ConnectionSet[]
397436
sys = generate_connection_set!(
398-
connectionsets, domain_csets, sys, find, replace, scalarize)
437+
connectionsets, domain_csets, sys, find, replace, scalarize, nothing,
438+
# include systems to be ignored
439+
ignored_connections(sys))
399440
csets = merge(connectionsets)
400441
domain_csets = merge([csets; domain_csets], true)
401442

402443
sys, (csets, domain_csets)
403444
end
404445

446+
"""
447+
$(TYPEDSIGNATURES)
448+
449+
Generate connection sets from `connect` equations.
450+
451+
# Arguments
452+
453+
- `connectionsets` is the list of connection sets to be populated by recursively
454+
descending `sys`.
455+
- `domain_csets` is the list of connection sets for domain connections.
456+
- `sys` is the system whose equations are to be searched.
457+
- `namespace` is a system representing the namespace in which `sys` exists, or `nothing`
458+
for no namespace (if `sys` is top-level).
459+
- `ignored_connects` is a tuple. The first (second) element is a list of systems
460+
(variables) in the format returned by `as_hierarchy` to be ignored when generating
461+
connections. This is typically because the connections they are used in were removed by
462+
analysis point transformations.
463+
"""
405464
function generate_connection_set!(connectionsets, domain_csets,
406-
sys::AbstractSystem, find, replace, scalarize, namespace = nothing)
465+
sys::AbstractSystem, find, replace, scalarize, namespace = nothing,
466+
ignored_connects = (HierarchySystemT[], HierarchyVariableT[]))
407467
subsys = get_systems(sys)
468+
ignored_systems, ignored_variables = ignored_connects
469+
# turn hierarchies into namespaced systems
470+
namespaced_ignored_systems = from_hierarchy.(ignored_systems)
471+
namespaced_ignored_variables = from_hierarchy.(ignored_variables)
472+
namespaced_ignored = (namespaced_ignored_systems, namespaced_ignored_variables)
473+
# filter the subsystems of `sys` to exclude ignored ones
474+
filtered_subsys = filter(subsys) do ss
475+
all(namespaced_ignored_systems) do igsys
476+
nameof(igsys) != nameof(ss)
477+
end
478+
end
408479

409480
isouter = generate_isouter(sys)
410481
eqs′ = get_eqs(sys)
@@ -430,7 +501,8 @@ function generate_connection_set!(connectionsets, domain_csets,
430501
neweq isa AbstractArray ? append!(eqs, neweq) : push!(eqs, neweq)
431502
else
432503
if lhs isa Connection && get_systems(lhs) === :domain
433-
connection2set!(domain_csets, namespace, get_systems(rhs), isouter)
504+
connection2set!(domain_csets, namespace, get_systems(rhs), isouter;
505+
ignored_connects, namespaced_ignored_systems)
434506
elseif isconnection(rhs)
435507
push!(cts, get_systems(rhs))
436508
else
@@ -446,17 +518,23 @@ function generate_connection_set!(connectionsets, domain_csets,
446518

447519
# all connectors are eventually inside connectors.
448520
T = ConnectionElement
449-
for s in subsys
521+
# only generate connection sets for systems that are not ignored
522+
for s in filtered_subsys
450523
isconnector(s) || continue
451524
is_domain_connector(s) && continue
525+
_ignored_variables = ignored_systems_for_subsystem(s, ignored_variables)
526+
_namespaced_ignored_variables = from_hierarchy.(_ignored_variables)
452527
for v in unknowns(s)
453528
Flow === get_connection_type(v) || continue
529+
# ignore specified variables
530+
any(isequal(v), _namespaced_ignored_variables) && continue
454531
push!(connectionsets, ConnectionSet([T(LazyNamespace(namespace, s), v, false)]))
455532
end
456533
end
457534

458535
for ct in cts
459-
connection2set!(connectionsets, namespace, ct, isouter)
536+
connection2set!(connectionsets, namespace, ct, isouter;
537+
ignored_connects, namespaced_ignored_systems)
460538
end
461539

462540
# pre order traversal
@@ -465,12 +543,38 @@ function generate_connection_set!(connectionsets, domain_csets,
465543
end
466544
@set! sys.systems = map(
467545
s -> generate_connection_set!(connectionsets, domain_csets, s,
468-
find, replace, scalarize,
469-
renamespace(namespace, s)),
546+
find, replace, scalarize, renamespace(namespace, s),
547+
ignored_systems_for_subsystem.((s,), ignored_connects)),
470548
subsys)
471549
@set! sys.eqs = eqs
472550
end
473551

552+
"""
553+
$(TYPEDSIGNATURES)
554+
555+
Given a subsystem `subsys` of a parent system and a list of systems (variables) to be
556+
ignored by `generate_connection_set!` (`expand_variable_connections`), filter
557+
`ignored_systems` to only include those present in the subtree of `subsys` and update
558+
their hierarchy to not include `subsys`.
559+
"""
560+
function ignored_systems_for_subsystem(
561+
subsys::AbstractSystem, ignored_systems::Vector{<:HierarchyT})
562+
result = eltype(ignored_systems)[]
563+
# in case `subsys` is namespaced, get its hierarchy and compare suffixes
564+
# instead of the just the last element
565+
suffix = reverse!(namespace_hierarchy(nameof(subsys)))
566+
N = length(suffix)
567+
for igsys in ignored_systems
568+
if igsys[(end - N + 1):end] == suffix
569+
push!(result, copy(igsys))
570+
for i in 1:N
571+
pop!(result[end])
572+
end
573+
end
574+
end
575+
return result
576+
end
577+
474578
function Base.merge(csets::AbstractVector{<:ConnectionSet}, allouter = false)
475579
ele2idx = Dict{ConnectionElement, Int}()
476580
idx2ele = ConnectionElement[]
@@ -576,16 +680,25 @@ end
576680
Recursively descend through the hierarchy of `sys` and expand all connection equations
577681
of causal variables. Return the modified system.
578682
"""
579-
function expand_variable_connections(sys::AbstractSystem)
683+
function expand_variable_connections(sys::AbstractSystem; ignored_variables = nothing)
684+
if ignored_variables === nothing
685+
ignored_variables = ignored_connections(sys)[2]
686+
end
687+
namespaced_ignored = from_hierarchy.(ignored_variables)
580688
eqs = copy(get_eqs(sys))
581689
valid_idxs = trues(length(eqs))
582690
additional_eqs = Equation[]
583691

584692
for (i, eq) in enumerate(eqs)
585693
eq.lhs isa Connection || continue
586694
connection = eq.rhs
587-
elements = connection.systems
695+
elements = get_systems(connection)
588696
is_causal_variable_connection(connection) || continue
697+
elements = filter(elements) do el
698+
all(namespaced_ignored) do var
699+
getname(var) != getname(el.var)
700+
end
701+
end
589702

590703
valid_idxs[i] = false
591704
elements = map(x -> x.var, elements)
@@ -595,7 +708,10 @@ function expand_variable_connections(sys::AbstractSystem)
595708
end
596709
end
597710
eqs = [eqs[valid_idxs]; additional_eqs]
598-
subsystems = map(expand_variable_connections, get_systems(sys))
711+
subsystems = map(get_systems(sys)) do subsys
712+
expand_variable_connections(subsys;
713+
ignored_variables = ignored_systems_for_subsystem(subsys, ignored_variables))
714+
end
599715
@set! sys.eqs = eqs
600716
@set! sys.systems = subsystems
601717
return sys

0 commit comments

Comments
 (0)