From fc037f94912612730b3e4c98ebad54d74aaea4b8 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 30 May 2024 11:23:28 +0530 Subject: [PATCH 1/2] fix: fix view adjoints --- ext/RecursiveArrayToolsZygoteExt.jl | 24 +++++++++++++++--------- test/adjoints.jl | 12 ++++++++++++ 2 files changed, 27 insertions(+), 9 deletions(-) diff --git a/ext/RecursiveArrayToolsZygoteExt.jl b/ext/RecursiveArrayToolsZygoteExt.jl index 9b253a4f..c67103cb 100644 --- a/ext/RecursiveArrayToolsZygoteExt.jl +++ b/ext/RecursiveArrayToolsZygoteExt.jl @@ -88,7 +88,7 @@ end @adjoint function Base.copy(u::VectorOfArray) copy(u), - y -> (copy(y),) + tuple ∘ copy end @adjoint function DiffEqArray(u, t) @@ -123,18 +123,24 @@ end end @adjoint function Base.view(A::AbstractVectorOfArray, I::Colon...) - function adjoint(y) - (recursivecopy(parent(y)), map(_ -> nothing, I)...) + view_adjoint = let A = A, I = I + function (y) + A = recursivecopy(A) + copyto!(A, y) + (A, map(_ -> nothing, I)...) + end end - return view(A, I...), adjoint + return view(A, I...), view_adjoint end @adjoint function Base.view(A::AbstractVectorOfArray, I...) - function view_adjoint(y) - A = recursivecopy(parent(y)) - recursivefill!(A, zero(eltype(A))) - A[I...] .= y - return (A, map(_ -> nothing, I)...) + view_adjoint = let A = A, I = I + function (y) + A = recursivecopy(A) + recursivefill!(A, zero(eltype(A))) + A[I...] .= y + return (A, map(_ -> nothing, I)...) + end end view(A, I...), view_adjoint end diff --git a/test/adjoints.jl b/test/adjoints.jl index 1e5ee3c3..e5a1fc50 100644 --- a/test/adjoints.jl +++ b/test/adjoints.jl @@ -66,6 +66,16 @@ function loss9(x) return VectorOfArray([collect((3i):(3i + 3)) .* x for i in 1:5]) end +function loss10(x) + voa = VectorOfArray([i * x for i in 1:5]) + return sum(view(voa, 2:4, 3:5)) +end + +function loss11(x) + voa = VectorOfArray([i * x for i in 1:5]) + return sum(view(voa, :, :)) +end + x = float.(6:10) loss(x) @test Zygote.gradient(loss, x)[1] == ForwardDiff.gradient(loss, x) @@ -78,3 +88,5 @@ loss(x) @test Zygote.gradient(loss8, x)[1] == ForwardDiff.gradient(loss8, x) @test ForwardDiff.derivative(loss9, 0.0) == VectorOfArray([collect((3i):(3i + 3)) for i in 1:5]) +@test Zygote.gradient(loss10, x)[1] == ForwardDiff.gradient(loss10, x) +@test Zygote.gradient(loss11, x)[1] == ForwardDiff.gradient(loss11, x) From 97edb5870e6a506e823210ce59e9f93e434b0dc1 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 30 May 2024 12:12:16 +0530 Subject: [PATCH 2/2] fixup! fix: fix view adjoints --- ext/RecursiveArrayToolsZygoteExt.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/ext/RecursiveArrayToolsZygoteExt.jl b/ext/RecursiveArrayToolsZygoteExt.jl index c67103cb..e668676f 100644 --- a/ext/RecursiveArrayToolsZygoteExt.jl +++ b/ext/RecursiveArrayToolsZygoteExt.jl @@ -115,7 +115,7 @@ end adj = let VA = VA function Array_adjoint(y) VA = recursivecopy(VA) - copyto!(VA, y) + VA .= y return (VA,) end end @@ -126,7 +126,7 @@ end view_adjoint = let A = A, I = I function (y) A = recursivecopy(A) - copyto!(A, y) + A .= y (A, map(_ -> nothing, I)...) end end @@ -138,7 +138,8 @@ end function (y) A = recursivecopy(A) recursivefill!(A, zero(eltype(A))) - A[I...] .= y + v = view(A, I...) + v .= y return (A, map(_ -> nothing, I)...) end end