-
-
Notifications
You must be signed in to change notification settings - Fork 60
/
Copy pathRecursiveArrayToolsStructArraysExt.jl
62 lines (56 loc) · 2.28 KB
/
RecursiveArrayToolsStructArraysExt.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
module RecursiveArrayToolsStructArraysExt
import RecursiveArrayTools, StructArrays
RecursiveArrayTools.rewrap(::StructArrays.StructArray, u) = StructArrays.StructArray(u)
using RecursiveArrayTools: VectorOfArray, VectorOfArrayStyle, ArrayInterface, unpack_voa,
narrays, StaticArraysCore
using StructArrays: StructArray
const VectorOfStructArray{T, N} = VectorOfArray{T, N, <:StructArray}
# Since `StructArray` lazily materializes struct entries, the general `setindex!(x, val, I)`
# operation `VA.u[I[end]][Base.front(I)...]` will only update a lazily materialized struct
# entry of `u`, but will not actually mutate `x::StructArray`. See the StructArray documentation
# for more details:
#
# https://juliaarrays.github.io/StructArrays.jl/stable/counterintuitive/#Modifying-a-field-of-a-struct-element
#
# To avoid this, we can materialize a struct entry, modify it, and then use `setindex!`
# with the modified struct entry.
#
function Base.setindex!(VA::VectorOfStructArray{T, N}, v,
I::Int...) where {T, N}
u_I = VA.u[I[end]]
u_I[Base.front(I)...] = v
return VA.u[I[end]] = u_I
end
for (type, N_expr) in [
(Broadcast.Broadcasted{<:VectorOfArrayStyle}, :(narrays(bc))),
(Broadcast.Broadcasted{<:Broadcast.DefaultArrayStyle}, :(length(dest.u)))
]
@eval @inline function Base.copyto!(dest::VectorOfStructArray,
bc::$type)
bc = Broadcast.flatten(bc)
N = $N_expr
@inbounds for i in 1:N
dest_i = dest[:, i]
if dest_i isa AbstractArray
if ArrayInterface.ismutable(dest_i)
copyto!(dest_i, unpack_voa(bc, i))
else
unpacked = unpack_voa(bc, i)
arr_type = StaticArraysCore.similar_type(dest_i)
dest_i = if length(unpacked) == 1 && length(dest_i) == 1
arr_type(unpacked[1])
elseif length(unpacked) == 1
fill(copy(unpacked), arr_type)
else
arr_type(unpacked[j] for j in eachindex(unpacked))
end
end
else
dest_i = copy(unpack_voa(bc, i))
end
dest[:, i] = dest_i
end
dest
end
end
end