You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am following the somewhat recent tutorial on re-creating MTK problems. In summary, the function loss is defined, which remakes an ODEProblem by replacing the tunable portion of the parameters with the input argument, and then evaluates the loss function. Upon calling the loss function, the parameters in the original ODEProblem object (in the global scope) are mutated.
MWE:
using ModelingToolkit
using ModelingToolkit: t_nounits as t, D_nounits as D
@parameters α β γ δ
@variables x(t) y(t)
eqs = [D(x) ~ (α - β * y) * x
D(y) ~ (δ * x - γ) * y]
@mtkbuild odesys = ODESystem(eqs, t)
using OrdinaryDiffEq
odeprob = ODEProblem(
odesys, [x => 1.0, y => 1.0], (0.0, 10.0), [α => 1.5, β => 1.0, γ => 3.0, δ => 1.0])
timesteps = 0.0:0.1:10.0
sol = solve(odeprob, Tsit5(); saveat = timesteps)
data = Array(sol)
# add some random noise
data = data + 0.01 * randn(size(data))
using SymbolicIndexingInterface: parameter_values, state_values
using SciMLStructures: Tunable, canonicalize, replace, replace!
using PreallocationTools
function loss(x, p)
odeprob = p[1] # ODEProblem stored as parameters to avoid using global variables
ps = parameter_values(odeprob) # obtain the parameter object from the problem
diffcache = p[5]
# get an appropriately typed preallocated buffer to store the `x` values in
buffer = get_tmp(diffcache, x)
# copy the current values to this buffer
copyto!(buffer, canonicalize(Tunable(), ps)[1])
# create a copy of the parameter object with the buffer
ps = replace(Tunable(), ps, buffer)
# set the updated values in the parameter object
setter = p[4]
setter(ps, x)
# remake the problem, passing in our new parameter object
newprob = remake(odeprob; p = ps)
timesteps = p[2]
sol = solve(newprob, AutoTsit5(Rosenbrock23()); saveat = timesteps)
truth = p[3]
data = Array(sol)
return sum((truth .- data) .^ 2) / length(truth)
end
using SymbolicIndexingInterface
setter = setp(odeprob, [α, β, γ, δ]);
# `DiffCache` to avoid allocations
diffcache = DiffCache(canonicalize(Tunable(), parameter_values(odeprob))[1]);
getter = getp(odeprob, [α, β, γ, δ])
getter(odeprob) # returns original parameter values
loss(ones(4), (odeprob, timesteps, data, setter, diffcache))
getter(odeprob) # returns ones, calling the loss function has mutated `odeprob`
The text was updated successfully, but these errors were encountered:
Thanks for pointing this out. The DiffCache constructor uses the provided array as the internal buffer, so it aliases the tunable portion of odeprob.p. This can be fixed by calling copy when calling the constructor. I'll update the doc example accordingly.
I am following the somewhat recent tutorial on re-creating MTK problems. In summary, the function loss is defined, which remakes an ODEProblem by replacing the tunable portion of the parameters with the input argument, and then evaluates the loss function. Upon calling the loss function, the parameters in the original ODEProblem object (in the global scope) are mutated.
MWE:
The text was updated successfully, but these errors were encountered: