Skip to content

Commit 4fdfb5e

Browse files
Fix StructArray broadcast in VectorOfArray
Fixes #410 This specializes so that if `u.u` is not a vector, it will convert the broadcast to fix that. I couldn't find a nice generic way to use `map` so the fallback is to build the vector and convert, which seems to not be a big performance issue. For StructArrays, `convert(typeof(x), Vector(x))` fails, and so this case is specialized.
1 parent 600a9b5 commit 4fdfb5e

4 files changed

+26
-6
lines changed

Project.toml

+2
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
2323
MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
2424
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
2525
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
26+
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
2627
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
2728
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2829

@@ -33,6 +34,7 @@ RecursiveArrayToolsMeasurementsExt = "Measurements"
3334
RecursiveArrayToolsMonteCarloMeasurementsExt = "MonteCarloMeasurements"
3435
RecursiveArrayToolsReverseDiffExt = ["ReverseDiff", "Zygote"]
3536
RecursiveArrayToolsSparseArraysExt = ["SparseArrays"]
37+
RecursiveArrayToolsStructArraysExt = "StructArrays"
3638
RecursiveArrayToolsTrackerExt = "Tracker"
3739
RecursiveArrayToolsZygoteExt = "Zygote"
3840

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
module RecursiveArrayToolsStructArraysExt
2+
3+
import RecursiveArrayTools, StructArrays
4+
RecursiveArrayTools.rewrap(::StructArrays.StructArray, u) = StructArrays.StructArray(u)
5+
6+
end

src/vector_of_array.jl

+11-6
Original file line numberDiff line numberDiff line change
@@ -849,28 +849,33 @@ end
849849

850850
@inline function Base.copy(bc::Broadcast.Broadcasted{<:VectorOfArrayStyle})
851851
bc = Broadcast.flatten(bc)
852-
853852
parent = find_VoA_parent(bc.args)
854853

855-
if parent isa AbstractVector
854+
u = if parent isa AbstractVector
856855
# this is the default behavior in v3.15.0
857856
N = narrays(bc)
858-
return VectorOfArray(map(1:N) do i
857+
map(1:N) do i
859858
copy(unpack_voa(bc, i))
860-
end)
859+
end
861860
else # if parent isa AbstractArray
862-
return VectorOfArray(map(enumerate(Iterators.product(axes(parent)...))) do (i, _)
861+
map(enumerate(Iterators.product(axes(parent)...))) do (i, _)
863862
copy(unpack_voa(bc, i))
864-
end)
863+
end
865864
end
865+
VectorOfArray(rewrap(parent, u))
866866
end
867867

868+
rewrap(::Array,u) = u
869+
rewrap(parent, u) = convert(typeof(parent), u)
870+
868871
for (type, N_expr) in [
869872
(Broadcast.Broadcasted{<:VectorOfArrayStyle}, :(narrays(bc))),
870873
(Broadcast.Broadcasted{<:Broadcast.DefaultArrayStyle}, :(length(dest.u)))
871874
]
872875
@eval @inline function Base.copyto!(dest::AbstractVectorOfArray,
873876
bc::$type)
877+
@show typeof(dest)
878+
error()
874879
bc = Broadcast.flatten(bc)
875880
N = $N_expr
876881
@inbounds for i in 1:N

test/copy_static_array_test.jl

+7
Original file line numberDiff line numberDiff line change
@@ -114,3 +114,10 @@ a_voa = VectorOfArray(a)
114114
a_voa .= 1.0
115115
@test a_voa[1] == SVector(1.0, 1.0)
116116
@test a_voa[2] == SVector(1.0, 1.0)
117+
118+
#Broadcast Copy of StructArray
119+
x = StructArray{SVector{2, Float64}}((randn(2), randn(2)))
120+
vx = VectorOfArray(x)
121+
vx2 = copy(vx) .+ 1
122+
ans = vx .+ vx2
123+
@test ans.u isa StructArray

0 commit comments

Comments
 (0)