Skip to content

Commit ff98d26

Browse files
Merge pull request #423 from AayushSabharwal/as/ctor-ambig
fix: fix DiffEqArray constructors ambiguity
2 parents 7fdbfba + d6f139a commit ff98d26

File tree

3 files changed

+139
-46
lines changed

3 files changed

+139
-46
lines changed

src/vector_of_array.jl

+119-45
Original file line numberDiff line numberDiff line change
@@ -174,54 +174,77 @@ end
174174

175175
Base.parent(vec::VectorOfArray) = vec.u
176176

177-
function DiffEqArray(vec::AbstractVector{T},
178-
ts::AbstractVector,
179-
::NTuple{N, Int},
180-
p = nothing,
181-
sys = nothing; discretes = nothing) where {T, N}
182-
DiffEqArray{
183-
eltype(T), N, typeof(vec), typeof(ts), typeof(p), typeof(sys), typeof(discretes)}(
184-
vec,
177+
#### 2-argument
178+
179+
# first element representative
180+
function DiffEqArray(vec::AbstractVector, ts::AbstractVector; discretes = nothing, variables = nothing, parameters = nothing, independent_variables = nothing)
181+
sys = SymbolCache(something(variables, []),
182+
something(parameters, []),
183+
something(independent_variables, []))
184+
_size = size(vec[1])
185+
T = eltype(vec[1])
186+
return DiffEqArray{
187+
T,
188+
length(_size) + 1,
189+
typeof(vec),
190+
typeof(ts),
191+
Nothing,
192+
typeof(sys),
193+
typeof(discretes)
194+
}(vec,
185195
ts,
186-
p,
196+
nothing,
197+
sys,
198+
discretes)
199+
end
200+
201+
# T and N from type
202+
function DiffEqArray(vec::AbstractVector{VT}, ts::AbstractVector; discretes = nothing, variables = nothing, parameters = nothing, independent_variables = nothing) where {T, N, VT <: AbstractArray{T, N}}
203+
sys = SymbolCache(something(variables, []),
204+
something(parameters, []),
205+
something(independent_variables, []))
206+
return DiffEqArray{
207+
eltype(eltype(vec)),
208+
N + 1,
209+
typeof(vec),
210+
typeof(ts),
211+
Nothing,
212+
typeof(sys),
213+
typeof(discretes)
214+
}(vec,
215+
ts,
216+
nothing,
187217
sys,
188218
discretes)
189219
end
190220

191-
# ambiguity resolution
192-
function DiffEqArray(vec::AbstractVector{VT},
193-
ts::AbstractVector,
194-
::NTuple{N, Int}) where {T, N, VT <: AbstractArray{T, N}}
195-
DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), Nothing, Nothing, Nothing}(vec,
221+
#### 3-argument
222+
223+
# NTuple, T from type
224+
function DiffEqArray(vec::AbstractVector{T}, ts::AbstractVector, ::NTuple{N, Int}; discretes = nothing) where {T, N}
225+
DiffEqArray{
226+
eltype(T), N, typeof(vec), typeof(ts), Nothing, Nothing, typeof(discretes)}(
227+
vec,
196228
ts,
197229
nothing,
198230
nothing,
199-
nothing)
231+
discretes)
200232
end
201-
function DiffEqArray(vec::AbstractVector{VT},
202-
ts::AbstractVector,
203-
::NTuple{N, Int}, p; discretes = nothing) where {T, N, VT <: AbstractArray{T, N}}
204-
DiffEqArray{
205-
eltype(T), N, typeof(vec), typeof(ts), typeof(p), Nothing, typeof(discretes)}(vec,
233+
234+
# NTuple parameter
235+
function DiffEqArray(vec::AbstractVector{VT}, ts::AbstractVector, p::NTuple{N2, Int}; discretes = nothing) where {T, N, VT <: AbstractArray{T, N}, N2}
236+
DiffEqArray{eltype(T), N + 1, typeof(vec), typeof(ts), typeof(p), Nothing, typeof(discretes)}(vec,
206237
ts,
207238
p,
208239
nothing,
209240
discretes)
210241
end
211-
# Assume that the first element is representative of all other elements
212242

213-
function DiffEqArray(vec::AbstractVector,
214-
ts::AbstractVector,
215-
p = nothing,
216-
sys = nothing;
217-
variables = nothing,
218-
parameters = nothing,
219-
independent_variables = nothing,
220-
discretes = nothing)
221-
sys = something(sys,
222-
SymbolCache(something(variables, []),
243+
# first element representative
244+
function DiffEqArray(vec::AbstractVector, ts::AbstractVector, p; discretes = nothing, variables = nothing, parameters = nothing, independent_variables = nothing)
245+
sys = SymbolCache(something(variables, []),
223246
something(parameters, []),
224-
something(independent_variables, [])))
247+
something(independent_variables, []))
225248
_size = size(vec[1])
226249
T = eltype(vec[1])
227250
return DiffEqArray{
@@ -239,21 +262,50 @@ function DiffEqArray(vec::AbstractVector,
239262
discretes)
240263
end
241264

242-
function DiffEqArray(vec::AbstractVector{VT},
243-
ts::AbstractVector,
244-
p = nothing,
245-
sys = nothing;
246-
variables = nothing,
247-
parameters = nothing,
248-
independent_variables = nothing,
249-
discretes = nothing) where {T, N, VT <: AbstractArray{T, N}}
250-
sys = something(sys,
251-
SymbolCache(something(variables, []),
265+
# T and N from type
266+
function DiffEqArray(vec::AbstractVector{VT}, ts::AbstractVector, p; discretes = nothing, variables = nothing, parameters = nothing, independent_variables = nothing) where {T, N, VT <: AbstractArray{T, N}}
267+
sys = SymbolCache(something(variables, []),
252268
something(parameters, []),
253-
something(independent_variables, [])))
269+
something(independent_variables, []))
270+
DiffEqArray{eltype(T), N + 1, typeof(vec), typeof(ts), typeof(p), typeof(sys), typeof(discretes)}(vec,
271+
ts,
272+
p,
273+
sys,
274+
discretes)
275+
end
276+
277+
#### 4-argument
278+
279+
# NTuple, T from type
280+
function DiffEqArray(vec::AbstractVector{T}, ts::AbstractVector, ::NTuple{N, Int}, p; discretes = nothing) where {T, N}
281+
DiffEqArray{
282+
eltype(T), N, typeof(vec), typeof(ts), typeof(p), Nothing, typeof(discretes)}(
283+
vec,
284+
ts,
285+
p,
286+
nothing,
287+
discretes)
288+
end
289+
290+
# NTuple parameter
291+
function DiffEqArray(vec::AbstractVector{VT}, ts::AbstractVector, p::NTuple{N2, Int}, sys; discretes = nothing) where {T, N, VT <: AbstractArray{T, N}, N2}
292+
DiffEqArray{eltype(T), N + 1, typeof(vec), typeof(ts), typeof(p), typeof(sys), typeof(discretes)}(vec,
293+
ts,
294+
p,
295+
sys,
296+
discretes)
297+
end
298+
299+
# first element representative
300+
function DiffEqArray(vec::AbstractVector, ts::AbstractVector, p, sys; discretes = nothing, variables = nothing, parameters = nothing, independent_variables = nothing)
301+
sys = SymbolCache(something(variables, []),
302+
something(parameters, []),
303+
something(independent_variables, []))
304+
_size = size(vec[1])
305+
T = eltype(vec[1])
254306
return DiffEqArray{
255-
eltype(eltype(vec)),
256-
N + 1,
307+
T,
308+
length(_size) + 1,
257309
typeof(vec),
258310
typeof(ts),
259311
typeof(p),
@@ -266,6 +318,28 @@ function DiffEqArray(vec::AbstractVector{VT},
266318
discretes)
267319
end
268320

321+
# T and N from type
322+
function DiffEqArray(vec::AbstractVector{VT}, ts::AbstractVector, p, sys; discretes = nothing) where {T, N, VT <: AbstractArray{T, N}}
323+
DiffEqArray{eltype(T), N + 1, typeof(vec), typeof(ts), typeof(p), typeof(sys), typeof(discretes)}(vec,
324+
ts,
325+
p,
326+
sys,
327+
discretes)
328+
end
329+
330+
#### 5-argument
331+
332+
# NTuple, T from type
333+
function DiffEqArray(vec::AbstractVector{T}, ts::AbstractVector, ::NTuple{N, Int}, p, sys; discretes = nothing) where {T, N}
334+
DiffEqArray{
335+
eltype(T), N, typeof(vec), typeof(ts), typeof(p), typeof(sys), typeof(discretes)}(
336+
vec,
337+
ts,
338+
p,
339+
sys,
340+
discretes)
341+
end
342+
269343
has_discretes(::T) where {T <: AbstractDiffEqArray} = hasfield(T, :discretes)
270344
get_discretes(x) = getfield(x, :discretes)
271345

test/downstream/symbol_indexing.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ sol_new = DiffEqArray(sol.u[1:10],
2727
@test all(isequal.(all_variable_symbols(sol), all_variable_symbols(sol_new)))
2828
@test all(isequal.(all_variable_symbols(sol), [x, RHS]))
2929
@test all(isequal.(all_symbols(sol), all_symbols(sol_new)))
30-
@test all(isequal.(all_symbols(sol), [x, RHS, τ, t]))
30+
@test all([any(isequal(sym), all_symbols(sol)) for sym in [x, RHS, τ, t, Initial(x), Initial(RHS)]])
3131
@test sol[solvedvariables] == sol[[x]]
3232
@test sol_new[solvedvariables] == sol_new[[x]]
3333
@test sol[allvariables] == sol[[x, RHS]]

test/interface_tests.jl

+19
Original file line numberDiff line numberDiff line change
@@ -279,3 +279,22 @@ end
279279
zvoa = zero(voa)
280280
@test zvoa.u[1] == zvoa.u[2] == zeros(3)
281281
end
282+
283+
@testset "Issue SciMLBase#889: `DiffEqArray` constructor ambiguity" begin
284+
darr = DiffEqArray([[1.0, 1.0]], [1.0], ())
285+
@test darr.p == ()
286+
@test darr.sys === nothing
287+
@test ndims(darr) == 2
288+
darr = DiffEqArray([[1.0, 1.0]], [1.0], (), "A")
289+
@test darr.p == ()
290+
@test darr.sys == "A"
291+
@test ndims(darr) == 2
292+
darr = DiffEqArray([ones(2, 2)], [1.0], (1, 1, 1))
293+
@test darr.p == (1, 1, 1)
294+
@test darr.sys === nothing
295+
@test ndims(darr) == 3
296+
darr = DiffEqArray([ones(2, 2)], [1.0], (1, 1, 1), "A")
297+
@test darr.p == (1, 1, 1)
298+
@test darr.sys == "A"
299+
@test ndims(darr) == 3
300+
end

0 commit comments

Comments
 (0)