Skip to content

Commit c1249b3

Browse files
fix: GPU tests, CuArray conversion, autodiff
1 parent 87ef7d5 commit c1249b3

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

ext/RecursiveArrayToolsZygoteExt.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -95,17 +95,17 @@ end
9595
VectorOfArray(u),
9696
y -> begin
9797
y isa Ref && (y = VectorOfArray(y[].u))
98-
(VectorOfArray([y[ntuple(x -> Colon(), ndims(y.u) - 1)..., i]
99-
for i in 1:size(y.u)[end]]),)
98+
(VectorOfArray([y[ntuple(x -> Colon(), ndims(y) - 1)..., i]
99+
for i in 1:size(y)[end]]),)
100100
end
101101
end
102102

103103
@adjoint function DiffEqArray(u, t)
104104
DiffEqArray(u, t),
105105
y -> begin
106106
y isa Ref && (y = VectorOfArray(y[].u))
107-
(DiffEqArray([y[ntuple(x -> Colon(), ndims(y.u) - 1)..., i]
108-
for i in 1:size(y.u)[end]],
107+
(DiffEqArray([y[ntuple(x -> Colon(), ndims(y) - 1)..., i]
108+
for i in 1:size(y)[end]],
109109
t), nothing)
110110
end
111111
end

src/RecursiveArrayTools.jl

+1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ end
2828

2929
import GPUArraysCore
3030
Base.convert(T::Type{<:GPUArraysCore.AnyGPUArray}, VA::AbstractVectorOfArray) = stack(VA.u)
31+
(T::Type{<:GPUArraysCore.AnyGPUArray})(VA::AbstractVectorOfArray) = T(Array(VA))
3132

3233
import Requires
3334
@static if !isdefined(Base, :get_extension)

0 commit comments

Comments
 (0)