Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for an external synchronous compiler to discrete and hybrid systems #3399

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ using Symbolics: _parse_vars, value, @derivatives, get_variables,
NAMESPACE_SEPARATOR, set_scalar_metadata, setdefaultval,
initial_state, transition, activeState, entry, hasnode,
ticksInState, timeInState, fixpoint_sub, fast_substitute,
CallWithMetadata, CallWithParent
CallWithMetadata, CallWithParent, Transition, InitialState
const NAMESPACE_SEPARATOR_SYMBOL = Symbol(NAMESPACE_SEPARATOR)
import Symbolics: rename, get_variables!, _solve, hessian_sparsity,
jacobian_sparsity, isaffine, islinear, _iszero, _isone,
Expand Down
20 changes: 20 additions & 0 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1390,6 +1390,26 @@ function namespace_expr(
O
end
end

function namespace_expr(
O::Transition, sys, n = nameof(sys); ivs = independent_variables(sys))
return Transition(
O.from === nothing ? O.from : renamespace(sys, O.from),
O.to === nothing ? O.to : renamespace(sys, O.to),
O.cond === nothing ? O.cond : namespace_expr(O.cond, sys),
O.immediate, O.reset, O.synchronize, O.priority
)
end

function namespace_expr(
O::InitialState, sys, n = nameof(sys); ivs = independent_variables(sys))
return InitialState(O.s === nothing ? O.s : renamespace(sys, O.s))
end

function namespace_expr(O::StateMachineOperator, sys, n = nameof(sys); kwargs...)
error("Unhandled state machine operator")
end

_nonum(@nospecialize x) = x isa Num ? x.val : x

"""
Expand Down
11 changes: 10 additions & 1 deletion src/systems/clock_inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ function infer_clocks!(ci::ClockInference)
c = BitSet(c′)
idxs = intersect(c, inferred)
isempty(idxs) && continue
if !allequal(var_domain[i] for i in idxs)
if !allequal(iscontinuous(var_domain[i]) for i in idxs)
display(fullvars[c′])
throw(ClockInferenceException("Clocks are not consistent in connected component $(fullvars[c′])"))
end
Expand Down Expand Up @@ -144,6 +144,9 @@ function split_system(ci::ClockInference{S}) where {S}
var_to_cid = Vector{Int}(undef, ndsts(graph))
cid_to_var = Vector{Int}[]
cid_counter = Ref(0)

# populates clock_to_id and id_to_clock
# checks if there is a continuous_id (for some reason? clock to id does this too)
for (i, d) in enumerate(eq_domain)
cid = let cid_counter = cid_counter, id_to_clock = id_to_clock,
continuous_id = continuous_id
Expand All @@ -161,9 +164,13 @@ function split_system(ci::ClockInference{S}) where {S}
resize_or_push!(cid_to_eq, i, cid)
end
continuous_id = continuous_id[]
# for each clock partition what are the input (indexes/vars)
input_idxs = map(_ -> Int[], 1:cid_counter[])
inputs = map(_ -> Any[], 1:cid_counter[])
# var_domain corresponds to fullvars/all variables in the system
nvv = length(var_domain)
# put variables into the right clock partition
# keep track of inputs to each partition
for i in 1:nvv
d = var_domain[i]
cid = get(clock_to_id, d, 0)
Expand All @@ -177,6 +184,7 @@ function split_system(ci::ClockInference{S}) where {S}
resize_or_push!(cid_to_var, i, cid)
end

# breaks the system up into a continous and 0 or more discrete systems
tss = similar(cid_to_eq, S)
for (id, ieqs) in enumerate(cid_to_eq)
ts_i = system_subset(ts, ieqs)
Expand All @@ -186,6 +194,7 @@ function split_system(ci::ClockInference{S}) where {S}
end
tss[id] = ts_i
end
# put the continous system at the back
if continuous_id != 0
tss[continuous_id], tss[end] = tss[end], tss[continuous_id]
inputs[continuous_id], inputs[end] = inputs[end], inputs[continuous_id]
Expand Down
13 changes: 10 additions & 3 deletions src/systems/systems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ function structural_simplify(
kwargs...)
isscheduled(sys) && throw(RepeatedStructuralSimplificationError())
newsys′ = __structural_simplify(sys, io; simplify,
allow_symbolic, allow_parameter, conservative, fully_determined,
allow_symbolic, allow_parameter, conservative, fully_determined, additional_passes,
kwargs...)
if newsys′ isa Tuple
@assert length(newsys′) == 2
Expand Down Expand Up @@ -82,12 +82,13 @@ end

function __structural_simplify(sys::AbstractSystem, io = nothing; simplify = false,
kwargs...)
sys, statemachines = extract_top_level_statemachines(sys)
sys = expand_connections(sys)
state = TearingState(sys)
append!(state.statemachines, statemachines)

@unpack structure, fullvars = state
@unpack graph, var_to_diff, var_types = structure
eqs = equations(state)
brown_vars = Int[]
new_idxs = zeros(Int, length(var_types))
idx = 0
Expand All @@ -104,7 +105,8 @@ function __structural_simplify(sys::AbstractSystem, io = nothing; simplify = fal
Is = Int[]
Js = Int[]
vals = Num[]
new_eqs = copy(eqs)
make_eqs_zero_equals!(state)
new_eqs = copy(equations(state))
dvar2eq = Dict{Any, Int}()
for (v, dv) in enumerate(var_to_diff)
dv === nothing && continue
Expand Down Expand Up @@ -169,3 +171,8 @@ function __structural_simplify(sys::AbstractSystem, io = nothing; simplify = fal
guesses = guesses(sys), initialization_eqs = initialization_equations(sys))
end
end

"""
Mark whether an extra pass `p` can support compiling discrete systems.
"""
discrete_compile_pass(p) = false
149 changes: 121 additions & 28 deletions src/systems/systemstructure.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using DataStructures
using Symbolics: linear_expansion, unwrap, Connection
using Symbolics: linear_expansion, unwrap, Connection, Transition, InitialState
using SymbolicUtils: iscall, operation, arguments, Symbolic
using SymbolicUtils: quick_cancel, maketerm
using ..ModelingToolkit
Expand Down Expand Up @@ -198,16 +198,35 @@ end

mutable struct TearingState{T <: AbstractSystem} <: AbstractTearingState{T}
sys::T
original_eqs::Vector{Equation}
fullvars::Vector
structure::SystemStructure
extra_eqs::Vector
statemachines::Vector{T}
end

TransformationState(sys::AbstractSystem) = TearingState(sys)
function system_subset(ts::TearingState, ieqs::Vector{Int})
eqs = equations(ts)
@set! ts.original_eqs = ts.original_eqs[ieqs]
@set! ts.sys.eqs = eqs[ieqs]
@set! ts.structure = system_subset(ts.structure, ieqs)
if all(eq -> eq.rhs isa StateMachineOperator, get_eqs(ts.sys))
names = Symbol[]
for eq in get_eqs(ts.sys)
if eq.lhs isa Transition
push!(names, first(namespace_hierarchy(nameof(eq.rhs.from))))
push!(names, first(namespace_hierarchy(nameof(eq.rhs.to))))
elseif eq.lhs isa InitialState
push!(names, first(namespace_hierarchy(nameof(eq.rhs.s))))
else
error("Unhandled state machine operator")
end
end
@set! ts.statemachines = filter(x -> nameof(x) in names, ts.statemachines)
else
@set! ts.statemachines = eltype(ts.statemachines)[]
end
ts
end

Expand Down Expand Up @@ -247,12 +266,56 @@ function Base.push!(ev::EquationsView, eq)
push!(ev.ts.extra_eqs, eq)
end

"""
$(TYPEDSIGNATURES)

Descend through the system hierarchy and look for statemachines. Remove equations from
the inner statemachine systems. Return the new `sys` and an array of top-level
statemachines.
"""
function extract_top_level_statemachines(sys::AbstractSystem)
eqs = get_eqs(sys)

if !isempty(eqs) && all(eq -> eq.lhs isa StateMachineOperator, eqs)
# top-level statemachine
with_removed = @set sys.systems = map(remove_child_equations, get_systems(sys))
return with_removed, [sys]
elseif !isempty(eqs) && any(eq -> eq.lhs isa StateMachineOperator, eqs)
# error: can't mix
error("Mixing statemachine equations and standard equations in a top-level statemachine is not allowed.")
else
# descend
subsystems = get_systems(sys)
newsubsystems = eltype(subsystems)[]
statemachines = eltype(subsystems)[]
for subsys in subsystems
newsubsys, sub_statemachines = extract_top_level_statemachines(subsys)
push!(newsubsystems, newsubsys)
append!(statemachines, sub_statemachines)
end
@set! sys.systems = newsubsystems
return sys, statemachines
end
end

"""
$(TYPEDSIGNATURES)

Return `sys` with all equations (including those in subsystems) removed.
"""
function remove_child_equations(sys::AbstractSystem)
@set! sys.eqs = eltype(get_eqs(sys))[]
@set! sys.systems = map(remove_child_equations, get_systems(sys))
return sys
end

function TearingState(sys; quick_cancel = false, check = true)
sys = flatten(sys)
ivs = independent_variables(sys)
iv = length(ivs) == 1 ? ivs[1] : nothing
# scalarize array equations, without scalarizing arguments to registered functions
eqs = flatten_equations(copy(equations(sys)))
original_eqs = flatten_equations(copy(equations(sys)))
eqs = copy(original_eqs)
neqs = length(eqs)
dervaridxs = OrderedSet{Int}()
var2idx = Dict{Any, Int}()
Expand All @@ -275,7 +338,12 @@ function TearingState(sys; quick_cancel = false, check = true)
check ? error("$(nameof(sys)) has unexpanded `connect` statements") :
return nothing
end
if _iszero(eq′.lhs)
is_statemachine_equation = false
if eq′.lhs isa StateMachineOperation
is_statemachine_equation = true
eq = eq′
rhs = eq.rhs
elseif _iszero(eq′.lhs)
rhs = quick_cancel ? quick_cancel_expr(eq′.rhs) : eq′.rhs
eq = eq′
else
Expand Down Expand Up @@ -340,7 +408,7 @@ function TearingState(sys; quick_cancel = false, check = true)
empty!(unknownvars)
empty!(vars)
empty!(varsvec)
if isalgeq
if isalgeq || is_statemachine_equation
eqs[i] = eq
else
eqs[i] = eqs[i].lhs ~ rhs
Expand Down Expand Up @@ -428,10 +496,10 @@ function TearingState(sys; quick_cancel = false, check = true)

eq_to_diff = DiffGraph(nsrcs(graph))

ts = TearingState(sys, fullvars,
ts = TearingState(sys, original_eqs, fullvars,
SystemStructure(complete(var_to_diff), complete(eq_to_diff),
complete(graph), nothing, var_types, sys isa DiscreteSystem),
Any[])
Any[], typeof(sys)[])
if sys isa DiscreteSystem
ts = shift_discrete_system(ts)
end
Expand Down Expand Up @@ -622,44 +690,69 @@ function merge_io(io, inputs)
return io
end

function make_eqs_zero_equals!(ts::TearingState)
neweqs = map(enumerate(get_eqs(ts.sys))) do kvp
i, eq = kvp
isalgeq = true
for j in 𝑠neighbors(ts.structure.graph, i)
isalgeq &= invview(ts.structure.var_to_diff)[j] === nothing
end
if isalgeq
return 0 ~ eq.rhs - eq.lhs
else
return eq
end
end
copyto!(get_eqs(ts.sys), neweqs)
end

function structural_simplify!(state::TearingState, io = nothing; simplify = false,
check_consistency = true, fully_determined = true, warn_initialize_determined = true,
kwargs...)
if state.sys isa ODESystem
# split_system returns one or two systems and the inputs for each
# mod clock inference to be binary
# if it's continous keep going, if not then error unless given trait impl in additional passes
ci = ModelingToolkit.ClockInference(state)
ci = ModelingToolkit.infer_clocks!(ci)
time_domains = merge(Dict(state.fullvars .=> ci.var_domain),
Dict(default_toterm.(state.fullvars) .=> ci.var_domain))
tss, inputs, continuous_id, id_to_clock = ModelingToolkit.split_system(ci)
if continuous_id == 0
# do a trait check here - handle fully discrete system
additional_passes = get(kwargs, :additional_passes, nothing)
if !isnothing(additional_passes) &&
any(discrete_compile_pass, additional_passes)
# take the first discrete compilation pass given for now
discrete_pass_idx = findfirst(discrete_compile_pass, additional_passes)
discrete_compile = additional_passes[discrete_pass_idx]
deleteat!(additional_passes, discrete_pass_idx)
return discrete_compile(tss, inputs)
else
# error goes here! this is a purely discrete system
throw(HybridSystemNotSupportedException("Discrete systems without JuliaSimCompiler are currently not supported in ODESystem."))
end
end
make_eqs_zero_equals!(tss[continuous_id])
# puts the ios passed in to the call into the continous system
cont_io = merge_io(io, inputs[continuous_id])
# simplify as normal
sys, input_idxs = _structural_simplify!(tss[continuous_id], cont_io; simplify,
check_consistency, fully_determined,
kwargs...)
if length(tss) > 1
if continuous_id > 0
additional_passes = get(kwargs, :additional_passes, nothing)
if !isnothing(additional_passes) &&
any(discrete_compile_pass, additional_passes)
discrete_pass_idx = findfirst(discrete_compile_pass, additional_passes)
discrete_compile = additional_passes[discrete_pass_idx]
deleteat!(additional_passes, discrete_pass_idx)
# in the case of a hybrid system, the discrete_compile pass should take the currents of sys.discrete_subsystems
# and modifies discrete_subsystems to bea tuple of the io and anything else, while adding or manipulating the rest of sys as needed
sys = discrete_compile(sys, tss[2:end], inputs)
else
throw(HybridSystemNotSupportedException("Hybrid continuous-discrete systems are currently not supported with the standard MTK compiler. This system requires JuliaSimCompiler.jl, see https://help.juliahub.com/juliasimcompiler/stable/"))
end
# TODO: rename it to something else
discrete_subsystems = Vector{ODESystem}(undef, length(tss))
# Note that the appended_parameters must agree with
# `generate_discrete_affect`!
appended_parameters = parameters(sys)
for (i, state) in enumerate(tss)
if i == continuous_id
discrete_subsystems[i] = sys
continue
end
dist_io = merge_io(io, inputs[i])
ss, = _structural_simplify!(state, dist_io; simplify, check_consistency,
fully_determined, kwargs...)
append!(appended_parameters, inputs[i], unknowns(ss))
discrete_subsystems[i] = ss
end
@set! sys.discrete_subsystems = discrete_subsystems, inputs, continuous_id,
id_to_clock
@set! sys.ps = appended_parameters
@set! sys.defaults = merge(ModelingToolkit.defaults(sys),
Dict(v => 0.0 for v in Iterators.flatten(inputs)))
end
ps = [sym isa CallWithMetadata ? sym :
setmetadata(sym, VariableTimeDomain, get(time_domains, sym, Continuous()))
Expand Down
19 changes: 19 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,25 @@ vars(eq::Equation; op = Differential) = vars!(Set(), eq; op = op)
function vars!(vars, eq::Equation; op = Differential)
(vars!(vars, eq.lhs; op = op); vars!(vars, eq.rhs; op = op); vars)
end
function vars!(vars, O::AbstractSystem; op = Differential)
for eq in equations(O)
vars!(vars, eq; op)
end
return vars
end
function vars!(vars, O::Transition; op = Differential)
vars!(vars, O.from)
vars!(vars, O.to)
vars!(vars, O.cond; op)
return vars
end
function vars!(vars, O::InitialState; op = Differential)
vars!(vars, O.s; op)
return vars
end
function vars!(vars, O::StateMachineOperator; op = Differential)
error("Unhandled state machine operator")
end
function vars!(vars, O; op = Differential)
if isvariable(O)
if iscall(O) && operation(O) === getindex && iscalledparameter(first(arguments(O)))
Expand Down