Skip to content

Commit e8ccbbe

Browse files
Merge pull request #3347 from AayushSabharwal/as/init-no-ts
feat: always build initialization problem
2 parents a2db412 + 0b72737 commit e8ccbbe

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+1722
-1302
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ RecursiveArrayTools = "3.26"
135135
Reexport = "0.2, 1"
136136
RuntimeGeneratedFunctions = "0.5.9"
137137
SCCNonlinearSolve = "1.0.0"
138-
SciMLBase = "2.72.1"
138+
SciMLBase = "2.73"
139139
SciMLStructures = "1.0"
140140
Serialization = "1"
141141
Setfield = "0.7, 0.8, 1"

docs/Project.toml

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
1515
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
1616
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
1717
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
18+
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
1819
SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226"
1920
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
2021
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"

docs/src/basics/Variable_metadata.md

+1
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ state_priority(important_dof)
201201
Units for variables can be designated using symbolic metadata. For more information, please see the [model validation and units](@ref units) section of the docs. Note that `getunit` is not equivalent to `get_unit` - the former is a metadata getter for individual variables (and is provided so the same interface function for `unit` exists like other metadata), while the latter is used to handle more general symbolic expressions.
202202

203203
```@example metadata
204+
using DynamicQuantities
204205
@variables speed [unit = u"m/s"]
205206
hasunit(speed)
206207
```

docs/src/examples/remake.md

+19-36
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,22 @@ parameters to optimize.
4545

4646
```@example Remake
4747
using SymbolicIndexingInterface: parameter_values, state_values
48-
using SciMLStructures: Tunable, replace, replace!
48+
using SciMLStructures: Tunable, canonicalize, replace, replace!
49+
using PreallocationTools
4950
5051
function loss(x, p)
5152
odeprob = p[1] # ODEProblem stored as parameters to avoid using global variables
5253
ps = parameter_values(odeprob) # obtain the parameter object from the problem
53-
ps = replace(Tunable(), ps, x) # create a copy with the values passed to the loss function
54+
diffcache = p[5]
55+
# get an appropriately typed preallocated buffer to store the `x` values in
56+
buffer = get_tmp(diffcache, x)
57+
# copy the current values to this buffer
58+
copyto!(buffer, canonicalize(Tunable(), ps)[1])
59+
# create a copy of the parameter object with the buffer
60+
ps = replace(Tunable(), ps, buffer)
61+
# set the updated values in the parameter object
62+
setter = p[4]
63+
setter(ps, x)
5464
# remake the problem, passing in our new parameter object
5565
newprob = remake(odeprob; p = ps)
5666
timesteps = p[2]
@@ -81,49 +91,22 @@ We can perform the optimization as below:
8191
```@example Remake
8292
using Optimization
8393
using OptimizationOptimJL
94+
using SymbolicIndexingInterface
8495
8596
# manually create an OptimizationFunction to ensure usage of `ForwardDiff`, which will
8697
# require changing the types of parameters from `Float64` to `ForwardDiff.Dual`
8798
optfn = OptimizationFunction(loss, Optimization.AutoForwardDiff())
99+
# function to set the parameters we are optimizing
100+
setter = setp(odeprob, [α, β, γ, δ])
101+
# `DiffCache` to avoid allocations
102+
diffcache = DiffCache(canonicalize(Tunable(), parameter_values(odeprob))[1])
88103
# parameter object is a tuple, to store differently typed objects together
89104
optprob = OptimizationProblem(
90-
optfn, rand(4), (odeprob, timesteps, data), lb = 0.1zeros(4), ub = 3ones(4))
105+
optfn, rand(4), (odeprob, timesteps, data, setter, diffcache),
106+
lb = 0.1zeros(4), ub = 3ones(4))
91107
sol = solve(optprob, BFGS())
92108
```
93109

94-
To identify which values correspond to which parameters, we can `replace!` them into the
95-
`ODEProblem`:
96-
97-
```@example Remake
98-
replace!(Tunable(), parameter_values(odeprob), sol.u)
99-
odeprob.ps[[α, β, γ, δ]]
100-
```
101-
102-
`replace!` operates in-place, so the values being replaced must be of the same type as those
103-
stored in the parameter object, or convertible to that type. For demonstration purposes, we
104-
can construct a loss function that uses `replace!`, and calculate gradients using
105-
`AutoFiniteDiff` rather than `AutoForwardDiff`.
106-
107-
```@example Remake
108-
function loss2(x, p)
109-
odeprob = p[1] # ODEProblem stored as parameters to avoid using global variables
110-
newprob = remake(odeprob) # copy the problem with `remake`
111-
# update the parameter values in-place
112-
replace!(Tunable(), parameter_values(newprob), x)
113-
timesteps = p[2]
114-
sol = solve(newprob, AutoTsit5(Rosenbrock23()); saveat = timesteps)
115-
truth = p[3]
116-
data = Array(sol)
117-
return sum((truth .- data) .^ 2) / length(truth)
118-
end
119-
120-
# use finite-differencing to calculate derivatives
121-
optfn2 = OptimizationFunction(loss2, Optimization.AutoFiniteDiff())
122-
optprob2 = OptimizationProblem(
123-
optfn2, rand(4), (odeprob, timesteps, data), lb = 0.1zeros(4), ub = 3ones(4))
124-
sol = solve(optprob2, BFGS())
125-
```
126-
127110
# Re-creating the problem
128111

129112
There are multiple ways to re-create a problem with new state/parameter values. We will go

docs/src/tutorials/initialization.md

+70
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,73 @@ long enough you will see that `λ = 0` is required for this equation, but since
203203
problem constructor. Additionally, any warning about not being fully determined can
204204
be suppressed via passing `warn_initialize_determined = false`.
205205

206+
## Constant constraints in initialization
207+
208+
Consider the pendulum system again:
209+
210+
```@repl init
211+
equations(pend)
212+
observed(pend)
213+
```
214+
215+
Suppose we want to solve the same system with multiple different initial
216+
y-velocities from a given position.
217+
218+
```@example init
219+
prob = ODEProblem(
220+
pend, [x => 1, D(y) => 0], (0.0, 1.5), [g => 1], guesses = [λ => 0, y => 1, x => 1])
221+
sol1 = solve(prob, Rodas5P())
222+
```
223+
224+
```@example init
225+
sol1[D(y), 1]
226+
```
227+
228+
Repeatedly re-creating the `ODEProblem` with different values of `D(y)` and `x` or
229+
repeatedly calling `remake` is slow. Instead, for any `variable => constant` constraint
230+
in the `ODEProblem` initialization (whether provided to the `ODEProblem` constructor or
231+
a default value) we can update the `constant` value. ModelingToolkit refers to these
232+
values using the `Initial` operator. For example:
233+
234+
```@example init
235+
prob.ps[[Initial(x), Initial(D(y))]]
236+
```
237+
238+
To solve with a different starting y-velocity, we can simply do
239+
240+
```@example init
241+
prob.ps[Initial(D(y))] = -0.1
242+
sol2 = solve(prob, Rodas5P())
243+
```
244+
245+
```@example init
246+
sol2[D(y), 1]
247+
```
248+
249+
Note that this _only_ applies for constant constraints for the current ODEProblem.
250+
For example, `D(x)` does not have a constant constraint - it is solved for by
251+
initialization. Thus, mutating `Initial(D(x))` does not have any effect:
252+
253+
```@repl init
254+
sol2[D(x), 1]
255+
prob.ps[Initial(D(x))] = 1.0
256+
sol3 = solve(prob, Rodas5P())
257+
sol3[D(x), 1]
258+
```
259+
260+
To enforce this constraint, we would have to `remake` the problem (or construct a new one).
261+
262+
```@repl init
263+
prob2 = remake(prob; u0 = [y => 0.0, D(x) => 0.0, x => nothing, D(y) => nothing]);
264+
sol4 = solve(prob2, Rodas5P())
265+
sol4[D(x), 1]
266+
```
267+
268+
Note the need to provide `x => nothing, D(y) => nothing` to override the previously
269+
provided initial conditions. Since `remake` is a partial update, the constraints provided
270+
to it are merged with the ones already present in the problem. Existing constraints can be
271+
removed by providing a value of `nothing`.
272+
206273
## Initialization of parameters
207274

208275
Parameters may also be treated as unknowns in the initialization system. Doing so works
@@ -231,6 +298,9 @@ constraints provided to it. The new values will be combined with the original
231298
variable-value mapping provided to `ODEProblem` and used to construct the initialization
232299
problem.
233300

301+
The variable on the left hand side of all parameter dependencies also has an `Initial`
302+
variant, which is used if a constant constraint is provided for the variable.
303+
234304
### Parameter initialization by example
235305

236306
Consider the following system, where the sum of two unknowns is a constant parameter

src/ModelingToolkit.jl

+6-3
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ using Compat
4545
using AbstractTrees
4646
using DiffEqBase, SciMLBase, ForwardDiff
4747
using SciMLBase: StandardODEProblem, StandardNonlinearProblem, handle_varmap, TimeDomain,
48-
PeriodicClock, Clock, SolverStepClock, Continuous
48+
PeriodicClock, Clock, SolverStepClock, Continuous, OverrideInit, NoInit
4949
using Distributed
5050
import JuliaFormatter
5151
using MLStyle
@@ -56,6 +56,7 @@ using RecursiveArrayTools
5656
import Graphs: SimpleDiGraph, add_edge!, incidence_matrix
5757
import BlockArrays: BlockArray, BlockedArray, Block, blocksize, blocksizes, blockpush!,
5858
undef_blocks, blocks
59+
using OffsetArrays: Origin
5960
import CommonSolve
6061
import EnumX
6162

@@ -152,6 +153,7 @@ include("systems/imperative_affect.jl")
152153
include("systems/callbacks.jl")
153154
include("systems/codegen_utils.jl")
154155
include("systems/problem_utils.jl")
156+
include("linearization.jl")
155157

156158
include("systems/nonlinear/nonlinearsystem.jl")
157159
include("systems/nonlinear/homotopy_continuation.jl")
@@ -258,7 +260,8 @@ export Term, Sym
258260
export SymScope, LocalScope, ParentScope, DelayParentScope, GlobalScope
259261
export independent_variable, equations, controls, observed, full_equations
260262
export initialization_equations, guesses, defaults, parameter_dependencies, hierarchy
261-
export structural_simplify, expand_connections, linearize, linearization_function
263+
export structural_simplify, expand_connections, linearize, linearization_function,
264+
LinearizationProblem
262265

263266
export calculate_jacobian, generate_jacobian, generate_function, generate_custom_function
264267
export calculate_control_jacobian, generate_control_jacobian
@@ -278,7 +281,7 @@ export toexpr, get_variables
278281
export simplify, substitute
279282
export build_function
280283
export modelingtoolkitize
281-
export generate_initializesystem
284+
export generate_initializesystem, Initial
282285

283286
export alg_equations, diff_equations, has_alg_equations, has_diff_equations
284287
export get_alg_eqs, get_diff_eqs, has_alg_eqs, has_diff_eqs

src/inputoutput.jl

+5-6
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ function generate_control_function(sys::AbstractODESystem, inputs = unbound_inpu
211211
sys, _ = io_preprocessing(sys, inputs, []; simplify, kwargs...)
212212

213213
dvs = unknowns(sys)
214-
ps = parameters(sys)
214+
ps = parameters(sys; initial_parameters = true)
215215
ps = setdiff(ps, inputs)
216216
if disturbance_inputs !== nothing
217217
# remove from inputs since we do not want them as actual inputs to the dynamics
@@ -234,16 +234,14 @@ function generate_control_function(sys::AbstractODESystem, inputs = unbound_inpu
234234
[eq.rhs for eq in eqs]
235235

236236
# TODO: add an optional check on the ordering of observed equations
237-
u = map(x -> time_varying_as_func(value(x), sys), dvs)
238-
p = map(x -> time_varying_as_func(value(x), sys), ps)
239-
p = reorder_parameters(sys, p)
237+
p = reorder_parameters(sys, ps)
240238
t = get_iv(sys)
241239

242240
# pre = has_difference ? (ex -> ex) : get_postprocess_fbody(sys)
243241
if disturbance_argument
244-
args = (u, inputs, p..., t, disturbance_inputs)
242+
args = (dvs, inputs, p..., t, disturbance_inputs)
245243
else
246-
args = (u, inputs, p..., t)
244+
args = (dvs, inputs, p..., t)
247245
end
248246
if implicit_dae
249247
ddvs = map(Differential(get_iv(sys)), dvs)
@@ -252,6 +250,7 @@ function generate_control_function(sys::AbstractODESystem, inputs = unbound_inpu
252250
f = build_function_wrapper(sys, rhss, args...; p_start = 3 + implicit_dae,
253251
p_end = length(p) + 2 + implicit_dae)
254252
f = eval_or_rgf.(f; eval_expression, eval_module)
253+
ps = setdiff(parameters(sys), inputs, disturbance_inputs)
255254
(; f, dvs, ps, io_sys = sys)
256255
end
257256

0 commit comments

Comments
 (0)