diff --git a/Project.toml b/Project.toml index 4a4b9322..b1be4604 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "RecursiveArrayTools" uuid = "731186ca-8d62-57ce-b412-fbd966d074cd" authors = ["Chris Rackauckas "] -version = "3.28.0" +version = "3.29.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -60,7 +60,7 @@ SafeTestsets = "0.1" SparseArrays = "1.10" StaticArrays = "1.6" StaticArraysCore = "1.4" -Statistics = "1.10" +Statistics = "1.10, 1.11" StructArrays = "0.6.11, 0.7" SymbolicIndexingInterface = "0.3.25" Tables = "1.11" diff --git a/src/array_partition.jl b/src/array_partition.jl index 7bd8bb52..b0325fe0 100644 --- a/src/array_partition.jl +++ b/src/array_partition.jl @@ -209,6 +209,19 @@ function Base.copyto!(A::ArrayPartition, src::ArrayPartition) A end +function Base.fill!(A::ArrayPartition, x) + unrolled_foreach!(A.x) do x_ + fill!(x_, x) + end + A +end + +function recursivefill!(b::ArrayPartition, a::T2) where {T2 <: Union{Number, Bool}} + unrolled_foreach!(b.x) do x + fill!(x, a) + end +end + ## indexing # Interface for the linear indexing. This is just a view of the underlying nested structure diff --git a/src/utils.jl b/src/utils.jl index ae4f1d2f..58945065 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,3 +1,7 @@ +unrolled_foreach!(f, t::Tuple) = (f(t[1]); unrolled_foreach!(f, Base.tail(t))) +unrolled_foreach!(f, ::Tuple{}) = nothing + + """ ```julia recursivecopy(a::Union{AbstractArray{T, N}, AbstractVectorOfArray{T, N}}) @@ -127,6 +131,7 @@ function recursivefill!(bs::AbstractVectorOfArray{T, N}, end end + for type in [AbstractArray, AbstractVectorOfArray] @eval function recursivefill!(b::$type{T, N}, a::T2) where {T <: Enum, T2 <: Enum, N} fill!(b, a) diff --git a/src/vector_of_array.jl b/src/vector_of_array.jl index eef23eac..f1947de4 100644 --- a/src/vector_of_array.jl +++ b/src/vector_of_array.jl @@ -704,6 +704,12 @@ function Base.similar(vec::VectorOfArray{ return VectorOfArray(similar.(Base.parent(vec))) end +function Base.similar(vec::VectorOfArray{ + T, N, AT}) where {T, N, AT <: AbstractArray{<:StaticArraysCore.StaticVecOrMat{T}}} + # this avoids behavior such as similar(SVector) returning an MVector + return VectorOfArray(similar(Base.parent(vec))) +end + @inline function Base.similar(VA::VectorOfArray, ::Type{T} = eltype(VA)) where {T} VectorOfArray(similar.(VA.u, T)) end diff --git a/test/copy_static_array_test.jl b/test/copy_static_array_test.jl index d3346593..de0132ce 100644 --- a/test/copy_static_array_test.jl +++ b/test/copy_static_array_test.jl @@ -120,4 +120,14 @@ x = StructArray{SVector{2, Float64}}((randn(2), randn(2))) vx = VectorOfArray(x) vx2 = copy(vx) .+ 1 ans = vx .+ vx2 -@test ans.u isa StructArray \ No newline at end of file +@test ans.u isa StructArray + +# check that Base.similar(VectorOfArray{<:StaticArray}) returns the +# same type as the original VectorOfArray +x_staticvector = [SVector(0.0, 0.0) for _ in 1:2] +x_structarray = StructArray{SVector{2, Float64}}((randn(2), randn(2))) +x_mutablefv = [MutableFV(1.0, 2.0)] +x_immutablefv = [ImmutableFV(1.0, 2.0)] +for vec in [x_staticvector, x_structarray, x_mutablefv, x_immutablefv] + @test typeof(similar(VectorOfArray(vec))) === typeof(VectorOfArray(vec)) +end \ No newline at end of file diff --git a/test/gpu/arraypartition_gpu.jl b/test/gpu/arraypartition_gpu.jl new file mode 100644 index 00000000..3b335855 --- /dev/null +++ b/test/gpu/arraypartition_gpu.jl @@ -0,0 +1,20 @@ +using RecursiveArrayTools, CUDA, Test +CUDA.allowscalar(false) + + +# Test indexing with colon +a = (CUDA.zeros(5), CUDA.zeros(5)) +pA = ArrayPartition(a) +pA[:, :] + +# Indexing with boolean masks does not work yet +mask = pA .> 0 +# pA[mask] + +# Test recursive filling is done using GPU kernels and not scalar indexing +RecursiveArrayTools.recursivefill!(pA, true) +@test all(pA .== true) + +# Test that regular filling is done using GPU kernels and not scalar indexing +fill!(pA, false) +@test all(pA .== false) diff --git a/test/runtests.jl b/test/runtests.jl index 819e40f3..4ec9d6f4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -54,5 +54,6 @@ end if GROUP == "GPU" activate_gpu_env() @time @safetestset "VectorOfArray GPU" include("gpu/vectorofarray_gpu.jl") + @time @safetestset "ArrayPartition GPU" include("gpu/arraypartition_gpu.jl") end end