Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions ext/RecursiveArrayToolsZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ end

@adjoint function Base.copy(u::VectorOfArray)
copy(u),
y -> (copy(y),)
tuple ∘ copy
end

@adjoint function DiffEqArray(u, t)
Expand All @@ -115,26 +115,33 @@ end
adj = let VA = VA
function Array_adjoint(y)
VA = recursivecopy(VA)
copyto!(VA, y)
VA .= y
return (VA,)
end
end
Array(VA), adj
end

@adjoint function Base.view(A::AbstractVectorOfArray, I::Colon...)
function adjoint(y)
(recursivecopy(parent(y)), map(_ -> nothing, I)...)
view_adjoint = let A = A, I = I
function (y)
A = recursivecopy(A)
A .= y
(A, map(_ -> nothing, I)...)
end
end
return view(A, I...), adjoint
return view(A, I...), view_adjoint
end

@adjoint function Base.view(A::AbstractVectorOfArray, I...)
function view_adjoint(y)
A = recursivecopy(parent(y))
recursivefill!(A, zero(eltype(A)))
A[I...] .= y
return (A, map(_ -> nothing, I)...)
view_adjoint = let A = A, I = I
function (y)
A = recursivecopy(A)
recursivefill!(A, zero(eltype(A)))
v = view(A, I...)
v .= y
return (A, map(_ -> nothing, I)...)
end
end
view(A, I...), view_adjoint
end
Expand Down
12 changes: 12 additions & 0 deletions test/adjoints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,16 @@ function loss9(x)
return VectorOfArray([collect((3i):(3i + 3)) .* x for i in 1:5])
end

function loss10(x)
voa = VectorOfArray([i * x for i in 1:5])
return sum(view(voa, 2:4, 3:5))
end

function loss11(x)
voa = VectorOfArray([i * x for i in 1:5])
return sum(view(voa, :, :))
end

x = float.(6:10)
loss(x)
@test Zygote.gradient(loss, x)[1] == ForwardDiff.gradient(loss, x)
Expand All @@ -78,3 +88,5 @@ loss(x)
@test Zygote.gradient(loss8, x)[1] == ForwardDiff.gradient(loss8, x)
@test ForwardDiff.derivative(loss9, 0.0) ==
VectorOfArray([collect((3i):(3i + 3)) for i in 1:5])
@test Zygote.gradient(loss10, x)[1] == ForwardDiff.gradient(loss10, x)
@test Zygote.gradient(loss11, x)[1] == ForwardDiff.gradient(loss11, x)