Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix analysis point transform ignoring too many connections #3469

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 57 additions & 7 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1421,9 +1421,34 @@ function assertions(sys::AbstractSystem)
return merge(asserts, namespaced_asserts)
end

"""
$(TYPEDEF)

Information about an `AnalysisPoint` for which the corresponding connection must be
ignored during `expand_connections`, since the analysis point has been transformed.

# Fields

$(TYPEDFIELDS)
"""
struct IgnoredAnalysisPoint
"""
The input variable/connector.
"""
input::Union{BasicSymbolic, AbstractSystem}
"""
The output variables/connectors.
"""
outputs::Vector{Union{BasicSymbolic, AbstractSystem}}
end

const HierarchyVariableT = Vector{Union{BasicSymbolic, Symbol}}
const HierarchySystemT = Vector{Union{AbstractSystem, Symbol}}
"""
The type returned from `analysis_point_common_hierarchy`.
"""
const HierarchyAnalysisPointT = Vector{Union{IgnoredAnalysisPoint, Symbol}}
"""
The type returned from `as_hierarchy`.
"""
const HierarchyT = Union{HierarchyVariableT, HierarchySystemT}
Expand All @@ -1440,6 +1465,29 @@ function from_hierarchy(hierarchy::HierarchyT)
end
end

"""
$(TYPEDSIGNATURES)

Represent an ignored analysis point as a namespaced hierarchy. The hierarchy is built
using the common hierarchy of all involved systems/variables.
"""
function analysis_point_common_hierarchy(ap::IgnoredAnalysisPoint)::HierarchyAnalysisPointT
isys = as_hierarchy(ap.input)
osyss = as_hierarchy.(ap.outputs)
suffix = Symbol[]
while isys[end] == osyss[1][end] && allequal(last.(osyss))
push!(suffix, isys[end])
pop!(isys)
pop!.(osyss)
end
isys = from_hierarchy(isys)
osyss = from_hierarchy.(osyss)
newap = IgnoredAnalysisPoint(isys, osyss)
hierarchy = HierarchyAnalysisPointT([suffix; newap])
reverse!(hierarchy)
return hierarchy
end

"""
$(TYPEDSIGNATURES)

Expand All @@ -1466,19 +1514,20 @@ end
"""
$(TYPEDSIGNATURES)

Get the connections to ignore for `sys` and its subsystems. The returned value is a
`Tuple` similar in structure to the `ignored_connections` field. Each system (variable)
in the first (second) element of the tuple is also passed through `as_hierarchy`.
Get the analysis points to ignore for `sys` and its subsystems. The returned value is a
`Tuple` similar in structure to the `ignored_connections` field.
"""
function ignored_connections(sys::AbstractSystem)
has_ignored_connections(sys) || return (HierarchySystemT[], HierarchyVariableT[])
has_ignored_connections(sys) ||
return (HierarchyAnalysisPointT[], HierarchyAnalysisPointT[])

ics = get_ignored_connections(sys)
if ics === nothing
ics = (HierarchySystemT[], HierarchyVariableT[])
ics = (HierarchyAnalysisPointT[], HierarchyAnalysisPointT[])
end
# turn into hierarchies
ics = (map(as_hierarchy, ics[1]), map(as_hierarchy, ics[2]))
ics = (map(analysis_point_common_hierarchy, ics[1]),
map(analysis_point_common_hierarchy, ics[2]))
systems = get_systems(sys)
# for each subsystem, get its ignored connections, add the name of the subsystem
# to the hierarchy and concatenate corresponding buffers of the result
Expand All @@ -1487,7 +1536,8 @@ function ignored_connections(sys::AbstractSystem)
(map(Base.Fix2(push!, nameof(subsys)), sub_ics[1]),
map(Base.Fix2(push!, nameof(subsys)), sub_ics[2]))
end
return (Vector{HierarchySystemT}(result[1]), Vector{HierarchyVariableT}(result[2]))
return (Vector{HierarchyAnalysisPointT}(result[1]),
Vector{HierarchyAnalysisPointT}(result[2]))
end

"""
Expand Down
17 changes: 9 additions & 8 deletions src/systems/analysis_points.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,8 @@ function Symbolics.connect(
outs::ConnectableSymbolicT...; verbose = true)
allvars = (in, out, outs...)
validate_causal_variables_connection(allvars)
return AnalysisPoint() ~ AnalysisPoint(in, name, [out; collect(outs)]; verbose)
return AnalysisPoint() ~ AnalysisPoint(
unwrap(in), name, unwrap.([out; collect(outs)]); verbose)
end

"""
Expand Down Expand Up @@ -416,20 +417,20 @@ function with_analysis_point_ignored(sys::AbstractSystem, ap::AnalysisPoint)
has_ignored_connections(sys) || return sys
ignored = get_ignored_connections(sys)
if ignored === nothing
ignored = (ODESystem[], BasicSymbolic[])
ignored = (IgnoredAnalysisPoint[], IgnoredAnalysisPoint[])
else
ignored = copy.(ignored)
end
if ap.outputs === nothing
error("Empty analysis point")
end
for x in ap.outputs
if x isa ODESystem
push!(ignored[1], x)
else
push!(ignored[2], unwrap(x))
end

if ap.input isa AbstractSystem && all(x -> x isa AbstractSystem, ap.outputs)
push!(ignored[1], IgnoredAnalysisPoint(ap.input, ap.outputs))
else
push!(ignored[2], IgnoredAnalysisPoint(unwrap(ap.input), unwrap.(ap.outputs)))
end

return @set sys.ignored_connections = ignored
end

Expand Down
146 changes: 104 additions & 42 deletions src/systems/connectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ function Symbolics.connect(var1::ConnectableSymbolicT, var2::ConnectableSymbolic
vars::ConnectableSymbolicT...)
allvars = (var1, var2, vars...)
validate_causal_variables_connection(allvars)
return Equation(Connection(), Connection(map(SymbolicWithNameof, allvars)))
return Equation(Connection(), Connection(map(SymbolicWithNameof, unwrap.(allvars))))
end

function flowvar(sys::AbstractSystem)
Expand Down Expand Up @@ -328,14 +328,12 @@ namespaced by `namespace`.
`ignored_connects[1]`, purely to avoid unnecessary recomputation.
"""
function connection2set!(connectionsets, namespace, ss, isouter;
ignored_connects = (HierarchySystemT[], HierarchyVariableT[]),
namespaced_ignored_systems = ODESystem[])
ignored_systems, ignored_variables = ignored_connects
ignored_systems = HierarchySystemT[], ignored_variables = HierarchyVariableT[])
ns_ignored_systems = from_hierarchy.(ignored_systems)
ns_ignored_variables = from_hierarchy.(ignored_variables)
# ignore specified systems
ss = filter(ss) do s
all(namespaced_ignored_systems) do igsys
nameof(igsys) != nameof(s)
end
!any(x -> nameof(x) == nameof(s), ns_ignored_systems)
end
# `ignored_variables` for each `s` in `ss`
corresponding_ignored_variables = map(
Expand Down Expand Up @@ -434,15 +432,95 @@ function generate_connection_set(
connectionsets = ConnectionSet[]
domain_csets = ConnectionSet[]
sys = generate_connection_set!(
connectionsets, domain_csets, sys, find, replace, scalarize, nothing,
# include systems to be ignored
ignored_connections(sys))
connectionsets, domain_csets, sys, find, replace, scalarize, nothing, ignored_connections(sys))
csets = merge(connectionsets)
domain_csets = merge([csets; domain_csets], true)

sys, (csets, domain_csets)
end

"""
$(TYPEDSIGNATURES)

For a list of `systems` in a connect equation, return the subset of it to ignore (as a
list of hierarchical systems) based on `ignored_system_aps`, the analysis points to be
ignored. All analysis points in `ignored_system_aps` must contain systems (connectors)
as their input/outputs.
"""
function systems_to_ignore(ignored_system_aps::Vector{HierarchyAnalysisPointT},
systems::Union{Vector{<:AbstractSystem}, Tuple{Vararg{<:AbstractSystem}}})
to_ignore = HierarchySystemT[]
for ap in ignored_system_aps
# if `systems` contains the input of the AP, ignore any outputs of the AP present in it.
isys_hierarchy = HierarchySystemT([ap[1].input; @view ap[2:end]])
isys = from_hierarchy(isys_hierarchy)
any(x -> nameof(x) == nameof(isys), systems) || continue

for outsys in ap[1].outputs
osys_hierarchy = HierarchySystemT([outsys; @view ap[2:end]])
osys = from_hierarchy(osys_hierarchy)
any(x -> nameof(x) == nameof(osys), systems) || continue
push!(to_ignore, HierarchySystemT(osys_hierarchy))
end
end

return to_ignore
end

"""
$(TYPEDSIGNATURES)

For a list of `systems` in a connect equation, return the subset of their variables to
ignore (as a list of hierarchical variables) based on `ignored_system_aps`, the analysis
points to be ignored. All analysis points in `ignored_system_aps` must contain variables
as their input/outputs.
"""
function variables_to_ignore(ignored_variable_aps::Vector{HierarchyAnalysisPointT},
systems::Union{Vector{<:AbstractSystem}, Tuple{Vararg{<:AbstractSystem}}})
to_ignore = HierarchyVariableT[]
for ap in ignored_variable_aps
ivar_hierarchy = HierarchyVariableT([ap[1].input; @view ap[2:end]])
ivar = from_hierarchy(ivar_hierarchy)
any(x -> any(isequal(ivar), renamespace.((x,), unknowns(x))), systems) || continue

for outvar in ap[1].outputs
ovar_hierarchy = HierarchyVariableT([as_hierarchy(outvar); @view ap[2:end]])
ovar = from_hierarchy(ovar_hierarchy)
any(x -> any(isequal(ovar), renamespace.((x,), unknowns(x))), systems) ||
continue
push!(to_ignore, HierarchyVariableT(ovar_hierarchy))
end
end
return to_ignore
end

"""
$(TYPEDSIGNATURES)

For a list of variables `vars` in a connect equation, return the subset of them ignore
(as a list of symbolic variables) based on `ignored_system_aps`, the analysis points to
be ignored. All analysis points in `ignored_system_aps` must contain variables as their
input/outputs.
"""
function variables_to_ignore(ignored_variable_aps::Vector{HierarchyAnalysisPointT},
vars::Union{Vector{<:BasicSymbolic}, Tuple{Vararg{<:BasicSymbolic}}})
to_ignore = eltype(vars)[]
for ap in ignored_variable_aps
ivar_hierarchy = HierarchyVariableT([ap[1].input; @view ap[2:end]])
ivar = from_hierarchy(ivar_hierarchy)
any(isequal(ivar), vars) || continue

for outvar in ap[1].outputs
ovar_hierarchy = HierarchyVariableT([outvar; @view ap[2:end]])
ovar = from_hierarchy(ovar_hierarchy)
any(isequal(ovar), vars) || continue
push!(to_ignore, ovar)
end
end

return to_ignore
end

"""
$(TYPEDSIGNATURES)

Expand All @@ -456,26 +534,12 @@ Generate connection sets from `connect` equations.
- `sys` is the system whose equations are to be searched.
- `namespace` is a system representing the namespace in which `sys` exists, or `nothing`
for no namespace (if `sys` is top-level).
- `ignored_connects` is a tuple. The first (second) element is a list of systems
(variables) in the format returned by `as_hierarchy` to be ignored when generating
connections. This is typically because the connections they are used in were removed by
analysis point transformations.
"""
function generate_connection_set!(connectionsets, domain_csets,
sys::AbstractSystem, find, replace, scalarize, namespace = nothing,
ignored_connects = (HierarchySystemT[], HierarchyVariableT[]))
ignored_connects = (HierarchyAnalysisPointT[], HierarchyAnalysisPointT[]))
subsys = get_systems(sys)
ignored_systems, ignored_variables = ignored_connects
# turn hierarchies into namespaced systems
namespaced_ignored_systems = from_hierarchy.(ignored_systems)
namespaced_ignored_variables = from_hierarchy.(ignored_variables)
namespaced_ignored = (namespaced_ignored_systems, namespaced_ignored_variables)
# filter the subsystems of `sys` to exclude ignored ones
filtered_subsys = filter(subsys) do ss
all(namespaced_ignored_systems) do igsys
nameof(igsys) != nameof(ss)
end
end
ignored_system_aps, ignored_variable_aps = ignored_connects

isouter = generate_isouter(sys)
eqs′ = get_eqs(sys)
Expand All @@ -501,8 +565,12 @@ function generate_connection_set!(connectionsets, domain_csets,
neweq isa AbstractArray ? append!(eqs, neweq) : push!(eqs, neweq)
else
if lhs isa Connection && get_systems(lhs) === :domain
connection2set!(domain_csets, namespace, get_systems(rhs), isouter;
ignored_connects, namespaced_ignored_systems)
connected_systems = get_systems(rhs)
connection2set!(domain_csets, namespace, connected_systems, isouter;
ignored_systems = systems_to_ignore(
ignored_system_aps, connected_systems),
ignored_variables = variables_to_ignore(
ignored_variable_aps, connected_systems))
elseif isconnection(rhs)
push!(cts, get_systems(rhs))
else
Expand All @@ -519,22 +587,19 @@ function generate_connection_set!(connectionsets, domain_csets,
# all connectors are eventually inside connectors.
T = ConnectionElement
# only generate connection sets for systems that are not ignored
for s in filtered_subsys
for s in subsys
isconnector(s) || continue
is_domain_connector(s) && continue
_ignored_variables = ignored_systems_for_subsystem(s, ignored_variables)
_namespaced_ignored_variables = from_hierarchy.(_ignored_variables)
for v in unknowns(s)
Flow === get_connection_type(v) || continue
# ignore specified variables
any(isequal(v), _namespaced_ignored_variables) && continue
push!(connectionsets, ConnectionSet([T(LazyNamespace(namespace, s), v, false)]))
end
end

for ct in cts
connection2set!(connectionsets, namespace, ct, isouter;
ignored_connects, namespaced_ignored_systems)
ignored_systems = systems_to_ignore(ignored_system_aps, ct),
ignored_variables = variables_to_ignore(ignored_variable_aps, ct))
end

# pre order traversal
Expand All @@ -558,14 +623,15 @@ ignored by `generate_connection_set!` (`expand_variable_connections`), filter
their hierarchy to not include `subsys`.
"""
function ignored_systems_for_subsystem(
subsys::AbstractSystem, ignored_systems::Vector{<:HierarchyT})
subsys::AbstractSystem, ignored_systems::Vector{<:Union{
HierarchyT, HierarchyAnalysisPointT}})
result = eltype(ignored_systems)[]
# in case `subsys` is namespaced, get its hierarchy and compare suffixes
# instead of the just the last element
suffix = reverse!(namespace_hierarchy(nameof(subsys)))
N = length(suffix)
for igsys in ignored_systems
if igsys[(end - N + 1):end] == suffix
if length(igsys) > N && igsys[(end - N + 1):end] == suffix
push!(result, copy(igsys))
for i in 1:N
pop!(result[end])
Expand Down Expand Up @@ -684,7 +750,6 @@ function expand_variable_connections(sys::AbstractSystem; ignored_variables = no
if ignored_variables === nothing
ignored_variables = ignored_connections(sys)[2]
end
namespaced_ignored = from_hierarchy.(ignored_variables)
eqs = copy(get_eqs(sys))
valid_idxs = trues(length(eqs))
additional_eqs = Equation[]
Expand All @@ -694,14 +759,11 @@ function expand_variable_connections(sys::AbstractSystem; ignored_variables = no
connection = eq.rhs
elements = get_systems(connection)
is_causal_variable_connection(connection) || continue
elements = filter(elements) do el
all(namespaced_ignored) do var
getname(var) != getname(el.var)
end
end

valid_idxs[i] = false
elements = map(x -> x.var, elements)
to_ignore = variables_to_ignore(ignored_variables, elements)
elements = setdiff(elements, to_ignore)
outvar = first(elements)
for invar in Iterators.drop(elements, 1)
push!(additional_eqs, outvar ~ invar)
Expand Down
9 changes: 5 additions & 4 deletions src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,11 +189,12 @@ struct ODESystem <: AbstractODESystem
"""
split_idxs::Union{Nothing, Vector{Vector{Int}}}
"""
The connections to ignore (since they're removed by analysis point transformations).
The first element of the tuple are systems that can't be in the same connection set,
and the second are variables (for the trivial form of `connect`).
The analysis points removed by transformations, representing connections to be
ignored. The first element of the tuple analysis points connecting systems and
the second are ones connecting variables (for the trivial form of `connect`).
"""
ignored_connections::Union{Nothing, Tuple{Vector{ODESystem}, Vector{BasicSymbolic}}}
ignored_connections::Union{
Nothing, Tuple{Vector{IgnoredAnalysisPoint}, Vector{IgnoredAnalysisPoint}}}
"""
The hierarchical parent system before simplification.
"""
Expand Down
Loading
Loading