diff --git a/Project.toml b/Project.toml index 7c028056..3cf40904 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "RecursiveArrayTools" uuid = "731186ca-8d62-57ce-b412-fbd966d074cd" authors = ["Chris Rackauckas "] -version = "3.37.0" +version = "3.37.1" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/ext/RecursiveArrayToolsZygoteExt.jl b/ext/RecursiveArrayToolsZygoteExt.jl index e1987dd8..b8ec612e 100644 --- a/ext/RecursiveArrayToolsZygoteExt.jl +++ b/ext/RecursiveArrayToolsZygoteExt.jl @@ -99,6 +99,30 @@ end end end +Zygote.@adjoint function Zygote.literal_getproperty(A::RecursiveArrayTools.AbstractVectorOfArray, ::Val{:u}) + function literal_AbstractVofA_u_adjoint(d) + dA = vofa_u_adjoint(d, A) + (dA, nothing) + end + A.u, literal_AbstractVofA_u_adjoint +end + +function vofa_u_adjoint(d, A::RecursiveArrayTools.AbstractVectorOfArray) + m = map(enumerate(d)) do (idx, d_i) + isnothing(d_i) && return zero(A.u[idx]) + d_i + end + VectorOfArray(m) +end + +function vofa_u_adjoint(d, A::RecursiveArrayTools.AbstractDiffEqArray) + m = map(enumerate(d)) do (idx, d_i) + isnothing(d_i) && return zero(A.u[idx]) + d_i + end + DiffEqArray(m, A.t) +end + @adjoint function literal_getproperty(A::ArrayPartition, ::Val{:x}) function literal_ArrayPartition_x_adjoint(d) (ArrayPartition((isnothing(d[i]) ? zero(A.x[i]) : d[i] for i in 1:length(d))...),) diff --git a/test/adjoints.jl b/test/adjoints.jl index af2abd42..a390c33a 100644 --- a/test/adjoints.jl +++ b/test/adjoints.jl @@ -92,3 +92,9 @@ loss(x) VectorOfArray([collect((3i):(3i + 3)) for i in 1:5]) @test Zygote.gradient(loss10, x)[1] == ForwardDiff.gradient(loss10, x) @test Zygote.gradient(loss11, x)[1] == ForwardDiff.gradient(loss11, x) + +voa = RecursiveArrayTools.VectorOfArray(fill(rand(3), 3)) +voa_gs, = Zygote.gradient(voa) do x + sum(sum.(x.u)) +end +@test voa_gs isa RecursiveArrayTools.VectorOfArray