Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support analysis points duplicating existing connections #3453

Merged
merged 6 commits into from
Mar 12, 2025
70 changes: 70 additions & 0 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,7 @@
"""
Initial(x)

The `Initial` operator. Used by initializaton to store constant constraints on variables

Check warning on line 625 in src/systems/abstractsystem.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"initializaton" should be "initialization".
of a system. See the documentation section on initialization for more information.
"""
struct Initial <: Symbolics.Operator end
Expand Down Expand Up @@ -878,6 +878,7 @@
:assertions
:solved_unknowns
:split_idxs
:ignored_connections
:parent
:is_dde
:tstops
Expand Down Expand Up @@ -1394,6 +1395,75 @@
return merge(asserts, namespaced_asserts)
end

const HierarchyVariableT = Vector{Union{BasicSymbolic, Symbol}}
const HierarchySystemT = Vector{Union{AbstractSystem, Symbol}}
"""
The type returned from `as_hierarchy`.
"""
const HierarchyT = Union{HierarchyVariableT, HierarchySystemT}

"""
$(TYPEDSIGNATURES)

The inverse operation of `as_hierarchy`.
"""
function from_hierarchy(hierarchy::HierarchyT)
namefn = hierarchy[1] isa AbstractSystem ? nameof : getname
foldl(@view hierarchy[2:end]; init = hierarchy[1]) do sys, name
rename(sys, Symbol(name, NAMESPACE_SEPARATOR, namefn(sys)))
end
end

"""
$(TYPEDSIGNATURES)

Represent a namespaced system (or variable) `sys` as a hierarchy. Return a vector, where
the first element is the unnamespaced system (variable) and subsequent elements are
`Symbol`s representing the parents of the unnamespaced system (variable) in order from
inner to outer.
"""
function as_hierarchy(sys::Union{AbstractSystem, BasicSymbolic})::HierarchyT
namefn = sys isa AbstractSystem ? nameof : getname
# get the hierarchy
hierarchy = namespace_hierarchy(namefn(sys))
# rename the system with unnamespaced name
newsys = rename(sys, hierarchy[end])
# and remove it from the list
pop!(hierarchy)
# reverse it to go from inner to outer
reverse!(hierarchy)
# concatenate
T = sys isa AbstractSystem ? AbstractSystem : BasicSymbolic
return Union{Symbol, T}[newsys; hierarchy]
end

"""
$(TYPEDSIGNATURES)

Get the connections to ignore for `sys` and its subsystems. The returned value is a
`Tuple` similar in structure to the `ignored_connections` field. Each system (variable)
in the first (second) element of the tuple is also passed through `as_hierarchy`.
"""
function ignored_connections(sys::AbstractSystem)
has_ignored_connections(sys) || return (HierarchySystemT[], HierarchyVariableT[])

ics = get_ignored_connections(sys)
if ics === nothing
ics = (HierarchySystemT[], HierarchyVariableT[])
end
# turn into hierarchies
ics = (map(as_hierarchy, ics[1]), map(as_hierarchy, ics[2]))
systems = get_systems(sys)
# for each subsystem, get its ignored connections, add the name of the subsystem
# to the hierarchy and concatenate corresponding buffers of the result
result = mapreduce(Broadcast.BroadcastFunction(vcat), systems; init = ics) do subsys
sub_ics = ignored_connections(subsys)
(map(Base.Fix2(push!, nameof(subsys)), sub_ics[1]),
map(Base.Fix2(push!, nameof(subsys)), sub_ics[2]))
end
return (Vector{HierarchySystemT}(result[1]), Vector{HierarchyVariableT}(result[2]))
end

"""
$(TYPEDSIGNATURES)

Expand Down
30 changes: 27 additions & 3 deletions src/systems/analysis_points.jl
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,27 @@
return pvar, default
end

function with_analysis_point_ignored(sys::AbstractSystem, ap::AnalysisPoint)
has_ignored_connections(sys) || return sys
ignored = get_ignored_connections(sys)
if ignored === nothing
ignored = (ODESystem[], BasicSymbolic[])
else
ignored = copy.(ignored)
end
if ap.outputs === nothing
error("Empty analysis point")
end
for x in ap.outputs
if x isa ODESystem
push!(ignored[1], x)
else
push!(ignored[2], unwrap(x))
end
end
return @set sys.ignored_connections = ignored
end

#### PRIMITIVE TRANSFORMATIONS

const DOC_WILL_REMOVE_AP = """
Expand Down Expand Up @@ -469,7 +490,9 @@
ap = breaksys_eqs[ap_idx].rhs
deleteat!(breaksys_eqs, ap_idx)

tf.add_input || return sys, ()
breaksys = with_analysis_point_ignored(breaksys, ap)

tf.add_input || return breaksys, ()

ap_ivar = ap_var(ap.input)
new_var, new_def = get_analysis_variable(ap_ivar, nameof(ap), get_iv(sys))
Expand Down Expand Up @@ -510,8 +533,8 @@
ap_idx = analysis_point_index(ap_sys, tf.ap)
ap_idx === nothing &&
error("Analysis point $(nameof(tf.ap)) not found in system $(nameof(sys)).")
# get the anlysis point

Check warning on line 536 in src/systems/analysis_points.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"anlysis" should be "analysis".
ap_sys_eqs = copy(get_eqs(ap_sys))
ap_sys_eqs = get_eqs(ap_sys)
ap = ap_sys_eqs[ap_idx].rhs

# input variable
Expand Down Expand Up @@ -564,12 +587,13 @@
ap_idx = analysis_point_index(ap_sys, tf.ap)
ap_idx === nothing &&
error("Analysis point $(nameof(tf.ap)) not found in system $(nameof(sys)).")
# modified quations

Check warning on line 590 in src/systems/analysis_points.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"quations" should be "equations".
ap_sys_eqs = copy(get_eqs(ap_sys))
@set! ap_sys.eqs = ap_sys_eqs
ap = ap_sys_eqs[ap_idx].rhs
# remove analysis point
deleteat!(ap_sys_eqs, ap_idx)
ap_sys = with_analysis_point_ignored(ap_sys, ap)

# add equations involving new variable
ap_ivar = ap_var(ap.input)
Expand Down Expand Up @@ -634,7 +658,7 @@
ap_idx = analysis_point_index(ap_sys, tf.ap)
ap_idx === nothing &&
error("Analysis point $(nameof(tf.ap)) not found in system $(nameof(sys)).")
ap_sys_eqs = copy(get_eqs(ap_sys))
ap_sys_eqs = get_eqs(ap_sys)
ap = ap_sys_eqs[ap_idx].rhs

# add equations involving new variable
Expand Down Expand Up @@ -875,7 +899,7 @@
# Keyword Arguments

- `system_modifier`: a function which takes the modified system and returns a new system
with any required further modifications peformed.

Check warning on line 902 in src/systems/analysis_points.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"peformed" should be "performed".
"""
function open_loop(sys, ap::Union{Symbol, AnalysisPoint}; system_modifier = identity)
ap = only(canonicalize_ap(sys, ap))
Expand Down
140 changes: 128 additions & 12 deletions src/systems/connectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,34 @@ function ori(sys)
end
end

function connection2set!(connectionsets, namespace, ss, isouter)
"""
$(TYPEDSIGNATURES)

Populate `connectionsets` with connections between the connectors `ss`, all of which are
namespaced by `namespace`.

# Keyword Arguments
- `ignored_connects`: A tuple of the systems and variables for which connections should be
ignored. Of the format returned from `as_hierarchy`.
- `namespaced_ignored_systems`: The `from_hierarchy` versions of entries in
`ignored_connects[1]`, purely to avoid unnecessary recomputation.
"""
function connection2set!(connectionsets, namespace, ss, isouter;
ignored_connects = (HierarchySystemT[], HierarchyVariableT[]),
namespaced_ignored_systems = ODESystem[])
ignored_systems, ignored_variables = ignored_connects
# ignore specified systems
ss = filter(ss) do s
all(namespaced_ignored_systems) do igsys
nameof(igsys) != nameof(s)
end
end
# `ignored_variables` for each `s` in `ss`
corresponding_ignored_variables = map(
Base.Fix2(ignored_systems_for_subsystem, ignored_variables), ss)
corresponding_namespaced_ignored_variables = map(
Broadcast.BroadcastFunction(from_hierarchy), corresponding_ignored_variables)

regular_ss = []
domain_ss = nothing
for s in ss
Expand All @@ -340,9 +367,12 @@ function connection2set!(connectionsets, namespace, ss, isouter)
for (i, s) in enumerate(ss)
sts = unknowns(s)
io = isouter(s)
for (j, v) in enumerate(sts)
_ignored_variables = corresponding_ignored_variables[i]
_namespaced_ignored_variables = corresponding_namespaced_ignored_variables[i]
for v in sts
vtype = get_connection_type(v)
(vtype === Flow && isequal(v, dv)) || continue
any(isequal(v), _namespaced_ignored_variables) && continue
push!(cset, T(LazyNamespace(namespace, domain_ss), dv, false))
push!(cset, T(LazyNamespace(namespace, s), v, io))
end
Expand All @@ -360,6 +390,12 @@ function connection2set!(connectionsets, namespace, ss, isouter)
end
sts1 = Set(sts1v)
num_unknowns = length(sts1)

# we don't filter here because `csets` should include the full set of unknowns.
# not all of `ss` will have the same (or any) variables filtered so the ones
# that aren't should still go in the right cset. Since `sts1` is only used for
# validating that all systems being connected are of the same type, it has
# unfiltered entries.
csets = [T[] for _ in 1:num_unknowns] # Add 9 orientation variables if connection is between multibody frames
for (i, s) in enumerate(ss)
unknown_vars = unknowns(s)
Expand All @@ -372,7 +408,10 @@ function connection2set!(connectionsets, namespace, ss, isouter)
all(Base.Fix2(in, sts1), unknown_vars)) ||
connection_error(ss))
io = isouter(s)
# don't `filter!` here so that `j` points to the correct cset regardless of
# which variables are filtered.
for (j, v) in enumerate(unknown_vars)
any(isequal(v), corresponding_namespaced_ignored_variables[i]) && continue
push!(csets[j], T(LazyNamespace(namespace, s), v, io))
end
end
Expand All @@ -395,16 +434,48 @@ function generate_connection_set(
connectionsets = ConnectionSet[]
domain_csets = ConnectionSet[]
sys = generate_connection_set!(
connectionsets, domain_csets, sys, find, replace, scalarize)
connectionsets, domain_csets, sys, find, replace, scalarize, nothing,
# include systems to be ignored
ignored_connections(sys))
csets = merge(connectionsets)
domain_csets = merge([csets; domain_csets], true)

sys, (csets, domain_csets)
end

"""
$(TYPEDSIGNATURES)

Generate connection sets from `connect` equations.

# Arguments

- `connectionsets` is the list of connection sets to be populated by recursively
descending `sys`.
- `domain_csets` is the list of connection sets for domain connections.
- `sys` is the system whose equations are to be searched.
- `namespace` is a system representing the namespace in which `sys` exists, or `nothing`
for no namespace (if `sys` is top-level).
- `ignored_connects` is a tuple. The first (second) element is a list of systems
(variables) in the format returned by `as_hierarchy` to be ignored when generating
connections. This is typically because the connections they are used in were removed by
analysis point transformations.
"""
function generate_connection_set!(connectionsets, domain_csets,
sys::AbstractSystem, find, replace, scalarize, namespace = nothing)
sys::AbstractSystem, find, replace, scalarize, namespace = nothing,
ignored_connects = (HierarchySystemT[], HierarchyVariableT[]))
subsys = get_systems(sys)
ignored_systems, ignored_variables = ignored_connects
# turn hierarchies into namespaced systems
namespaced_ignored_systems = from_hierarchy.(ignored_systems)
namespaced_ignored_variables = from_hierarchy.(ignored_variables)
namespaced_ignored = (namespaced_ignored_systems, namespaced_ignored_variables)
# filter the subsystems of `sys` to exclude ignored ones
filtered_subsys = filter(subsys) do ss
all(namespaced_ignored_systems) do igsys
nameof(igsys) != nameof(ss)
end
end

isouter = generate_isouter(sys)
eqs′ = get_eqs(sys)
Expand All @@ -430,7 +501,8 @@ function generate_connection_set!(connectionsets, domain_csets,
neweq isa AbstractArray ? append!(eqs, neweq) : push!(eqs, neweq)
else
if lhs isa Connection && get_systems(lhs) === :domain
connection2set!(domain_csets, namespace, get_systems(rhs), isouter)
connection2set!(domain_csets, namespace, get_systems(rhs), isouter;
ignored_connects, namespaced_ignored_systems)
elseif isconnection(rhs)
push!(cts, get_systems(rhs))
else
Expand All @@ -446,17 +518,23 @@ function generate_connection_set!(connectionsets, domain_csets,

# all connectors are eventually inside connectors.
T = ConnectionElement
for s in subsys
# only generate connection sets for systems that are not ignored
for s in filtered_subsys
isconnector(s) || continue
is_domain_connector(s) && continue
_ignored_variables = ignored_systems_for_subsystem(s, ignored_variables)
_namespaced_ignored_variables = from_hierarchy.(_ignored_variables)
for v in unknowns(s)
Flow === get_connection_type(v) || continue
# ignore specified variables
any(isequal(v), _namespaced_ignored_variables) && continue
push!(connectionsets, ConnectionSet([T(LazyNamespace(namespace, s), v, false)]))
end
end

for ct in cts
connection2set!(connectionsets, namespace, ct, isouter)
connection2set!(connectionsets, namespace, ct, isouter;
ignored_connects, namespaced_ignored_systems)
end

# pre order traversal
Expand All @@ -465,12 +543,38 @@ function generate_connection_set!(connectionsets, domain_csets,
end
@set! sys.systems = map(
s -> generate_connection_set!(connectionsets, domain_csets, s,
find, replace, scalarize,
renamespace(namespace, s)),
find, replace, scalarize, renamespace(namespace, s),
ignored_systems_for_subsystem.((s,), ignored_connects)),
subsys)
@set! sys.eqs = eqs
end

"""
$(TYPEDSIGNATURES)

Given a subsystem `subsys` of a parent system and a list of systems (variables) to be
ignored by `generate_connection_set!` (`expand_variable_connections`), filter
`ignored_systems` to only include those present in the subtree of `subsys` and update
their hierarchy to not include `subsys`.
"""
function ignored_systems_for_subsystem(
subsys::AbstractSystem, ignored_systems::Vector{<:HierarchyT})
result = eltype(ignored_systems)[]
# in case `subsys` is namespaced, get its hierarchy and compare suffixes
# instead of the just the last element
suffix = reverse!(namespace_hierarchy(nameof(subsys)))
N = length(suffix)
for igsys in ignored_systems
if igsys[(end - N + 1):end] == suffix
push!(result, copy(igsys))
for i in 1:N
pop!(result[end])
end
end
end
return result
end

function Base.merge(csets::AbstractVector{<:ConnectionSet}, allouter = false)
ele2idx = Dict{ConnectionElement, Int}()
idx2ele = ConnectionElement[]
Expand Down Expand Up @@ -576,16 +680,25 @@ end
Recursively descend through the hierarchy of `sys` and expand all connection equations
of causal variables. Return the modified system.
"""
function expand_variable_connections(sys::AbstractSystem)
function expand_variable_connections(sys::AbstractSystem; ignored_variables = nothing)
if ignored_variables === nothing
ignored_variables = ignored_connections(sys)[2]
end
namespaced_ignored = from_hierarchy.(ignored_variables)
eqs = copy(get_eqs(sys))
valid_idxs = trues(length(eqs))
additional_eqs = Equation[]

for (i, eq) in enumerate(eqs)
eq.lhs isa Connection || continue
connection = eq.rhs
elements = connection.systems
elements = get_systems(connection)
is_causal_variable_connection(connection) || continue
elements = filter(elements) do el
all(namespaced_ignored) do var
getname(var) != getname(el.var)
end
end

valid_idxs[i] = false
elements = map(x -> x.var, elements)
Expand All @@ -595,7 +708,10 @@ function expand_variable_connections(sys::AbstractSystem)
end
end
eqs = [eqs[valid_idxs]; additional_eqs]
subsystems = map(expand_variable_connections, get_systems(sys))
subsystems = map(get_systems(sys)) do subsys
expand_variable_connections(subsys;
ignored_variables = ignored_systems_for_subsystem(subsys, ignored_variables))
end
@set! sys.eqs = eqs
@set! sys.systems = subsystems
return sys
Expand Down
Loading
Loading