Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-creating MTK problems causes side effects #3407

Closed
ysfoo opened this issue Feb 22, 2025 · 2 comments · Fixed by #3405
Closed

Re-creating MTK problems causes side effects #3407

ysfoo opened this issue Feb 22, 2025 · 2 comments · Fixed by #3405
Assignees
Labels
bug Something isn't working

Comments

@ysfoo
Copy link

ysfoo commented Feb 22, 2025

I am following the somewhat recent tutorial on re-creating MTK problems. In summary, the function loss is defined, which remakes an ODEProblem by replacing the tunable portion of the parameters with the input argument, and then evaluates the loss function. Upon calling the loss function, the parameters in the original ODEProblem object (in the global scope) are mutated.

MWE:

using ModelingToolkit
using ModelingToolkit: t_nounits as t, D_nounits as D

@parameters α β γ δ
@variables x(t) y(t)
eqs = [D(x) ~ (α - β * y) * x
       D(y) ~ (δ * x - γ) * y]
@mtkbuild odesys = ODESystem(eqs, t)

using OrdinaryDiffEq

odeprob = ODEProblem(
    odesys, [x => 1.0, y => 1.0], (0.0, 10.0), [α => 1.5, β => 1.0, γ => 3.0, δ => 1.0])
timesteps = 0.0:0.1:10.0
sol = solve(odeprob, Tsit5(); saveat = timesteps)
data = Array(sol)
# add some random noise
data = data + 0.01 * randn(size(data))

using SymbolicIndexingInterface: parameter_values, state_values
using SciMLStructures: Tunable, canonicalize, replace, replace!
using PreallocationTools

function loss(x, p)
    odeprob = p[1] # ODEProblem stored as parameters to avoid using global variables
    ps = parameter_values(odeprob) # obtain the parameter object from the problem
    diffcache = p[5]
    # get an appropriately typed preallocated buffer to store the `x` values in
    buffer = get_tmp(diffcache, x)
    # copy the current values to this buffer
    copyto!(buffer, canonicalize(Tunable(), ps)[1])
    # create a copy of the parameter object with the buffer
    ps = replace(Tunable(), ps, buffer)
    # set the updated values in the parameter object
    setter = p[4]
    setter(ps, x)
    # remake the problem, passing in our new parameter object
    newprob = remake(odeprob; p = ps)
    timesteps = p[2]
    sol = solve(newprob, AutoTsit5(Rosenbrock23()); saveat = timesteps)
    truth = p[3]
    data = Array(sol)
    return sum((truth .- data) .^ 2) / length(truth)
end

using SymbolicIndexingInterface

setter = setp(odeprob, [α, β, γ, δ]);
# `DiffCache` to avoid allocations
diffcache = DiffCache(canonicalize(Tunable(), parameter_values(odeprob))[1]);

getter = getp(odeprob, [α, β, γ, δ])
getter(odeprob) # returns original parameter values
loss(ones(4), (odeprob, timesteps, data, setter, diffcache))
getter(odeprob) # returns ones, calling the loss function has mutated `odeprob`
@ysfoo ysfoo added the bug Something isn't working label Feb 22, 2025
@ysfoo ysfoo changed the title Re-creating MTK problems creates side effects Re-creating MTK problems causes side effects Feb 22, 2025
@AayushSabharwal
Copy link
Member

Thanks for pointing this out. The DiffCache constructor uses the provided array as the internal buffer, so it aliases the tunable portion of odeprob.p. This can be fixed by calling copy when calling the constructor. I'll update the doc example accordingly.

@ysfoo
Copy link
Author

ysfoo commented Feb 24, 2025

I can confirm that copying resolves the issue, thanks.

@ysfoo ysfoo closed this as completed Feb 24, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants