@@ -3,7 +3,8 @@ module RecursiveArrayToolsStructArraysExt
3
3
import RecursiveArrayTools, StructArrays
4
4
RecursiveArrayTools. rewrap (:: StructArrays.StructArray , u) = StructArrays. StructArray (u)
5
5
6
- using RecursiveArrayTools: VectorOfArray
6
+ using RecursiveArrayTools: VectorOfArray, VectorOfArrayStyle, ArrayInterface, unpack_voa,
7
+ narrays, StaticArraysCore
7
8
using StructArrays: StructArray
8
9
9
10
const VectorOfStructArray{T, N} = VectorOfArray{T, N, <: StructArray }
@@ -17,11 +18,45 @@ const VectorOfStructArray{T, N} = VectorOfArray{T, N, <:StructArray}
17
18
#
18
19
# To avoid this, we can materialize a struct entry, modify it, and then use `setindex!`
19
20
# with the modified struct entry.
21
+ #
20
22
function Base. setindex! (VA:: VectorOfStructArray{T, N} , v,
21
23
I:: Int... ) where {T, N}
22
24
u_I = VA. u[I[end ]]
23
25
u_I[Base. front (I)... ] = v
24
26
return VA. u[I[end ]] = u_I
25
27
end
26
28
29
+ for (type, N_expr) in [
30
+ (Broadcast. Broadcasted{<: VectorOfArrayStyle }, :(narrays (bc))),
31
+ (Broadcast. Broadcasted{<: Broadcast.DefaultArrayStyle }, :(length (dest. u)))
32
+ ]
33
+ @eval @inline function Base. copyto! (dest:: VectorOfStructArray ,
34
+ bc:: $type )
35
+ bc = Broadcast. flatten (bc)
36
+ N = $ N_expr
37
+ @inbounds for i in 1 : N
38
+ dest_i = dest[:, i]
39
+ if dest_i isa AbstractArray
40
+ if ArrayInterface. ismutable (dest_i)
41
+ copyto! (dest_i, unpack_voa (bc, i))
42
+ else
43
+ unpacked = unpack_voa (bc, i)
44
+ arr_type = StaticArraysCore. similar_type (dest_i)
45
+ dest_i = if length (unpacked) == 1 && length (dest_i) == 1
46
+ arr_type (unpacked[1 ])
47
+ elseif length (unpacked) == 1
48
+ fill (copy (unpacked), arr_type)
49
+ else
50
+ arr_type (unpacked[j] for j in eachindex (unpacked))
51
+ end
52
+ end
53
+ else
54
+ dest_i = copy (unpack_voa (bc, i))
55
+ end
56
+ dest[:, i] = dest_i
57
+ end
58
+ dest
59
+ end
60
+ end
61
+
27
62
end
0 commit comments