Skip to content

Commit 4d6845f

Browse files
fix: add RecursiveArrayToolsReverseDiffExt
1 parent 13b2a67 commit 4d6845f

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

Project.toml

+3
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,13 @@ Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
2323
MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
2424
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
2525
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
26+
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
2627

2728
[extensions]
2829
RecursiveArrayToolsFastBroadcastExt = "FastBroadcast"
2930
RecursiveArrayToolsMeasurementsExt = "Measurements"
3031
RecursiveArrayToolsMonteCarloMeasurementsExt = "MonteCarloMeasurements"
32+
RecursiveArrayToolsReverseDiffExt = ["ReverseDiff", "Zygote"]
3133
RecursiveArrayToolsTrackerExt = "Tracker"
3234
RecursiveArrayToolsZygoteExt = "Zygote"
3335

@@ -49,6 +51,7 @@ OrdinaryDiffEq = "6.62"
4951
Pkg = "1"
5052
Random = "1"
5153
RecipesBase = "1.1"
54+
ReverseDiff = "1.15"
5255
SafeTestsets = "0.1"
5356
SparseArrays = "1.10"
5457
StaticArrays = "1.6"
+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
module RecursiveArrayToolsReverseDiffExt
2+
3+
using RecursiveArrayTools
4+
using ReverseDiff
5+
using Zygote: @adjoint
6+
7+
function trackedarraycopyto!(dest, src)
8+
for (i, slice) in zip(eachindex(dest.u), eachslice(src, dims=ndims(src)))
9+
if dest.u[i] isa AbstractArray
10+
dest.u[i] = reshape(reduce(vcat, slice), size(dest.u[i]))
11+
else
12+
trackedarraycopyto!(dest.u[i], slice)
13+
end
14+
end
15+
end
16+
17+
@adjoint function Array(VA::AbstractVectorOfArray{<:ReverseDiff.TrackedReal})
18+
function Array_adjoint(y)
19+
VA = recursivecopy(VA)
20+
trackedarraycopyto!(VA, y)
21+
return (VA,)
22+
end
23+
return Array(VA), Array_adjoint
24+
end
25+
end # module

0 commit comments

Comments
 (0)