Skip to content

Commit 3f67126

Browse files
Merge pull request SciML#2729 from AayushSabharwal/as/mtkparams-bug
fix: fix bug in `remake_buffer`
2 parents 6d84b8f + 97bb913 commit 3f67126

File tree

2 files changed

+24
-8
lines changed

2 files changed

+24
-8
lines changed

Diff for: src/systems/parameter_buffer.jl

+7-8
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ function MTKParameters(
140140
dep_exprs = ArrayPartition((Any[missing for _ in 1:length(v)] for v in dep_buffer)...)
141141
for (sym, val) in pdeps
142142
i, j = ic.dependent_idx[sym]
143-
dep_exprs.x[i][j] = wrap(val)
143+
dep_exprs.x[i][j] = unwrap(val)
144144
end
145145
dep_exprs = identity.(dep_exprs)
146146
p = reorder_parameters(ic, full_parameters(sys))
@@ -423,7 +423,10 @@ function SymbolicIndexingInterface.remake_buffer(sys, oldbuf::MTKParameters, val
423423
@set! newbuf.nonnumeric = narrow_buffer_type_and_fallback_undefs.(
424424
oldbuf.nonnumeric, newbuf.nonnumeric)
425425
if newbuf.dependent_update_oop !== nothing
426-
@set! newbuf.dependent = newbuf.dependent_update_oop(newbuf...)
426+
@set! newbuf.dependent = narrow_buffer_type_and_fallback_undefs.(
427+
oldbuf.dependent,
428+
split_into_buffers(
429+
newbuf.dependent_update_oop(newbuf...), oldbuf.dependent, Val(false)))
427430
end
428431
return newbuf
429432
end
@@ -447,6 +450,7 @@ _num_subarrays(v::Tuple) = length(v)
447450
# getindex indexes the vectors, setindex! linearly indexes values
448451
# it's inconsistent, but we need it to be this way
449452
function Base.getindex(buf::MTKParameters, i)
453+
i_orig = i
450454
if !isempty(buf.tunable)
451455
i <= _num_subarrays(buf.tunable) && return _subarrays(buf.tunable)[i]
452456
i -= _num_subarrays(buf.tunable)
@@ -467,7 +471,7 @@ function Base.getindex(buf::MTKParameters, i)
467471
i <= _num_subarrays(buf.dependent) && return _subarrays(buf.dependent)[i]
468472
i -= _num_subarrays(buf.dependent)
469473
end
470-
throw(BoundsError(buf, i))
474+
throw(BoundsError(buf, i_orig))
471475
end
472476
function Base.setindex!(p::MTKParameters, val, i)
473477
function _helper(buf)
@@ -551,9 +555,6 @@ function jacobian_wrt_vars(pf::F, p::MTKParameters, input_idxs, chunk::C) where
551555
for (i, val) in zip(input_idxs, p_small_inner)
552556
_set_parameter_unchecked!(p_big, val, i)
553557
end
554-
# tunable, repack, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p_big)
555-
# tunable[input_idxs] .= p_small_inner
556-
# p_big = repack(tunable)
557558
return if pf isa SciMLBase.ParamJacobianWrapper
558559
buffer = Array{dualtype}(undef, size(pf.u))
559560
pf(buffer, p_big)
@@ -563,8 +564,6 @@ function jacobian_wrt_vars(pf::F, p::MTKParameters, input_idxs, chunk::C) where
563564
end
564565
end
565566
end
566-
# tunable, _, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)
567-
# p_small = tunable[input_idxs]
568567
p_small = parameter_values.((p,), input_idxs)
569568
cfg = ForwardDiff.JacobianConfig(p_closure, p_small, chunk, tag)
570569
ForwardDiff.jacobian(p_closure, p_small, cfg, Val(false))

Diff for: test/mtkparameters.jl

+17
Original file line numberDiff line numberDiff line change
@@ -224,3 +224,20 @@ function loss(x)
224224
end
225225

226226
@test_nowarn ForwardDiff.gradient(loss, collect(tunables))
227+
228+
# Ensure dependent parameters are `Tuple{...}` and not `ArrayPartition` when using
229+
# `remake_buffer`.
230+
@parameters p1 p2 p3[1:2] p4[1:2]
231+
@named sys = ODESystem(
232+
Equation[], t, [], [p1, p2, p3, p4]; parameter_dependencies = [p2 => 2p1, p4 => 3p3])
233+
sys = complete(sys)
234+
ps = MTKParameters(sys, [p1 => 1.0, p3 => [2.0, 3.0]])
235+
@test ps[parameter_index(sys, p2)] == 2.0
236+
@test ps[parameter_index(sys, p4)] == [6.0, 9.0]
237+
238+
newps = remake_buffer(
239+
sys, ps, Dict(p1 => ForwardDiff.Dual(2.0), p3 => ForwardDiff.Dual.([3.0, 4.0])))
240+
241+
VDual = Vector{<:ForwardDiff.Dual}
242+
VVDual = Vector{<:Vector{<:ForwardDiff.Dual}}
243+
@test newps.dependent isa Union{Tuple{VDual, VVDual}, Tuple{VVDual, VDual}}

0 commit comments

Comments
 (0)