Skip to content

Commit 9d0bd35

Browse files
committed
specialize broadcasting for StructArrays
1 parent 107fe3c commit 9d0bd35

File tree

1 file changed

+36
-1
lines changed

1 file changed

+36
-1
lines changed

ext/RecursiveArrayToolsStructArraysExt.jl

+36-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@ module RecursiveArrayToolsStructArraysExt
33
import RecursiveArrayTools, StructArrays
44
RecursiveArrayTools.rewrap(::StructArrays.StructArray, u) = StructArrays.StructArray(u)
55

6-
using RecursiveArrayTools: VectorOfArray
6+
using RecursiveArrayTools: VectorOfArray, VectorOfArrayStyle, ArrayInterface, unpack_voa,
7+
narrays, StaticArraysCore
78
using StructArrays: StructArray
89

910
const VectorOfStructArray{T, N} = VectorOfArray{T, N, <:StructArray}
@@ -17,11 +18,45 @@ const VectorOfStructArray{T, N} = VectorOfArray{T, N, <:StructArray}
1718
#
1819
# To avoid this, we can materialize a struct entry, modify it, and then use `setindex!`
1920
# with the modified struct entry.
21+
#
2022
function Base.setindex!(VA::VectorOfStructArray{T, N}, v,
2123
I::Int...) where {T, N}
2224
u_I = VA.u[I[end]]
2325
u_I[Base.front(I)...] = v
2426
return VA.u[I[end]] = u_I
2527
end
2628

29+
for (type, N_expr) in [
30+
(Broadcast.Broadcasted{<:VectorOfArrayStyle}, :(narrays(bc))),
31+
(Broadcast.Broadcasted{<:Broadcast.DefaultArrayStyle}, :(length(dest.u)))
32+
]
33+
@eval @inline function Base.copyto!(dest::VectorOfStructArray,
34+
bc::$type)
35+
bc = Broadcast.flatten(bc)
36+
N = $N_expr
37+
@inbounds for i in 1:N
38+
dest_i = dest[:, i]
39+
if dest_i isa AbstractArray
40+
if ArrayInterface.ismutable(dest_i)
41+
copyto!(dest_i, unpack_voa(bc, i))
42+
else
43+
unpacked = unpack_voa(bc, i)
44+
arr_type = StaticArraysCore.similar_type(dest_i)
45+
dest_i = if length(unpacked) == 1 && length(dest_i) == 1
46+
arr_type(unpacked[1])
47+
elseif length(unpacked) == 1
48+
fill(copy(unpacked), arr_type)
49+
else
50+
arr_type(unpacked[j] for j in eachindex(unpacked))
51+
end
52+
end
53+
else
54+
dest_i = copy(unpack_voa(bc, i))
55+
end
56+
dest[:, i] = dest_i
57+
end
58+
dest
59+
end
60+
end
61+
2762
end

0 commit comments

Comments
 (0)