Skip to content

Commit 987f835

Browse files
committed
Added unrolled implementation of recursivefill! which works on GPUs and avoids recomputing global indices for each setindex!
1 parent fbf5695 commit 987f835

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

src/array_partition.jl

+7
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,13 @@ function Base.copyto!(A::ArrayPartition, src::ArrayPartition)
209209
A
210210
end
211211

212+
function recursivefill!(b::ArrayPartition, a::T2) where {T2 <: Union{Number, Bool}}
213+
unrolled_foreach!(b.x) do x
214+
fill!(x, a)
215+
end
216+
end
217+
218+
212219
## indexing
213220

214221
# Interface for the linear indexing. This is just a view of the underlying nested structure

src/utils.jl

+5
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
unrolled_foreach!(f, t::Tuple) = (f(t[1]); unrolled_foreach!(f, Base.tail(t)))
2+
unrolled_foreach!(f, ::Tuple{}) = nothing
3+
4+
15
"""
26
```julia
37
recursivecopy(a::Union{AbstractArray{T, N}, AbstractVectorOfArray{T, N}})
@@ -127,6 +131,7 @@ function recursivefill!(bs::AbstractVectorOfArray{T, N},
127131
end
128132
end
129133

134+
130135
for type in [AbstractArray, AbstractVectorOfArray]
131136
@eval function recursivefill!(b::$type{T, N}, a::T2) where {T <: Enum, T2 <: Enum, N}
132137
fill!(b, a)

0 commit comments

Comments
 (0)