Skip to content

Commit fbf85bf

Browse files
Merge pull request #426 from jlchan/jc/VoA_StructArray_setindex
Fix broadcast assignment for `VectorOfArray` with `StructArrays`
2 parents 0b9c8eb + 96776b9 commit fbf85bf

File tree

2 files changed

+47
-3
lines changed

2 files changed

+47
-3
lines changed

ext/RecursiveArrayToolsStructArraysExt.jl

+36-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@ module RecursiveArrayToolsStructArraysExt
33
import RecursiveArrayTools, StructArrays
44
RecursiveArrayTools.rewrap(::StructArrays.StructArray, u) = StructArrays.StructArray(u)
55

6-
using RecursiveArrayTools: VectorOfArray
6+
using RecursiveArrayTools: VectorOfArray, VectorOfArrayStyle, ArrayInterface, unpack_voa,
7+
narrays, StaticArraysCore
78
using StructArrays: StructArray
89

910
const VectorOfStructArray{T, N} = VectorOfArray{T, N, <:StructArray}
@@ -17,11 +18,45 @@ const VectorOfStructArray{T, N} = VectorOfArray{T, N, <:StructArray}
1718
#
1819
# To avoid this, we can materialize a struct entry, modify it, and then use `setindex!`
1920
# with the modified struct entry.
21+
#
2022
function Base.setindex!(VA::VectorOfStructArray{T, N}, v,
2123
I::Int...) where {T, N}
2224
u_I = VA.u[I[end]]
2325
u_I[Base.front(I)...] = v
2426
return VA.u[I[end]] = u_I
2527
end
2628

29+
for (type, N_expr) in [
30+
(Broadcast.Broadcasted{<:VectorOfArrayStyle}, :(narrays(bc))),
31+
(Broadcast.Broadcasted{<:Broadcast.DefaultArrayStyle}, :(length(dest.u)))
32+
]
33+
@eval @inline function Base.copyto!(dest::VectorOfStructArray,
34+
bc::$type)
35+
bc = Broadcast.flatten(bc)
36+
N = $N_expr
37+
@inbounds for i in 1:N
38+
dest_i = dest[:, i]
39+
if dest_i isa AbstractArray
40+
if ArrayInterface.ismutable(dest_i)
41+
copyto!(dest_i, unpack_voa(bc, i))
42+
else
43+
unpacked = unpack_voa(bc, i)
44+
arr_type = StaticArraysCore.similar_type(dest_i)
45+
dest_i = if length(unpacked) == 1 && length(dest_i) == 1
46+
arr_type(unpacked[1])
47+
elseif length(unpacked) == 1
48+
fill(copy(unpacked), arr_type)
49+
else
50+
arr_type(unpacked[j] for j in eachindex(unpacked))
51+
end
52+
end
53+
else
54+
dest_i = copy(unpack_voa(bc, i))
55+
end
56+
dest[:, i] = dest_i
57+
end
58+
dest
59+
end
60+
end
61+
2762
end

test/basic_indexing.jl

+11-2
Original file line numberDiff line numberDiff line change
@@ -265,9 +265,18 @@ num_allocs = @allocations foo!(u_matrix)
265265

266266
# check VectorOfArray indexing for a StructArray of mutable structs
267267
using StructArrays
268-
using StaticArrays: MVector
268+
using StaticArrays: MVector, SVector
269269
x = VectorOfArray(StructArray{MVector{1, Float64}}(ntuple(_ -> [1.0, 2.0], 1)))
270+
y = 2 * x
270271

271-
# check VectorOfArray assignment
272+
# check mutable VectorOfArray assignment and broadcast
272273
x[1, 1] = 10
273274
@test x[1, 1] == 10
275+
@. x = y
276+
@test all(all.(y .== x))
277+
278+
# check immutable VectorOfArray broadcast
279+
x = VectorOfArray(StructArray{SVector{1, Float64}}(ntuple(_ -> [1.0, 2.0], 1)))
280+
y = 2 * x
281+
@. x = y
282+
@test all(all.(y .== x))

0 commit comments

Comments
 (0)