Skip to content

feat: create initialization systems for all problem types #3253

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Dec 25, 2024
Merged
Changes from 1 commit
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
219aee3
refactor: add guesses to `SDESystem`, `NonlinearSystem`, `JumpSystem`
AayushSabharwal Dec 1, 2024
af8cd67
feat: support arbitrary systems in `generate_initializesystem`
AayushSabharwal Dec 1, 2024
3928194
refactor: use `initialization_data` in SciMLFunction constructors
AayushSabharwal Dec 2, 2024
4044317
fix: don't build initializeprob for initializeprob
AayushSabharwal Dec 2, 2024
180b978
feat: build initialization system for all system types in `process_Sc…
AayushSabharwal Dec 2, 2024
c9c613f
fix: retain system data on `structural_simplify` of `SDESystem`
AayushSabharwal Dec 2, 2024
4d5daa3
fix: pass `t` to `process_SciMLProblem` in `SDEProblem`
AayushSabharwal Dec 3, 2024
8d09409
feat: support arbitrary systems in `remake_initialization_data`
AayushSabharwal Dec 3, 2024
def207b
fix: fix type promotion bug in `remake_buffer`
AayushSabharwal Dec 3, 2024
d971b18
test: test initialization on `SDEProblem`, `DDEProblem`, `SDDEProblem`
AayushSabharwal Dec 3, 2024
39bb59c
fix: handle integer `u0` in `DDEProblem`
AayushSabharwal Dec 6, 2024
2f2e625
feat: enable creating `InitializationProblem` for non-`AbstractODESys…
AayushSabharwal Dec 6, 2024
671b93f
fix: filter kwargs in `SDEProblem`
AayushSabharwal Dec 6, 2024
9987da0
test: test initialization on `NonlinearProblem` and `NonlinearLeastSq…
AayushSabharwal Dec 6, 2024
51eeeeb
fix: store and propagate `initialization_eqs` provided to Problem
AayushSabharwal Dec 6, 2024
96f8d5d
build: bump compats
AayushSabharwal Dec 14, 2024
0a881e7
fix: better handle reconstructing initializeprob with new types
AayushSabharwal Dec 16, 2024
2e07200
test: fix incorrect initial values in tests
AayushSabharwal Dec 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feat: support arbitrary systems in generate_initializesystem
  • Loading branch information
AayushSabharwal committed Dec 24, 2024
commit af8cd679571d6c2d078fbebdd9219536b270f335
65 changes: 38 additions & 27 deletions src/systems/nonlinear/initializesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ $(TYPEDSIGNATURES)

Generate `NonlinearSystem` which initializes an ODE problem from specified initial conditions of an `ODESystem`.
"""
function generate_initializesystem(sys::ODESystem;
function generate_initializesystem(sys::AbstractSystem;
u0map = Dict(),
pmap = Dict(),
initialization_eqs = [],
Expand All @@ -12,28 +12,36 @@ function generate_initializesystem(sys::ODESystem;
algebraic_only = false,
check_units = true, check_defguess = false,
name = nameof(sys), extra_metadata = (;), kwargs...)
trueobs, eqs = unhack_observed(observed(sys), equations(sys))
eqs = equations(sys)
eqs = filter(x -> x isa Equation, eqs)
trueobs, eqs = unhack_observed(observed(sys), eqs)
vars = unique([unknowns(sys); getfield.(trueobs, :lhs)])
vars_set = Set(vars) # for efficient in-lookup

idxs_diff = isdiffeq.(eqs)
idxs_alge = .!idxs_diff

# prepare map for dummy derivative substitution
eqs_diff = eqs[idxs_diff]
D = Differential(get_iv(sys))
diffmap = merge(
Dict(eq.lhs => eq.rhs for eq in eqs_diff),
Dict(D(eq.lhs) => D(eq.rhs) for eq in trueobs)
)

# 1) process dummy derivatives and u0map into initialization system
eqs_ics = eqs[idxs_alge] # start equation list with algebraic equations
eqs_ics = Equation[]
defs = copy(defaults(sys)) # copy so we don't modify sys.defaults
additional_guesses = anydict(guesses)
guesses = merge(get_guesses(sys), additional_guesses)
schedule = getfield(sys, :schedule)
if !isnothing(schedule)
idxs_diff = isdiffeq.(eqs)

# 1) Use algebraic equations of time-dependent systems as initialization constraints
if has_iv(sys)
idxs_alge = .!idxs_diff
append!(eqs_ics, eqs[idxs_alge]) # start equation list with algebraic equations

eqs_diff = eqs[idxs_diff]
D = Differential(get_iv(sys))
diffmap = merge(
Dict(eq.lhs => eq.rhs for eq in eqs_diff),
Dict(D(eq.lhs) => D(eq.rhs) for eq in trueobs)
)
else
diffmap = Dict()
end

if has_schedule(sys) && (schedule = get_schedule(sys); !isnothing(schedule))
# 2) process dummy derivatives and u0map into initialization system
# prepare map for dummy derivative substitution
for x in filter(x -> !isnothing(x[1]), schedule.dummy_sub)
# set dummy derivatives to default_dd_guess unless specified
push!(defs, x[1] => get(guesses, x[1], default_dd_guess))
Expand Down Expand Up @@ -61,9 +69,14 @@ function generate_initializesystem(sys::ODESystem;
process_u0map_with_dummysubs(y, x)
end
end
else
# 2) System doesn't have a schedule, so dummy derivatives don't exist/aren't handled (SDESystem)
for (k, v) in u0map
defs[k] = v
end
end

# 2) process other variables
# 3) process other variables
for var in vars
if var ∈ keys(defs)
push!(eqs_ics, var ~ defs[var])
Expand All @@ -74,7 +87,7 @@ function generate_initializesystem(sys::ODESystem;
end
end

# 3) process explicitly provided initialization equations
# 4) process explicitly provided initialization equations
if !algebraic_only
initialization_eqs = [get_initialization_eqs(sys); initialization_eqs]
for eq in initialization_eqs
Expand All @@ -83,7 +96,7 @@ function generate_initializesystem(sys::ODESystem;
end
end

# 4) process parameters as initialization unknowns
# 5) process parameters as initialization unknowns
paramsubs = Dict()
if pmap isa SciMLBase.NullParameters
pmap = Dict()
Expand Down Expand Up @@ -138,7 +151,7 @@ function generate_initializesystem(sys::ODESystem;
end
end

# 5) parameter dependencies become equations, their LHS become unknowns
# 6) parameter dependencies become equations, their LHS become unknowns
# non-numeric dependent parameters stay as parameter dependencies
new_parameter_deps = Equation[]
for eq in parameter_dependencies(sys)
Expand All @@ -153,20 +166,18 @@ function generate_initializesystem(sys::ODESystem;
push!(defs, varp => guessval)
end

# 6) handle values provided for dependent parameters similar to values for observed variables
# 7) handle values provided for dependent parameters similar to values for observed variables
for (k, v) in merge(defaults(sys), pmap)
if is_variable_floatingpoint(k) && has_parameter_dependency_with_lhs(sys, k)
push!(eqs_ics, paramsubs[k] ~ v)
end
end

# parameters do not include ones that became initialization unknowns
pars = vcat(
[get_iv(sys)], # include independent variable as pseudo-parameter
[p for p in parameters(sys) if !haskey(paramsubs, p)]
)
pars = Vector{SymbolicParam}(filter(p -> !haskey(paramsubs, p), parameters(sys)))
is_time_dependent(sys) && push!(pars, get_iv(sys))

# 7) use observed equations for guesses of observed variables if not provided
# 8) use observed equations for guesses of observed variables if not provided
for eq in trueobs
haskey(defs, eq.lhs) && continue
any(x -> isequal(default_toterm(x), eq.lhs), keys(defs)) && continue
Expand Down