Skip to content

Commit 3819dee

Browse files
Merge pull request #406 from huiyuxie/fix
Fix broadcast failure for `VectorOfArray` with `SVector{1}`
2 parents 6947538 + baa2af9 commit 3819dee

File tree

2 files changed

+35
-1
lines changed

2 files changed

+35
-1
lines changed

src/vector_of_array.jl

+3-1
Original file line numberDiff line numberDiff line change
@@ -905,7 +905,9 @@ for (type, N_expr) in [
905905
else
906906
unpacked = unpack_voa(bc, i)
907907
arr_type = StaticArraysCore.similar_type(dest[:, i])
908-
dest[:, i] = if length(unpacked) == 1
908+
dest[:, i] = if length(unpacked) == 1 && length(dest[:, i]) == 1
909+
arr_type(unpacked[1])
910+
elseif length(unpacked) == 1
909911
fill(copy(unpacked), arr_type)
910912
else
911913
arr_type(unpacked[j] for j in eachindex(unpacked))

test/copy_static_array_test.jl

+32
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,35 @@ b = recursivecopy(a)
8282
@test a[1] == b[1]
8383
a[1] *= 2
8484
@test a[1] != b[1]
85+
86+
# Broadcasting when SVector{N} where N = 1
87+
a = [SVector(0.0) for _ in 1:2]
88+
a_voa = VectorOfArray(a)
89+
b_voa = copy(a_voa)
90+
a_voa[1] = SVector(1.0)
91+
a_voa[2] = SVector(1.0)
92+
@. b_voa = a_voa
93+
@test b_voa[1] == a_voa[1]
94+
@test b_voa[2] == a_voa[2]
95+
96+
a = [SVector(0.0) for _ in 1:2]
97+
a_voa = VectorOfArray(a)
98+
a_voa .= 1.0
99+
@test a_voa[1] == SVector(1.0)
100+
@test a_voa[2] == SVector(1.0)
101+
102+
# Broadcasting when SVector{N} where N > 1
103+
a = [SVector(0.0, 0.0) for _ in 1:2]
104+
a_voa = VectorOfArray(a)
105+
b_voa = copy(a_voa)
106+
a_voa[1] = SVector(1.0, 1.0)
107+
a_voa[2] = SVector(1.0, 1.0)
108+
@. b_voa = a_voa
109+
@test b_voa[1] == a_voa[1]
110+
@test b_voa[2] == a_voa[2]
111+
112+
a = [SVector(0.0, 0.0) for _ in 1:2]
113+
a_voa = VectorOfArray(a)
114+
a_voa .= 1.0
115+
@test a_voa[1] == SVector(1.0, 1.0)
116+
@test a_voa[2] == SVector(1.0, 1.0)

0 commit comments

Comments
 (0)