From 20b90c17be238368029219e7b191a76752c244d6 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 20 Feb 2025 16:35:29 +0530 Subject: [PATCH 1/3] fix: fix DiffEqArray constructors ambiguity --- src/vector_of_array.jl | 164 ++++++++++++++++++++++++++++++----------- 1 file changed, 119 insertions(+), 45 deletions(-) diff --git a/src/vector_of_array.jl b/src/vector_of_array.jl index f1947de4..2609d9ff 100644 --- a/src/vector_of_array.jl +++ b/src/vector_of_array.jl @@ -174,54 +174,77 @@ end Base.parent(vec::VectorOfArray) = vec.u -function DiffEqArray(vec::AbstractVector{T}, - ts::AbstractVector, - ::NTuple{N, Int}, - p = nothing, - sys = nothing; discretes = nothing) where {T, N} - DiffEqArray{ - eltype(T), N, typeof(vec), typeof(ts), typeof(p), typeof(sys), typeof(discretes)}( - vec, +#### 2-argument + +# first element representative +function DiffEqArray(vec::AbstractVector, ts::AbstractVector; discretes = nothing, variables = nothing, parameters = nothing, independent_variables = nothing) + sys = SymbolCache(something(variables, []), + something(parameters, []), + something(independent_variables, [])) + _size = size(vec[1]) + T = eltype(vec[1]) + return DiffEqArray{ + T, + length(_size) + 1, + typeof(vec), + typeof(ts), + Nothing, + typeof(sys), + typeof(discretes) + }(vec, ts, - p, + nothing, + sys, + discretes) +end + +# T and N from type +function DiffEqArray(vec::AbstractVector{VT}, ts::AbstractVector; discretes = nothing, variables = nothing, parameters = nothing, independent_variables = nothing) where {T, N, VT <: AbstractArray{T, N}} + sys = SymbolCache(something(variables, []), + something(parameters, []), + something(independent_variables, [])) + return DiffEqArray{ + eltype(eltype(vec)), + N + 1, + typeof(vec), + typeof(ts), + Nothing, + typeof(sys), + typeof(discretes) + }(vec, + ts, + nothing, sys, discretes) end -# ambiguity resolution -function DiffEqArray(vec::AbstractVector{VT}, - ts::AbstractVector, - ::NTuple{N, Int}) where {T, N, VT <: AbstractArray{T, N}} - DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), Nothing, Nothing, Nothing}(vec, +#### 3-argument + +# NTuple, T from type +function DiffEqArray(vec::AbstractVector{T}, ts::AbstractVector, ::NTuple{N, Int}; discretes = nothing) where {T, N} + DiffEqArray{ + eltype(T), N, typeof(vec), typeof(ts), Nothing, Nothing, typeof(discretes)}( + vec, ts, nothing, nothing, - nothing) + discretes) end -function DiffEqArray(vec::AbstractVector{VT}, - ts::AbstractVector, - ::NTuple{N, Int}, p; discretes = nothing) where {T, N, VT <: AbstractArray{T, N}} - DiffEqArray{ - eltype(T), N, typeof(vec), typeof(ts), typeof(p), Nothing, typeof(discretes)}(vec, + +# NTuple parameter +function DiffEqArray(vec::AbstractVector{VT}, ts::AbstractVector, p::NTuple{N2, Int}; discretes = nothing) where {T, N, VT <: AbstractArray{T, N}, N2} + DiffEqArray{eltype(T), N + 1, typeof(vec), typeof(ts), typeof(p), Nothing, typeof(discretes)}(vec, ts, p, nothing, discretes) end -# Assume that the first element is representative of all other elements -function DiffEqArray(vec::AbstractVector, - ts::AbstractVector, - p = nothing, - sys = nothing; - variables = nothing, - parameters = nothing, - independent_variables = nothing, - discretes = nothing) - sys = something(sys, - SymbolCache(something(variables, []), +# first element representative +function DiffEqArray(vec::AbstractVector, ts::AbstractVector, p; discretes = nothing, variables = nothing, parameters = nothing, independent_variables = nothing) + sys = SymbolCache(something(variables, []), something(parameters, []), - something(independent_variables, []))) + something(independent_variables, [])) _size = size(vec[1]) T = eltype(vec[1]) return DiffEqArray{ @@ -239,21 +262,50 @@ function DiffEqArray(vec::AbstractVector, discretes) end -function DiffEqArray(vec::AbstractVector{VT}, - ts::AbstractVector, - p = nothing, - sys = nothing; - variables = nothing, - parameters = nothing, - independent_variables = nothing, - discretes = nothing) where {T, N, VT <: AbstractArray{T, N}} - sys = something(sys, - SymbolCache(something(variables, []), +# T and N from type +function DiffEqArray(vec::AbstractVector{VT}, ts::AbstractVector, p; discretes = nothing, variables = nothing, parameters = nothing, independent_variables = nothing) where {T, N, VT <: AbstractArray{T, N}} + sys = SymbolCache(something(variables, []), something(parameters, []), - something(independent_variables, []))) + something(independent_variables, [])) + DiffEqArray{eltype(T), N + 1, typeof(vec), typeof(ts), typeof(p), typeof(sys), typeof(discretes)}(vec, + ts, + p, + sys, + discretes) +end + +#### 4-argument + +# NTuple, T from type +function DiffEqArray(vec::AbstractVector{T}, ts::AbstractVector, ::NTuple{N, Int}, p; discretes = nothing) where {T, N} + DiffEqArray{ + eltype(T), N, typeof(vec), typeof(ts), typeof(p), Nothing, typeof(discretes)}( + vec, + ts, + p, + nothing, + discretes) +end + +# NTuple parameter +function DiffEqArray(vec::AbstractVector{VT}, ts::AbstractVector, p::NTuple{N2, Int}, sys; discretes = nothing) where {T, N, VT <: AbstractArray{T, N}, N2} + DiffEqArray{eltype(T), N + 1, typeof(vec), typeof(ts), typeof(p), typeof(sys), typeof(discretes)}(vec, + ts, + p, + sys, + discretes) +end + +# first element representative +function DiffEqArray(vec::AbstractVector, ts::AbstractVector, p, sys; discretes = nothing, variables = nothing, parameters = nothing, independent_variables = nothing) + sys = SymbolCache(something(variables, []), + something(parameters, []), + something(independent_variables, [])) + _size = size(vec[1]) + T = eltype(vec[1]) return DiffEqArray{ - eltype(eltype(vec)), - N + 1, + T, + length(_size) + 1, typeof(vec), typeof(ts), typeof(p), @@ -266,6 +318,28 @@ function DiffEqArray(vec::AbstractVector{VT}, discretes) end +# T and N from type +function DiffEqArray(vec::AbstractVector{VT}, ts::AbstractVector, p, sys; discretes = nothing) where {T, N, VT <: AbstractArray{T, N}} + DiffEqArray{eltype(T), N + 1, typeof(vec), typeof(ts), typeof(p), typeof(sys), typeof(discretes)}(vec, + ts, + p, + sys, + discretes) +end + +#### 5-argument + +# NTuple, T from type +function DiffEqArray(vec::AbstractVector{T}, ts::AbstractVector, ::NTuple{N, Int}, p, sys; discretes = nothing) where {T, N} + DiffEqArray{ + eltype(T), N, typeof(vec), typeof(ts), typeof(p), typeof(sys), typeof(discretes)}( + vec, + ts, + p, + sys, + discretes) +end + has_discretes(::T) where {T <: AbstractDiffEqArray} = hasfield(T, :discretes) get_discretes(x) = getfield(x, :discretes) From 720735d3b0b0506516d43c1c38085916624f60bc Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 20 Feb 2025 16:37:12 +0530 Subject: [PATCH 2/3] test: test some ambiguous `DiffEqArray` constructors --- test/interface_tests.jl | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/test/interface_tests.jl b/test/interface_tests.jl index a1266082..674d933c 100644 --- a/test/interface_tests.jl +++ b/test/interface_tests.jl @@ -279,3 +279,22 @@ end zvoa = zero(voa) @test zvoa.u[1] == zvoa.u[2] == zeros(3) end + +@testset "Issue SciMLBase#889: `DiffEqArray` constructor ambiguity" begin + darr = DiffEqArray([[1.0, 1.0]], [1.0], ()) + @test darr.p == () + @test darr.sys === nothing + @test ndims(darr) == 2 + darr = DiffEqArray([[1.0, 1.0]], [1.0], (), "A") + @test darr.p == () + @test darr.sys == "A" + @test ndims(darr) == 2 + darr = DiffEqArray([ones(2, 2)], [1.0], (1, 1, 1)) + @test darr.p == (1, 1, 1) + @test darr.sys === nothing + @test ndims(darr) == 3 + darr = DiffEqArray([ones(2, 2)], [1.0], (1, 1, 1), "A") + @test darr.p == (1, 1, 1) + @test darr.sys == "A" + @test ndims(darr) == 3 +end From d6f139a95b3f0b788d99d6acb08f0f91fb7d19e4 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 20 Feb 2025 17:13:14 +0530 Subject: [PATCH 3/3] test: fix test to account for MTK's `Initial` parameters --- test/downstream/symbol_indexing.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/downstream/symbol_indexing.jl b/test/downstream/symbol_indexing.jl index 1dc45dd3..557b65be 100644 --- a/test/downstream/symbol_indexing.jl +++ b/test/downstream/symbol_indexing.jl @@ -27,7 +27,7 @@ sol_new = DiffEqArray(sol.u[1:10], @test all(isequal.(all_variable_symbols(sol), all_variable_symbols(sol_new))) @test all(isequal.(all_variable_symbols(sol), [x, RHS])) @test all(isequal.(all_symbols(sol), all_symbols(sol_new))) -@test all(isequal.(all_symbols(sol), [x, RHS, τ, t])) +@test all([any(isequal(sym), all_symbols(sol)) for sym in [x, RHS, τ, t, Initial(x), Initial(RHS)]]) @test sol[solvedvariables] == sol[[x]] @test sol_new[solvedvariables] == sol_new[[x]] @test sol[allvariables] == sol[[x, RHS]]