Skip to content

Commit da739eb

Browse files
Merge pull request #3466 from AayushSabharwal/as/better-nlsys-init
feat: handle `Initial(x)` initialization_eqs in time-independent systems
2 parents b7bb95d + ec2fbb4 commit da739eb

15 files changed

+392
-176
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ DataInterpolations = "6.4"
9292
DataStructures = "0.17, 0.18"
9393
DeepDiffs = "1"
9494
DelayDiffEq = "5.50"
95-
DiffEqBase = "6.157"
95+
DiffEqBase = "6.165.1"
9696
DiffEqCallbacks = "2.16, 3, 4"
9797
DiffEqNoiseProcess = "5"
9898
DiffRules = "0.1, 1.0"

src/doc

-1
This file was deleted.

src/systems/abstractsystem.jl

+26-38
Original file line numberDiff line numberDiff line change
@@ -676,23 +676,27 @@ function SymbolicUtils.maketerm(::Type{<:BasicSymbolic}, ::Initial, args, meta)
676676
return metadata(val, meta)
677677
end
678678

679+
supports_initialization(sys::AbstractSystem) = true
680+
679681
function add_initialization_parameters(sys::AbstractSystem)
680682
@assert !has_systems(sys) || isempty(get_systems(sys))
683+
supports_initialization(sys) || return sys
684+
is_initializesystem(sys) && return sys
685+
681686
all_initialvars = Set{BasicSymbolic}()
682687
# time-independent systems don't initialize unknowns
683-
if is_time_dependent(sys)
684-
eqs = equations(sys)
685-
if !(eqs isa Vector{Equation})
686-
eqs = Equation[x for x in eqs if x isa Equation]
687-
end
688-
obs, eqs = unhack_observed(observed(sys), eqs)
689-
for x in Iterators.flatten((unknowns(sys), Iterators.map(eq -> eq.lhs, obs)))
690-
x = unwrap(x)
691-
if iscall(x) && operation(x) == getindex
692-
push!(all_initialvars, arguments(x)[1])
693-
else
694-
push!(all_initialvars, x)
695-
end
688+
# but may initialize parameters using guesses for unknowns
689+
eqs = equations(sys)
690+
if !(eqs isa Vector{Equation})
691+
eqs = Equation[x for x in eqs if x isa Equation]
692+
end
693+
obs, eqs = unhack_observed(observed(sys), eqs)
694+
for x in Iterators.flatten((unknowns(sys), Iterators.map(eq -> eq.lhs, obs)))
695+
x = unwrap(x)
696+
if iscall(x) && operation(x) == getindex
697+
push!(all_initialvars, arguments(x)[1])
698+
else
699+
push!(all_initialvars, x)
696700
end
697701
end
698702
for eq in parameter_dependencies(sys)
@@ -722,15 +726,8 @@ Returns true if the parameter `p` is of the form `Initial(x)`.
722726
"""
723727
function isinitial(p)
724728
p = unwrap(p)
725-
if iscall(p)
726-
operation(p) isa Initial && return true
727-
if operation(p) === getindex
728-
operation(arguments(p)[1]) isa Initial && return true
729-
end
730-
else
731-
return false
732-
end
733-
return false
729+
return iscall(p) && (operation(p) isa Initial ||
730+
operation(p) === getindex && isinitial(arguments(p)[1]))
734731
end
735732

736733
"""
@@ -744,7 +741,8 @@ the global structure of the system.
744741
One property to note is that if a system is complete, the system will no longer
745742
namespace its subsystems or variables, i.e. `isequal(complete(sys).v.i, v.i)`.
746743
"""
747-
function complete(sys::AbstractSystem; split = true, flatten = true)
744+
function complete(
745+
sys::AbstractSystem; split = true, flatten = true, add_initial_parameters = true)
748746
newunknowns = OrderedSet()
749747
newparams = OrderedSet()
750748
iv = has_iv(sys) ? get_iv(sys) : nothing
@@ -765,7 +763,9 @@ function complete(sys::AbstractSystem; split = true, flatten = true)
765763
@set! newsys.parent = complete(sys; split = false, flatten = false)
766764
end
767765
sys = newsys
768-
sys = add_initialization_parameters(sys)
766+
if add_initial_parameters
767+
sys = add_initialization_parameters(sys)
768+
end
769769
end
770770
if split && has_index_cache(sys)
771771
@set! sys.index_cache = IndexCache(sys)
@@ -1345,20 +1345,8 @@ function parameters(sys::AbstractSystem; initial_parameters = false)
13451345
systems = get_systems(sys)
13461346
result = unique(isempty(systems) ? ps :
13471347
[ps; reduce(vcat, namespace_parameters.(systems))])
1348-
if !initial_parameters
1349-
if is_time_dependent(sys)
1350-
# time-dependent systems have `Initial` parameters for all their
1351-
# unknowns/pdeps, all of which should be hidden.
1352-
filter!(x -> !iscall(x) || !isa(operation(x), Initial), result)
1353-
else
1354-
# time-independent systems only have `Initial` parameters for
1355-
# pdeps. Any other `Initial` parameters should be kept (e.g. initialization
1356-
# systems)
1357-
filter!(
1358-
x -> !iscall(x) || !isa(operation(x), Initial) ||
1359-
!has_parameter_dependency_with_lhs(sys, only(arguments(x))),
1360-
result)
1361-
end
1348+
if !initial_parameters && !is_initializesystem(sys)
1349+
filter!(x -> !iscall(x) || !isa(operation(x), Initial), result)
13621350
end
13631351
return result
13641352
end

src/systems/discrete_system/discrete_system.jl

+2
Original file line numberDiff line numberDiff line change
@@ -421,3 +421,5 @@ end
421421
function DiscreteFunctionExpr(sys::DiscreteSystem, args...; kwargs...)
422422
DiscreteFunctionExpr{true}(sys, args...; kwargs...)
423423
end
424+
425+
supports_initialization(::DiscreteSystem) = false

src/systems/discrete_system/implicit_discrete_system.jl

-2
Original file line numberDiff line numberDiff line change
@@ -333,10 +333,8 @@ function SciMLBase.ImplicitDiscreteProblem(
333333

334334
u0map = to_varmap(u0map, dvs)
335335
u0map = shift_u0map_forward(sys, u0map, defaults(sys))
336-
@show u0map
337336
f, u0, p = process_SciMLProblem(
338337
ImplicitDiscreteFunction, sys, u0map, parammap; eval_expression, eval_module, kwargs...)
339-
@show u0
340338

341339
kwargs = filter_kwargs(kwargs)
342340
ImplicitDiscreteProblem(f, u0, tspan, p; kwargs...)

src/systems/jumps/jumpsystem.jl

+3-1
Original file line numberDiff line numberDiff line change
@@ -507,7 +507,7 @@ function DiffEqBase.ODEProblem(sys::JumpSystem, u0map, tspan::Union{Tuple, Nothi
507507
systems = get_systems(sys), defaults = defaults(sys), guesses = guesses(sys),
508508
parameter_dependencies = parameter_dependencies(sys),
509509
metadata = get_metadata(sys), gui_metadata = get_gui_metadata(sys))
510-
osys = complete(osys)
510+
osys = complete(osys; add_initial_parameters = false)
511511
return ODEProblem(osys, u0map, tspan, parammap; check_length = false,
512512
build_initializeprob = false, kwargs...)
513513
else
@@ -685,3 +685,5 @@ function (ratemap::JumpSysMajParamMapper{U, V, W})(maj::MassActionJump, newparam
685685
scale_rates && JumpProcesses.scalerates!(maj.scaled_rates, maj.reactant_stoch)
686686
nothing
687687
end
688+
689+
supports_initialization(::JumpSystem) = false

0 commit comments

Comments
 (0)