Skip to content

Commit dfa4db8

Browse files
docs: update remake doc page to account for extra tunables
1 parent 7baf788 commit dfa4db8

File tree

1 file changed

+19
-36
lines changed

1 file changed

+19
-36
lines changed

docs/src/examples/remake.md

+19-36
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,22 @@ parameters to optimize.
4545

4646
```@example Remake
4747
using SymbolicIndexingInterface: parameter_values, state_values
48-
using SciMLStructures: Tunable, replace, replace!
48+
using SciMLStructures: Tunable, canonicalize, replace, replace!
49+
using PreallocationTools
4950
5051
function loss(x, p)
5152
odeprob = p[1] # ODEProblem stored as parameters to avoid using global variables
5253
ps = parameter_values(odeprob) # obtain the parameter object from the problem
53-
ps = replace(Tunable(), ps, x) # create a copy with the values passed to the loss function
54+
diffcache = p[5]
55+
# get an appropriately typed preallocated buffer to store the `x` values in
56+
buffer = get_tmp(diffcache, x)
57+
# copy the current values to this buffer
58+
copyto!(buffer, canonicalize(Tunable(), ps)[1])
59+
# create a copy of the parameter object with the buffer
60+
ps = replace(Tunable(), ps, buffer)
61+
# set the updated values in the parameter object
62+
setter = p[4]
63+
setter(ps, x)
5464
# remake the problem, passing in our new parameter object
5565
newprob = remake(odeprob; p = ps)
5666
timesteps = p[2]
@@ -81,49 +91,22 @@ We can perform the optimization as below:
8191
```@example Remake
8292
using Optimization
8393
using OptimizationOptimJL
94+
using SymbolicIndexingInterface
8495
8596
# manually create an OptimizationFunction to ensure usage of `ForwardDiff`, which will
8697
# require changing the types of parameters from `Float64` to `ForwardDiff.Dual`
8798
optfn = OptimizationFunction(loss, Optimization.AutoForwardDiff())
99+
# function to set the parameters we are optimizing
100+
setter = setp(odeprob, [α, β, γ, δ])
101+
# `DiffCache` to avoid allocations
102+
diffcache = DiffCache(canonicalize(Tunable(), parameter_values(odeprob))[1])
88103
# parameter object is a tuple, to store differently typed objects together
89104
optprob = OptimizationProblem(
90-
optfn, rand(4), (odeprob, timesteps, data), lb = 0.1zeros(4), ub = 3ones(4))
105+
optfn, rand(4), (odeprob, timesteps, data, setter, diffcache),
106+
lb = 0.1zeros(4), ub = 3ones(4))
91107
sol = solve(optprob, BFGS())
92108
```
93109

94-
To identify which values correspond to which parameters, we can `replace!` them into the
95-
`ODEProblem`:
96-
97-
```@example Remake
98-
replace!(Tunable(), parameter_values(odeprob), sol.u)
99-
odeprob.ps[[α, β, γ, δ]]
100-
```
101-
102-
`replace!` operates in-place, so the values being replaced must be of the same type as those
103-
stored in the parameter object, or convertible to that type. For demonstration purposes, we
104-
can construct a loss function that uses `replace!`, and calculate gradients using
105-
`AutoFiniteDiff` rather than `AutoForwardDiff`.
106-
107-
```@example Remake
108-
function loss2(x, p)
109-
odeprob = p[1] # ODEProblem stored as parameters to avoid using global variables
110-
newprob = remake(odeprob) # copy the problem with `remake`
111-
# update the parameter values in-place
112-
replace!(Tunable(), parameter_values(newprob), x)
113-
timesteps = p[2]
114-
sol = solve(newprob, AutoTsit5(Rosenbrock23()); saveat = timesteps)
115-
truth = p[3]
116-
data = Array(sol)
117-
return sum((truth .- data) .^ 2) / length(truth)
118-
end
119-
120-
# use finite-differencing to calculate derivatives
121-
optfn2 = OptimizationFunction(loss2, Optimization.AutoFiniteDiff())
122-
optprob2 = OptimizationProblem(
123-
optfn2, rand(4), (odeprob, timesteps, data), lb = 0.1zeros(4), ub = 3ones(4))
124-
sol = solve(optprob2, BFGS())
125-
```
126-
127110
# Re-creating the problem
128111

129112
There are multiple ways to re-create a problem with new state/parameter values. We will go

0 commit comments

Comments
 (0)