Skip to content

Commit 5691fa3

Browse files
Merge pull request #3448 from vyudu/fix_float
fix: propagate `tofloat`, `use_union` to `better_varmap_to_vars`
2 parents 8d37790 + cabd060 commit 5691fa3

12 files changed

+60
-51
lines changed

src/systems/diffeqs/abstractodesystem.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -1504,5 +1504,6 @@ function InitializationProblem{iip, specialize}(sys::AbstractSystem,
15041504
else
15051505
NonlinearLeastSquaresProblem
15061506
end
1507-
TProb(isys, u0map, parammap; kwargs..., build_initializeprob = false, is_initializeprob = true)
1507+
TProb(isys, u0map, parammap; kwargs...,
1508+
build_initializeprob = false, is_initializeprob = true)
15081509
end

src/systems/discrete_system/discrete_system.jl

-1
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,6 @@ function SciMLBase.DiscreteProblem(
302302
parammap = SciMLBase.NullParameters();
303303
eval_module = @__MODULE__,
304304
eval_expression = false,
305-
use_union = false,
306305
kwargs...
307306
)
308307
if !iscomplete(sys)

src/systems/discrete_system/implicit_discrete_system.jl

-1
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,6 @@ function SciMLBase.ImplicitDiscreteProblem(
321321
parammap = SciMLBase.NullParameters();
322322
eval_module = @__MODULE__,
323323
eval_expression = false,
324-
use_union = false,
325324
kwargs...
326325
)
327326
if !iscomplete(sys)

src/systems/jumps/jumpsystem.jl

+3-8
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,6 @@ end
383383
```julia
384384
DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan,
385385
parammap = DiffEqBase.NullParameters;
386-
use_union = true,
387386
kwargs...)
388387
```
389388
@@ -403,7 +402,6 @@ dprob = DiscreteProblem(complete(js), u₀map, tspan, parammap)
403402
"""
404403
function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan::Union{Tuple, Nothing},
405404
parammap = DiffEqBase.NullParameters();
406-
use_union = true,
407405
eval_expression = false,
408406
eval_module = @__MODULE__,
409407
kwargs...)
@@ -416,7 +414,7 @@ function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan::Union{Tuple,
416414
end
417415

418416
_f, u0, p = process_SciMLProblem(EmptySciMLFunction, sys, u0map, parammap;
419-
t = tspan === nothing ? nothing : tspan[1], use_union, tofloat = false, check_length = false, build_initializeprob = false)
417+
t = tspan === nothing ? nothing : tspan[1], tofloat = false, check_length = false, build_initializeprob = false)
420418
f = DiffEqBase.DISCRETE_INPLACE_DEFAULT
421419

422420
observedfun = ObservedFunctionCache(
@@ -451,14 +449,13 @@ struct DiscreteProblemExpr{iip} end
451449

452450
function DiscreteProblemExpr{iip}(sys::JumpSystem, u0map, tspan::Union{Tuple, Nothing},
453451
parammap = DiffEqBase.NullParameters();
454-
use_union = true,
455452
kwargs...) where {iip}
456453
if !iscomplete(sys)
457454
error("A completed `JumpSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `DiscreteProblemExpr`")
458455
end
459456

460457
_, u0, p = process_SciMLProblem(EmptySciMLFunction, sys, u0map, parammap;
461-
t = tspan === nothing ? nothing : tspan[1], use_union, tofloat = false, check_length = false)
458+
t = tspan === nothing ? nothing : tspan[1], tofloat = false, check_length = false)
462459
# identity function to make syms works
463460
quote
464461
f = DiffEqBase.DISCRETE_INPLACE_DEFAULT
@@ -475,7 +472,6 @@ end
475472
```julia
476473
DiffEqBase.ODEProblem(sys::JumpSystem, u0map, tspan,
477474
parammap = DiffEqBase.NullParameters;
478-
use_union = true,
479475
kwargs...)
480476
```
481477
@@ -497,7 +493,6 @@ oprob = ODEProblem(complete(js), u₀map, tspan, parammap)
497493
"""
498494
function DiffEqBase.ODEProblem(sys::JumpSystem, u0map, tspan::Union{Tuple, Nothing},
499495
parammap = DiffEqBase.NullParameters();
500-
use_union = false,
501496
eval_expression = false,
502497
eval_module = @__MODULE__,
503498
kwargs...)
@@ -517,7 +512,7 @@ function DiffEqBase.ODEProblem(sys::JumpSystem, u0map, tspan::Union{Tuple, Nothi
517512
build_initializeprob = false, kwargs...)
518513
else
519514
_, u0, p = process_SciMLProblem(EmptySciMLFunction, sys, u0map, parammap;
520-
t = tspan === nothing ? nothing : tspan[1], use_union, tofloat = false,
515+
t = tspan === nothing ? nothing : tspan[1], tofloat = false,
521516
check_length = false)
522517
f = (du, u, p, t) -> (du .= 0; nothing)
523518
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module,

src/systems/optimization/optimizationsystem.jl

+6-8
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,6 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map,
295295
cons_sparse = false, checkbounds = false,
296296
linenumbers = true, parallel = SerialForm(),
297297
eval_expression = false, eval_module = @__MODULE__,
298-
use_union = false,
299298
checks = true,
300299
kwargs...) where {iip}
301300
if !iscomplete(sys)
@@ -338,10 +337,10 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map,
338337
elseif has_index_cache(sys) && get_index_cache(sys) !== nothing
339338
p = MTKParameters(sys, parammap, u0map)
340339
else
341-
p = varmap_to_vars(parammap, ps; defaults = defs, tofloat = false, use_union)
340+
p = varmap_to_vars(parammap, ps; defaults = defs, tofloat = false)
342341
end
343-
lb = varmap_to_vars(dvs .=> lb, dvs; defaults = defs, tofloat = false, use_union)
344-
ub = varmap_to_vars(dvs .=> ub, dvs; defaults = defs, tofloat = false, use_union)
342+
lb = varmap_to_vars(dvs .=> lb, dvs; defaults = defs, tofloat = false)
343+
ub = varmap_to_vars(dvs .=> ub, dvs; defaults = defs, tofloat = false)
345344

346345
if !isnothing(lb) && all(lb .== -Inf) && !isnothing(ub) && all(ub .== Inf)
347346
lb = nothing
@@ -538,7 +537,6 @@ function OptimizationProblemExpr{iip}(sys::OptimizationSystem, u0map,
538537
checkbounds = false,
539538
linenumbers = false, parallel = SerialForm(),
540539
eval_expression = false, eval_module = @__MODULE__,
541-
use_union = false,
542540
kwargs...) where {iip}
543541
if !iscomplete(sys)
544542
error("A completed `OptimizationSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `OptimizationProblemExpr`")
@@ -578,10 +576,10 @@ function OptimizationProblemExpr{iip}(sys::OptimizationSystem, u0map,
578576
if has_index_cache(sys) && get_index_cache(sys) !== nothing
579577
p = MTKParameters(sys, parammap, u0map)
580578
else
581-
p = varmap_to_vars(parammap, ps; defaults = defs, tofloat = false, use_union)
579+
p = varmap_to_vars(parammap, ps; defaults = defs, tofloat = false)
582580
end
583-
lb = varmap_to_vars(dvs .=> lb, dvs; defaults = defs, tofloat = false, use_union)
584-
ub = varmap_to_vars(dvs .=> ub, dvs; defaults = defs, tofloat = false, use_union)
581+
lb = varmap_to_vars(dvs .=> lb, dvs; defaults = defs, tofloat = false)
582+
ub = varmap_to_vars(dvs .=> ub, dvs; defaults = defs, tofloat = false)
585583

586584
if !isnothing(lb) && all(lb .== -Inf) && !isnothing(ub) && all(ub .== Inf)
587585
lb = nothing

src/systems/parameter_buffer.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ This requires that `complete` has been called on the system (usually via
2727
the default behavior).
2828
"""
2929
function MTKParameters(
30-
sys::AbstractSystem, p, u0 = Dict(); tofloat = false, use_union = false,
30+
sys::AbstractSystem, p, u0 = Dict(); tofloat = false,
3131
t0 = nothing, substitution_limit = 1000)
3232
ic = if has_index_cache(sys) && get_index_cache(sys) !== nothing
3333
get_index_cache(sys)

src/systems/problem_utils.jl

+21-22
Original file line numberDiff line numberDiff line change
@@ -330,17 +330,17 @@ struct MissingGuessError <: Exception
330330
vals::Vector{Any}
331331
end
332332

333-
function Base.showerror(io::IO, err::MissingGuessError)
334-
println(io,
335-
"""
336-
Cyclic guesses detected in the system. Symbolic values were found for the following variables/parameters in the map: \
337-
""")
333+
function Base.showerror(io::IO, err::MissingGuessError)
334+
println(io,
335+
"""
336+
Cyclic guesses detected in the system. Symbolic values were found for the following variables/parameters in the map: \
337+
""")
338338
for (sym, val) in zip(err.syms, err.vals)
339339
println(io, "$sym => $val")
340340
end
341341
println(io,
342-
"""
343-
In order to resolve this, please provide additional numeric guesses so that the chain can be resolved to assign numeric values to each variable. """)
342+
"""
343+
In order to resolve this, please provide additional numeric guesses so that the chain can be resolved to assign numeric values to each variable. """)
344344
end
345345

346346
"""
@@ -351,20 +351,20 @@ in `varmap`. Does not perform symbolic substitution in the values of `varmap`.
351351
352352
Keyword arguments:
353353
- `tofloat`: Convert values to floating point numbers using `float`.
354-
- `use_union`: Use a `Union`-typed array if the values have heterogeneous types.
355354
- `container_type`: The type of container to use for the values.
356355
- `toterm`: The `toterm` method to use for converting symbolics.
357356
- `promotetoconcrete`: whether the promote to a concrete buffer (respecting
358-
`tofloat` and `use_union`). Defaults to `container_type <: AbstractArray`.
357+
`tofloat`). Defaults to `container_type <: AbstractArray`.
359358
- `check`: Error if any variables in `vars` do not have a mapping in `varmap`. Uses
360359
[`missingvars`](@ref) to perform the check.
361360
- `allow_symbolic` allows the returned array to contain symbolic values. If this is `true`,
362361
`promotetoconcrete` is set to `false`.
363362
- `is_initializeprob, guesses`: Used to determine whether the system is missing guesses.
364363
"""
365364
function better_varmap_to_vars(varmap::AbstractDict, vars::Vector;
366-
tofloat = true, use_union = true, container_type = Array,
367-
toterm = default_toterm, promotetoconcrete = nothing, check = true, allow_symbolic = false, is_initializeprob = false)
365+
tofloat = true, container_type = Array,
366+
toterm = default_toterm, promotetoconcrete = nothing, check = true,
367+
allow_symbolic = false, is_initializeprob = false)
368368
isempty(vars) && return nothing
369369

370370
if check
@@ -382,8 +382,8 @@ function better_varmap_to_vars(varmap::AbstractDict, vars::Vector;
382382
end
383383

384384
if !isempty(missingsyms)
385-
is_initializeprob ? throw(MissingGuessError(missingsyms, missingvals)) :
386-
throw(UnexpectedSymbolicValueInVarmap(missingsyms[1], missingvals[1]))
385+
is_initializeprob ? throw(MissingGuessError(missingsyms, missingvals)) :
386+
throw(UnexpectedSymbolicValueInVarmap(missingsyms[1], missingvals[1]))
387387
end
388388
end
389389

@@ -393,7 +393,7 @@ function better_varmap_to_vars(varmap::AbstractDict, vars::Vector;
393393

394394
promotetoconcrete === nothing && (promotetoconcrete = container_type <: AbstractArray)
395395
if promotetoconcrete && !allow_symbolic
396-
vals = promote_to_concrete(vals; tofloat = tofloat, use_union = use_union)
396+
vals = promote_to_concrete(vals; tofloat = tofloat, use_union = false)
397397
end
398398

399399
if isempty(vals)
@@ -731,8 +731,7 @@ Keyword arguments:
731731
- `fully_determined`: Override whether the initialization system is fully determined.
732732
- `check_initialization_units`: Enable or disable unit checks when constructing the
733733
initialization problem.
734-
- `tofloat`, `use_union`, `is_initializeprob`: Passed to [`better_varmap_to_vars`](@ref) for building `u0` (and
735-
possibly `p`).
734+
- `tofloat`, `is_initializeprob`: Passed to [`better_varmap_to_vars`](@ref) for building `u0` (and possibly `p`).
736735
- `u0_constructor`: A function to apply to the `u0` value returned from `better_varmap_to_vars`
737736
to construct the final `u0` value.
738737
- `du0map`: A map of derivatives to values. See `implicit_dae`.
@@ -762,7 +761,7 @@ function process_SciMLProblem(
762761
implicit_dae = false, t = nothing, guesses = AnyDict(),
763762
warn_initialize_determined = true, initialization_eqs = [],
764763
eval_expression = false, eval_module = @__MODULE__, fully_determined = nothing,
765-
check_initialization_units = false, tofloat = true, use_union = false,
764+
check_initialization_units = false, tofloat = true,
766765
u0_constructor = identity, du0map = nothing, check_length = true,
767766
symbolic_u0 = false, warn_cyclic_dependency = false,
768767
circular_dependency_max_cycle_length = length(all_symbols(sys)),
@@ -841,7 +840,7 @@ function process_SciMLProblem(
841840
evaluate_varmap!(op, dvs; limit = substitution_limit)
842841

843842
u0 = better_varmap_to_vars(
844-
op, dvs; tofloat = true, use_union = false,
843+
op, dvs; tofloat,
845844
container_type = u0Type, allow_symbolic = symbolic_u0, is_initializeprob)
846845

847846
if u0 !== nothing
@@ -867,15 +866,15 @@ function process_SciMLProblem(
867866
if is_split(sys)
868867
p = MTKParameters(sys, op)
869868
else
870-
p = better_varmap_to_vars(op, ps; tofloat, use_union, container_type = pType)
869+
p = better_varmap_to_vars(op, ps; tofloat, container_type = pType)
871870
end
872871

873872
if implicit_dae && du0map !== nothing
874873
ddvs = map(Differential(iv), dvs)
875874
du0map = to_varmap(du0map, ddvs)
876875
merge!(op, du0map)
877876
du0 = varmap_to_vars(op, ddvs; toterm = identity,
878-
tofloat = true)
877+
tofloat)
879878
kwargs = merge(kwargs, (; ddvs))
880879
else
881880
du0 = nothing
@@ -944,8 +943,8 @@ function get_u0_p(sys,
944943
u0map,
945944
parammap = nothing;
946945
t0 = nothing,
947-
use_union = true,
948946
tofloat = true,
947+
use_union = true,
949948
symbolic_u0 = false)
950949
dvs = unknowns(sys)
951950
ps = parameters(sys; initial_parameters = true)
@@ -985,7 +984,7 @@ function get_u0_p(sys,
985984
if symbolic_u0
986985
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false, use_union = false)
987986
else
988-
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = true, use_union)
987+
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat, use_union)
989988
end
990989
p = varmap_to_vars(parammap, ps; defaults = defs, tofloat, use_union)
991990
p = p === nothing ? SciMLBase.NullParameters() : p

src/variables.jl

+2-3
Original file line numberDiff line numberDiff line change
@@ -186,15 +186,14 @@ function varmap_to_vars(varmap, varlist; defaults = Dict(), check = true,
186186

187187
vals = if eltype(varmap) <: Pair # `varmap` is a dict or an array of pairs
188188
varmap = todict(varmap)
189-
_varmap_to_vars(varmap, varlist; defaults = defaults, check = check,
190-
toterm = toterm)
189+
_varmap_to_vars(varmap, varlist; defaults, check, toterm)
191190
else # plain array-like initialization
192191
varmap
193192
end
194193

195194
promotetoconcrete === nothing && (promotetoconcrete = container_type <: AbstractArray)
196195
if promotetoconcrete
197-
vals = promote_to_concrete(vals; tofloat = tofloat, use_union = use_union)
196+
vals = promote_to_concrete(vals; tofloat, use_union)
198197
end
199198

200199
if isempty(vals)

test/initial_values.jl

+4-2
Original file line numberDiff line numberDiff line change
@@ -208,12 +208,14 @@ end
208208
x^2 + y^2 ~ 1]
209209
@mtkbuild pend = ODESystem(eqs, t)
210210

211-
@test_throws ModelingToolkit.MissingGuessError ODEProblem(pend, [x => 1], (0, 1), [g => 1], guesses = [y => λ, λ => y + 1])
211+
@test_throws ModelingToolkit.MissingGuessError ODEProblem(
212+
pend, [x => 1], (0, 1), [g => 1], guesses = [y => λ, λ => y + 1])
212213
ODEProblem(pend, [x => 1], (0, 1), [g => 1], guesses = [y => λ, λ => 0.5])
213214

214215
# Throw multiple if multiple are missing
215216
@variables a(t) b(t) c(t) d(t) e(t)
216217
eqs = [D(a) ~ b, D(b) ~ c, D(c) ~ d, D(d) ~ e, D(e) ~ 1]
217218
@mtkbuild sys = ODESystem(eqs, t)
218-
@test_throws ["a(t)", "c(t)"] ODEProblem(sys, [e => 2, a => b, b => a + 1, c => d, d => c + 1], (0, 1))
219+
@test_throws ["a(t)", "c(t)"] ODEProblem(
220+
sys, [e => 2, a => b, b => a + 1, c => d, d => c + 1], (0, 1))
219221
end

test/jumpsystem.jl

+18
Original file line numberDiff line numberDiff line change
@@ -533,3 +533,21 @@ let
533533
Xsamp /= Nsims
534534
@test abs(Xsamp - Xf(0.2, p) < 0.05 * Xf(0.2, p))
535535
end
536+
537+
@testset "JumpProcess simulation should be Int64 valued (#3446)" begin
538+
@parameters p d
539+
@variables X(t)
540+
rate1 = p
541+
rate2 = X * d
542+
affect1 = [X ~ X + 1]
543+
affect2 = [X ~ X - 1]
544+
j1 = ConstantRateJump(rate1, affect1)
545+
j2 = ConstantRateJump(rate2, affect2)
546+
547+
# Works.
548+
@mtkbuild js = JumpSystem([j1, j2], t, [X], [p, d])
549+
dprob = DiscreteProblem(js, [X => 15], (0.0, 10.0), [p => 2.0, d => 0.5])
550+
jprob = JumpProblem(js, dprob, Direct())
551+
sol = solve(jprob, SSAStepper())
552+
@test eltype(sol[X]) === Int64
553+
end

test/odesystem.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -745,7 +745,7 @@ let
745745
# No longer supported, Tuple used instead
746746
# pmap = Pair{Any, Union{Int, Float64}}[k1 => 1, k2 => 1.0]
747747
# tspan = (0.0, 1.0)
748-
# prob = ODEProblem(sys, u0map, tspan, pmap, use_union = true)
748+
# prob = ODEProblem(sys, u0map, tspan, pmap)
749749
# @test eltype(prob.p) === Union{Float64, Int}
750750
end
751751

@@ -1208,7 +1208,7 @@ end
12081208
sys = structural_simplify(ODESystem([D(x) ~ P], t, [x], [P]; name = :sys))
12091209

12101210
function x_at_1(P)
1211-
prob = ODEProblem(sys, [x => P], (0.0, 1.0), [sys.P => P], use_union = false)
1211+
prob = ODEProblem(sys, [x => P], (0.0, 1.0), [sys.P => P])
12121212
return solve(prob, Tsit5())(1.0)
12131213
end
12141214

test/split_parameters.jl

+1-2
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,8 @@ sol = solve(prob, ImplicitEuler());
123123
# ------------------------ Mixed Type Conserved
124124

125125
prob = ODEProblem(
126-
sys, [], tspan, []; tofloat = false, use_union = true, build_initializeprob = false)
126+
sys, [], tspan, []; tofloat = false, build_initializeprob = false)
127127

128-
@test prob.p isa Vector{Union{Float64, Int64}}
129128
sol = solve(prob, ImplicitEuler());
130129
@test sol.retcode == ReturnCode.Success
131130

0 commit comments

Comments
 (0)