From 4f7810cf92d9c864fde8bc5da847e814dd5af2d2 Mon Sep 17 00:00:00 2001 From: Benjamin Chung Date: Mon, 17 Feb 2025 17:51:46 -0800 Subject: [PATCH 01/10] Early work on the new discrete backend for MTK --- src/systems/clock_inference.jl | 11 +++++++- src/systems/systems.jl | 8 +++++- src/systems/systemstructure.jl | 49 +++++++++++++++++++--------------- 3 files changed, 44 insertions(+), 24 deletions(-) diff --git a/src/systems/clock_inference.jl b/src/systems/clock_inference.jl index 611f8e2fae..66454e0785 100644 --- a/src/systems/clock_inference.jl +++ b/src/systems/clock_inference.jl @@ -93,7 +93,7 @@ function infer_clocks!(ci::ClockInference) c = BitSet(c′) idxs = intersect(c, inferred) isempty(idxs) && continue - if !allequal(var_domain[i] for i in idxs) + if !allequal(iscontinuous(var_domain[i]) for i in idxs) display(fullvars[c′]) throw(ClockInferenceException("Clocks are not consistent in connected component $(fullvars[c′])")) end @@ -144,6 +144,9 @@ function split_system(ci::ClockInference{S}) where {S} var_to_cid = Vector{Int}(undef, ndsts(graph)) cid_to_var = Vector{Int}[] cid_counter = Ref(0) + + # populates clock_to_id and id_to_clock + # checks if there is a continuous_id (for some reason? clock to id does this too) for (i, d) in enumerate(eq_domain) cid = let cid_counter = cid_counter, id_to_clock = id_to_clock, continuous_id = continuous_id @@ -161,9 +164,13 @@ function split_system(ci::ClockInference{S}) where {S} resize_or_push!(cid_to_eq, i, cid) end continuous_id = continuous_id[] + # for each clock partition what are the input (indexes/vars) input_idxs = map(_ -> Int[], 1:cid_counter[]) inputs = map(_ -> Any[], 1:cid_counter[]) + # var_domain corresponds to fullvars/all variables in the system nvv = length(var_domain) + # put variables into the right clock partition + # keep track of inputs to each partition for i in 1:nvv d = var_domain[i] cid = get(clock_to_id, d, 0) @@ -177,6 +184,7 @@ function split_system(ci::ClockInference{S}) where {S} resize_or_push!(cid_to_var, i, cid) end + # breaks the system up into a continous and 0 or more discrete systems tss = similar(cid_to_eq, S) for (id, ieqs) in enumerate(cid_to_eq) ts_i = system_subset(ts, ieqs) @@ -186,6 +194,7 @@ function split_system(ci::ClockInference{S}) where {S} end tss[id] = ts_i end + # put the continous system at the back if continuous_id != 0 tss[continuous_id], tss[end] = tss[end], tss[continuous_id] inputs[continuous_id], inputs[end] = inputs[end], inputs[continuous_id] diff --git a/src/systems/systems.jl b/src/systems/systems.jl index f8630f2d20..5c44c0b4fe 100644 --- a/src/systems/systems.jl +++ b/src/systems/systems.jl @@ -31,7 +31,7 @@ function structural_simplify( kwargs...) isscheduled(sys) && throw(RepeatedStructuralSimplificationError()) newsys′ = __structural_simplify(sys, io; simplify, - allow_symbolic, allow_parameter, conservative, fully_determined, + allow_symbolic, allow_parameter, conservative, fully_determined, additional_passes, kwargs...) if newsys′ isa Tuple @assert length(newsys′) == 2 @@ -169,3 +169,9 @@ function __structural_simplify(sys::AbstractSystem, io = nothing; simplify = fal guesses = guesses(sys), initialization_eqs = initialization_equations(sys)) end end + +""" +Mark whether an extra pass `p` can support compiling discrete systems. +""" +discrete_compile_pass(p) = false + diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl index 1bdc11f06a..1763b6175a 100644 --- a/src/systems/systemstructure.jl +++ b/src/systems/systemstructure.jl @@ -626,40 +626,45 @@ function structural_simplify!(state::TearingState, io = nothing; simplify = fals check_consistency = true, fully_determined = true, warn_initialize_determined = true, kwargs...) if state.sys isa ODESystem + # split_system returns one or two systems and the inputs for each + # mod clock inference to be binary + # if it's continous keep going, if not then error unless given trait impl in additional passes ci = ModelingToolkit.ClockInference(state) ci = ModelingToolkit.infer_clocks!(ci) time_domains = merge(Dict(state.fullvars .=> ci.var_domain), Dict(default_toterm.(state.fullvars) .=> ci.var_domain)) tss, inputs, continuous_id, id_to_clock = ModelingToolkit.split_system(ci) + if continuous_id == 0 + # do a trait check here - handle fully discrete system + additional_passes = get(kwargs, :additional_passes, nothing) + if !isnothing(additional_passes) && any(discrete_compile_pass, additional_passes) + # take the first discrete compilation pass given for now + discrete_pass_idx = findfirst(discrete_compile_pass, additional_passes) + discrete_compile = additional_passes[discrete_pass_idx] + deleteat!(additional_passes, discrete_pass_idx) + return discrete_compile(tss, inputs) + else + # error goes here! this is a purely discrete system + throw(HybridSystemNotSupportedException("Discrete systems without JuliaSimCompiler are currently not supported in ODESystem.")) + end + end + # puts the ios passed in to the call into the continous system cont_io = merge_io(io, inputs[continuous_id]) + # simplify as normal sys, input_idxs = _structural_simplify!(tss[continuous_id], cont_io; simplify, check_consistency, fully_determined, kwargs...) if length(tss) > 1 - if continuous_id > 0 + if !isnothing(additional_passes) && any(discrete_compile_pass, additional_passes) + discrete_pass_idx = findfirst(discrete_compile_pass, additional_passes) + discrete_compile = additional_passes[discrete_pass_idx] + deleteat!(additional_passes, discrete_pass_idx) + # in the case of a hybrid system, the discrete_compile pass should take the currents of sys.discrete_subsystems + # and modifies discrete_subsystems to bea tuple of the io and anything else, while adding or manipulating the rest of sys as needed + sys = discrete_compile(sys, tss[2:end], inputs) + else throw(HybridSystemNotSupportedException("Hybrid continuous-discrete systems are currently not supported with the standard MTK compiler. This system requires JuliaSimCompiler.jl, see https://help.juliahub.com/juliasimcompiler/stable/")) end - # TODO: rename it to something else - discrete_subsystems = Vector{ODESystem}(undef, length(tss)) - # Note that the appended_parameters must agree with - # `generate_discrete_affect`! - appended_parameters = parameters(sys) - for (i, state) in enumerate(tss) - if i == continuous_id - discrete_subsystems[i] = sys - continue - end - dist_io = merge_io(io, inputs[i]) - ss, = _structural_simplify!(state, dist_io; simplify, check_consistency, - fully_determined, kwargs...) - append!(appended_parameters, inputs[i], unknowns(ss)) - discrete_subsystems[i] = ss - end - @set! sys.discrete_subsystems = discrete_subsystems, inputs, continuous_id, - id_to_clock - @set! sys.ps = appended_parameters - @set! sys.defaults = merge(ModelingToolkit.defaults(sys), - Dict(v => 0.0 for v in Iterators.flatten(inputs))) end ps = [sym isa CallWithMetadata ? sym : setmetadata(sym, VariableTimeDomain, get(time_domains, sym, Continuous())) From 404d7dc32eeb08f248efe982a722b14abbd55a4a Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 20 Feb 2025 14:51:34 +0530 Subject: [PATCH 02/10] fixup! Early work on the new discrete backend for MTK --- src/systems/systems.jl | 1 - src/systems/systemstructure.jl | 6 ++++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/systems/systems.jl b/src/systems/systems.jl index 5c44c0b4fe..0947e63cfd 100644 --- a/src/systems/systems.jl +++ b/src/systems/systems.jl @@ -174,4 +174,3 @@ end Mark whether an extra pass `p` can support compiling discrete systems. """ discrete_compile_pass(p) = false - diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl index 1763b6175a..1e839f33e6 100644 --- a/src/systems/systemstructure.jl +++ b/src/systems/systemstructure.jl @@ -637,7 +637,8 @@ function structural_simplify!(state::TearingState, io = nothing; simplify = fals if continuous_id == 0 # do a trait check here - handle fully discrete system additional_passes = get(kwargs, :additional_passes, nothing) - if !isnothing(additional_passes) && any(discrete_compile_pass, additional_passes) + if !isnothing(additional_passes) && + any(discrete_compile_pass, additional_passes) # take the first discrete compilation pass given for now discrete_pass_idx = findfirst(discrete_compile_pass, additional_passes) discrete_compile = additional_passes[discrete_pass_idx] @@ -655,7 +656,8 @@ function structural_simplify!(state::TearingState, io = nothing; simplify = fals check_consistency, fully_determined, kwargs...) if length(tss) > 1 - if !isnothing(additional_passes) && any(discrete_compile_pass, additional_passes) + if !isnothing(additional_passes) && + any(discrete_compile_pass, additional_passes) discrete_pass_idx = findfirst(discrete_compile_pass, additional_passes) discrete_compile = additional_passes[discrete_pass_idx] deleteat!(additional_passes, discrete_pass_idx) From dea08c6b05c7cb2223c5b70a85dd7aa7eb1ab3bf Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 20 Feb 2025 14:52:14 +0530 Subject: [PATCH 03/10] feat: retain original equations of the system in `TearingState` --- src/systems/systems.jl | 4 ++-- src/systems/systemstructure.jl | 24 ++++++++++++++++++++++-- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/src/systems/systems.jl b/src/systems/systems.jl index 0947e63cfd..aa7762b741 100644 --- a/src/systems/systems.jl +++ b/src/systems/systems.jl @@ -87,7 +87,6 @@ function __structural_simplify(sys::AbstractSystem, io = nothing; simplify = fal @unpack structure, fullvars = state @unpack graph, var_to_diff, var_types = structure - eqs = equations(state) brown_vars = Int[] new_idxs = zeros(Int, length(var_types)) idx = 0 @@ -104,7 +103,8 @@ function __structural_simplify(sys::AbstractSystem, io = nothing; simplify = fal Is = Int[] Js = Int[] vals = Num[] - new_eqs = copy(eqs) + make_eqs_zero_equals!(state) + new_eqs = copy(equations(state)) dvar2eq = Dict{Any, Int}() for (v, dv) in enumerate(var_to_diff) dv === nothing && continue diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl index 1e839f33e6..2ccc9c4d99 100644 --- a/src/systems/systemstructure.jl +++ b/src/systems/systemstructure.jl @@ -198,6 +198,7 @@ end mutable struct TearingState{T <: AbstractSystem} <: AbstractTearingState{T} sys::T + original_eqs::Vector{Equation} fullvars::Vector structure::SystemStructure extra_eqs::Vector @@ -206,6 +207,7 @@ end TransformationState(sys::AbstractSystem) = TearingState(sys) function system_subset(ts::TearingState, ieqs::Vector{Int}) eqs = equations(ts) + @set! ts.original_eqs = ts.original_eqs[ieqs] @set! ts.sys.eqs = eqs[ieqs] @set! ts.structure = system_subset(ts.structure, ieqs) ts @@ -252,7 +254,8 @@ function TearingState(sys; quick_cancel = false, check = true) ivs = independent_variables(sys) iv = length(ivs) == 1 ? ivs[1] : nothing # scalarize array equations, without scalarizing arguments to registered functions - eqs = flatten_equations(copy(equations(sys))) + original_eqs = flatten_equations(copy(equations(sys))) + eqs = copy(original_eqs) neqs = length(eqs) dervaridxs = OrderedSet{Int}() var2idx = Dict{Any, Int}() @@ -428,7 +431,7 @@ function TearingState(sys; quick_cancel = false, check = true) eq_to_diff = DiffGraph(nsrcs(graph)) - ts = TearingState(sys, fullvars, + ts = TearingState(sys, original_eqs, fullvars, SystemStructure(complete(var_to_diff), complete(eq_to_diff), complete(graph), nothing, var_types, sys isa DiscreteSystem), Any[]) @@ -622,6 +625,22 @@ function merge_io(io, inputs) return io end +function make_eqs_zero_equals!(ts::TearingState) + neweqs = map(enumerate(get_eqs(ts.sys))) do kvp + i, eq = kvp + isalgeq = true + for j in 𝑠neighbors(ts.structure.graph, i) + isalgeq &= invview(ts.structure.var_to_diff)[j] === nothing + end + if isalgeq + return 0 ~ eq.rhs - eq.lhs + else + return eq + end + end + copyto!(get_eqs(ts.sys), neweqs) +end + function structural_simplify!(state::TearingState, io = nothing; simplify = false, check_consistency = true, fully_determined = true, warn_initialize_determined = true, kwargs...) @@ -649,6 +668,7 @@ function structural_simplify!(state::TearingState, io = nothing; simplify = fals throw(HybridSystemNotSupportedException("Discrete systems without JuliaSimCompiler are currently not supported in ODESystem.")) end end + make_eqs_zero_equals!(tss[continuous_id]) # puts the ios passed in to the call into the continous system cont_io = merge_io(io, inputs[continuous_id]) # simplify as normal From 1f7fb2e2453bdff11de841373653ad0d114a6ab7 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 14 Mar 2025 14:50:31 +0530 Subject: [PATCH 04/10] feat: allow namespacing statemachine equations --- src/ModelingToolkit.jl | 2 +- src/systems/abstractsystem.jl | 20 ++++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index 09b59c4ed6..5bbc8d9aed 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -69,7 +69,7 @@ using Symbolics: _parse_vars, value, @derivatives, get_variables, NAMESPACE_SEPARATOR, set_scalar_metadata, setdefaultval, initial_state, transition, activeState, entry, hasnode, ticksInState, timeInState, fixpoint_sub, fast_substitute, - CallWithMetadata, CallWithParent + CallWithMetadata, CallWithParent, Transition, InitialState const NAMESPACE_SEPARATOR_SYMBOL = Symbol(NAMESPACE_SEPARATOR) import Symbolics: rename, get_variables!, _solve, hessian_sparsity, jacobian_sparsity, isaffine, islinear, _iszero, _isone, diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 8213b8f241..724194fb94 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -1390,6 +1390,26 @@ function namespace_expr( O end end + +function namespace_expr( + O::Transition, sys, n = nameof(sys); ivs = independent_variables(sys)) + return Transition( + O.from === nothing ? O.from : renamespace(sys, O.from), + O.to === nothing ? O.to : renamespace(sys, O.to), + O.cond === nothing ? O.cond : namespace_expr(O.cond, sys), + O.immediate, O.reset, O.synchronize, O.priority + ) +end + +function namespace_expr( + O::InitialState, sys, n = nameof(sys); ivs = independent_variables(sys)) + return InitialState(O.s === nothing ? O.s : renamespace(sys, O.s)) +end + +function namespace_expr(O::StateMachineOperator, sys, n = nameof(sys); kwargs...) + error("Unhandled state machine operator") +end + _nonum(@nospecialize x) = x isa Num ? x.val : x """ From 8d1cf25775913b4bc7ba6a637b16417faec86dc0 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 14 Mar 2025 14:55:40 +0530 Subject: [PATCH 05/10] feat: implement `vars!` for state machine operators --- src/utils.jl | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/utils.jl b/src/utils.jl index cf49d9f445..3b4ccd3be9 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -401,6 +401,25 @@ vars(eq::Equation; op = Differential) = vars!(Set(), eq; op = op) function vars!(vars, eq::Equation; op = Differential) (vars!(vars, eq.lhs; op = op); vars!(vars, eq.rhs; op = op); vars) end +function vars!(vars, O::AbstractSystem; op = Differential) + for eq in equations(O) + vars!(vars, eq; op) + end + return vars +end +function vars!(vars, O::Transition; op = Differential) + vars!(vars, O.from) + vars!(vars, O.to) + vars!(vars, O.cond; op) + return vars +end +function vars!(vars, O::InitialState; op = Differential) + vars!(vars, O.s; op) + return vars +end +function vars!(vars, O::StateMachineOperator; op = Differential) + error("Unhandled state machine operator") +end function vars!(vars, O; op = Differential) if isvariable(O) if iscall(O) && operation(O) === getindex && iscalledparameter(first(arguments(O))) From c49858614744f839f25fcd1649f2957554c2975a Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 14 Mar 2025 14:56:00 +0530 Subject: [PATCH 06/10] feat: propagate state machines in structural simplification --- src/systems/systems.jl | 2 + src/systems/systemstructure.jl | 74 ++++++++++++++++++++++++++++++++-- 2 files changed, 72 insertions(+), 4 deletions(-) diff --git a/src/systems/systems.jl b/src/systems/systems.jl index aa7762b741..924be54b38 100644 --- a/src/systems/systems.jl +++ b/src/systems/systems.jl @@ -82,8 +82,10 @@ end function __structural_simplify(sys::AbstractSystem, io = nothing; simplify = false, kwargs...) + sys, statemachines = extract_top_level_statemachines(sys) sys = expand_connections(sys) state = TearingState(sys) + append!(state.statemachines, statemachines) @unpack structure, fullvars = state @unpack graph, var_to_diff, var_types = structure diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl index 2ccc9c4d99..2a0346bccf 100644 --- a/src/systems/systemstructure.jl +++ b/src/systems/systemstructure.jl @@ -1,5 +1,5 @@ using DataStructures -using Symbolics: linear_expansion, unwrap, Connection +using Symbolics: linear_expansion, unwrap, Connection, Transition, InitialState using SymbolicUtils: iscall, operation, arguments, Symbolic using SymbolicUtils: quick_cancel, maketerm using ..ModelingToolkit @@ -202,6 +202,7 @@ mutable struct TearingState{T <: AbstractSystem} <: AbstractTearingState{T} fullvars::Vector structure::SystemStructure extra_eqs::Vector + statemachines::Vector{T} end TransformationState(sys::AbstractSystem) = TearingState(sys) @@ -210,6 +211,22 @@ function system_subset(ts::TearingState, ieqs::Vector{Int}) @set! ts.original_eqs = ts.original_eqs[ieqs] @set! ts.sys.eqs = eqs[ieqs] @set! ts.structure = system_subset(ts.structure, ieqs) + if all(eq -> eq.rhs isa StateMachineOperator, get_eqs(ts.sys)) + names = Symbol[] + for eq in get_eqs(ts.sys) + if eq.lhs isa Transition + push!(names, first(namespace_hierarchy(nameof(eq.rhs.from)))) + push!(names, first(namespace_hierarchy(nameof(eq.rhs.to)))) + elseif eq.lhs isa InitialState + push!(names, first(namespace_hierarchy(nameof(eq.rhs.s)))) + else + error("Unhandled state machine operator") + end + end + @set! ts.statemachines = filter(x -> nameof(x) in names, ts.statemachines) + else + @set! ts.statemachines = eltype(ts.statemachines)[] + end ts end @@ -249,6 +266,49 @@ function Base.push!(ev::EquationsView, eq) push!(ev.ts.extra_eqs, eq) end +""" + $(TYPEDSIGNATURES) + +Descend through the system hierarchy and look for statemachines. Remove equations from +the inner statemachine systems. Return the new `sys` and an array of top-level +statemachines. +""" +function extract_top_level_statemachines(sys::AbstractSystem) + eqs = get_eqs(sys) + + if !isempty(eqs) && all(eq -> eq.lhs isa StateMachineOperator, eqs) + # top-level statemachine + with_removed = @set sys.systems = map(remove_child_equations, get_systems(sys)) + return with_removed, [sys] + elseif !isempty(eqs) && any(eq -> eq.lhs isa StateMachineOperator, eqs) + # error: can't mix + error("Mixing statemachine equations and standard equations in a top-level statemachine is not allowed.") + else + # descend + subsystems = get_systems(sys) + newsubsystems = eltype(subsystems)[] + statemachines = eltype(subsystems)[] + for subsys in subsystems + newsubsys, sub_statemachines = extract_top_level_statemachines(subsys) + push!(newsubsystems, newsubsys) + append!(statemachines, sub_statemachines) + end + @set! sys.systems = newsubsystems + return sys, statemachines + end +end + +""" + $(TYPEDSIGNATURES) + +Return `sys` with all equations (including those in subsystems) removed. +""" +function remove_child_equations(sys::AbstractSystem) + @set! sys.eqs = eltype(get_eqs(sys))[] + @set! sys.systems = map(remove_child_equations, get_systems(sys)) + return sys +end + function TearingState(sys; quick_cancel = false, check = true) sys = flatten(sys) ivs = independent_variables(sys) @@ -278,7 +338,12 @@ function TearingState(sys; quick_cancel = false, check = true) check ? error("$(nameof(sys)) has unexpanded `connect` statements") : return nothing end - if _iszero(eq′.lhs) + is_statemachine_equation = false + if eq′.lhs isa StateMachineOperation + is_statemachine_equation = true + eq = eq′ + rhs = eq.rhs + elseif _iszero(eq′.lhs) rhs = quick_cancel ? quick_cancel_expr(eq′.rhs) : eq′.rhs eq = eq′ else @@ -343,7 +408,7 @@ function TearingState(sys; quick_cancel = false, check = true) empty!(unknownvars) empty!(vars) empty!(varsvec) - if isalgeq + if isalgeq || is_statemachine_equation eqs[i] = eq else eqs[i] = eqs[i].lhs ~ rhs @@ -434,7 +499,7 @@ function TearingState(sys; quick_cancel = false, check = true) ts = TearingState(sys, original_eqs, fullvars, SystemStructure(complete(var_to_diff), complete(eq_to_diff), complete(graph), nothing, var_types, sys isa DiscreteSystem), - Any[]) + Any[], typeof(sys)[]) if sys isa DiscreteSystem ts = shift_discrete_system(ts) end @@ -676,6 +741,7 @@ function structural_simplify!(state::TearingState, io = nothing; simplify = fals check_consistency, fully_determined, kwargs...) if length(tss) > 1 + additional_passes = get(kwargs, :additional_passes, nothing) if !isnothing(additional_passes) && any(discrete_compile_pass, additional_passes) discrete_pass_idx = findfirst(discrete_compile_pass, additional_passes) From 0f1e24fda1caf65910028008d5a7d99a3f6adcee Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 26 Mar 2025 14:57:50 +0530 Subject: [PATCH 07/10] fix: import `StateMachineOperator` from Symbolics.jl --- src/ModelingToolkit.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index 5bbc8d9aed..5db5b3992b 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -69,7 +69,8 @@ using Symbolics: _parse_vars, value, @derivatives, get_variables, NAMESPACE_SEPARATOR, set_scalar_metadata, setdefaultval, initial_state, transition, activeState, entry, hasnode, ticksInState, timeInState, fixpoint_sub, fast_substitute, - CallWithMetadata, CallWithParent, Transition, InitialState + CallWithMetadata, CallWithParent, Transition, InitialState, + StateMachineOperator const NAMESPACE_SEPARATOR_SYMBOL = Symbol(NAMESPACE_SEPARATOR) import Symbolics: rename, get_variables!, _solve, hessian_sparsity, jacobian_sparsity, isaffine, islinear, _iszero, _isone, From e4b65d6431b814de468cf00a4d4d1710b0b8f2e3 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 26 Mar 2025 15:08:22 +0530 Subject: [PATCH 08/10] fix: use `StateMachineOperator` not `StateMachineOperation` --- src/systems/systemstructure.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl index 2a0346bccf..664bf8c570 100644 --- a/src/systems/systemstructure.jl +++ b/src/systems/systemstructure.jl @@ -339,7 +339,7 @@ function TearingState(sys; quick_cancel = false, check = true) return nothing end is_statemachine_equation = false - if eq′.lhs isa StateMachineOperation + if eq′.lhs isa StateMachineOperator is_statemachine_equation = true eq = eq′ rhs = eq.rhs From 24bf3812234f23f4f2ae71132dec947893b49305 Mon Sep 17 00:00:00 2001 From: Benjamin Chung Date: Fri, 14 Mar 2025 18:24:11 -0700 Subject: [PATCH 09/10] Handle nothing updates better --- src/systems/imperative_affect.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/systems/imperative_affect.jl b/src/systems/imperative_affect.jl index 4c9ff3d248..56c3721317 100644 --- a/src/systems/imperative_affect.jl +++ b/src/systems/imperative_affect.jl @@ -211,7 +211,9 @@ function compile_user_affect(affect::ImperativeAffect, cb, sys, dvs, ps; kwargs. upd_vals = user_affect(upd_component_array, obs_component_array, ctx, integ) # write the new values back to the integrator - _generated_writeback(integ, upd_funs, upd_vals) + if !isnothing(upd_vals) + _generated_writeback(integ, upd_funs, upd_vals) + end for idx in save_idxs SciMLBase.save_discretes!(integ, idx) From dc70f88a80682ca47a2f76db7f9a38969d7d6334 Mon Sep 17 00:00:00 2001 From: Benjamin Chung Date: Fri, 14 Mar 2025 18:24:24 -0700 Subject: [PATCH 10/10] Redefine the discrete_compile interface a bit --- src/systems/systemstructure.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl index 664bf8c570..6f51311cbb 100644 --- a/src/systems/systemstructure.jl +++ b/src/systems/systemstructure.jl @@ -727,7 +727,7 @@ function structural_simplify!(state::TearingState, io = nothing; simplify = fals discrete_pass_idx = findfirst(discrete_compile_pass, additional_passes) discrete_compile = additional_passes[discrete_pass_idx] deleteat!(additional_passes, discrete_pass_idx) - return discrete_compile(tss, inputs) + return discrete_compile(tss, inputs, ci) else # error goes here! this is a purely discrete system throw(HybridSystemNotSupportedException("Discrete systems without JuliaSimCompiler are currently not supported in ODESystem.")) @@ -749,7 +749,7 @@ function structural_simplify!(state::TearingState, io = nothing; simplify = fals deleteat!(additional_passes, discrete_pass_idx) # in the case of a hybrid system, the discrete_compile pass should take the currents of sys.discrete_subsystems # and modifies discrete_subsystems to bea tuple of the io and anything else, while adding or manipulating the rest of sys as needed - sys = discrete_compile(sys, tss[2:end], inputs) + sys = discrete_compile(sys, tss[[i for i in eachindex(tss) if i != continuous_id]], inputs, ci) else throw(HybridSystemNotSupportedException("Hybrid continuous-discrete systems are currently not supported with the standard MTK compiler. This system requires JuliaSimCompiler.jl, see https://help.juliahub.com/juliasimcompiler/stable/")) end