Skip to content

Commit ea4d2fc

Browse files
feat: add parameter type and size validation in remake_buffer and setp
1 parent 3f67126 commit ea4d2fc

File tree

5 files changed

+206
-34
lines changed

5 files changed

+206
-34
lines changed

src/systems/index_cache.jl

+27-30
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
struct BufferTemplate
2-
type::DataType
2+
type::Union{DataType, UnionAll}
33
length::Int
44
end
55

@@ -16,8 +16,11 @@ const NONNUMERIC_PORTION = Nonnumeric()
1616
struct ParameterIndex{P, I}
1717
portion::P
1818
idx::I
19+
validate_size::Bool
1920
end
2021

22+
ParameterIndex(portion, idx) = ParameterIndex(portion, idx, false)
23+
2124
const ParamIndexMap = Dict{Union{Symbol, BasicSymbolic}, Tuple{Int, Int}}
2225
const UnknownIndexMap = Dict{
2326
Union{Symbol, BasicSymbolic}, Union{Int, UnitRange{Int}, AbstractArray{Int}}}
@@ -34,11 +37,14 @@ struct IndexCache
3437
constant_buffer_sizes::Vector{BufferTemplate}
3538
dependent_buffer_sizes::Vector{BufferTemplate}
3639
nonnumeric_buffer_sizes::Vector{BufferTemplate}
40+
symbol_to_variable::Dict{Symbol, BasicSymbolic}
3741
end
3842

3943
function IndexCache(sys::AbstractSystem)
4044
unks = solved_unknowns(sys)
4145
unk_idxs = UnknownIndexMap()
46+
symbol_to_variable = Dict{Symbol, BasicSymbolic}()
47+
4248
let idx = 1
4349
for sym in unks
4450
usym = unwrap(sym)
@@ -50,7 +56,9 @@ function IndexCache(sys::AbstractSystem)
5056
unk_idxs[usym] = sym_idx
5157

5258
if hasname(sym) && (!istree(sym) || operation(sym) !== getindex)
53-
unk_idxs[getname(usym)] = sym_idx
59+
name = getname(usym)
60+
unk_idxs[name] = sym_idx
61+
symbol_to_variable[name] = sym
5462
end
5563
idx += length(sym)
5664
end
@@ -66,7 +74,9 @@ function IndexCache(sys::AbstractSystem)
6674
end
6775
unk_idxs[arrsym] = idxs
6876
if hasname(arrsym)
69-
unk_idxs[getname(arrsym)] = idxs
77+
name = getname(arrsym)
78+
unk_idxs[name] = idxs
79+
symbol_to_variable[name] = arrsym
7080
end
7181
end
7282
end
@@ -144,14 +154,15 @@ function IndexCache(sys::AbstractSystem)
144154
idxs[default_toterm(p)] = (i, j)
145155
if hasname(p) && (!istree(p) || operation(p) !== getindex)
146156
idxs[getname(p)] = (i, j)
157+
symbol_to_variable[getname(p)] = p
147158
idxs[getname(default_toterm(p))] = (i, j)
159+
symbol_to_variable[getname(default_toterm(p))] = p
148160
end
149161
end
150162
push!(buffer_sizes, BufferTemplate(T, length(buf)))
151163
end
152164
return idxs, buffer_sizes
153165
end
154-
155166
disc_idxs, discrete_buffer_sizes = get_buffer_sizes_and_idxs(disc_buffers)
156167
tunable_idxs, tunable_buffer_sizes = get_buffer_sizes_and_idxs(tunable_buffers)
157168
const_idxs, const_buffer_sizes = get_buffer_sizes_and_idxs(constant_buffers)
@@ -169,7 +180,8 @@ function IndexCache(sys::AbstractSystem)
169180
tunable_buffer_sizes,
170181
const_buffer_sizes,
171182
dependent_buffer_sizes,
172-
nonnumeric_buffer_sizes
183+
nonnumeric_buffer_sizes,
184+
symbol_to_variable
173185
)
174186
end
175187

@@ -190,16 +202,21 @@ function SymbolicIndexingInterface.is_parameter(ic::IndexCache, sym)
190202
end
191203

192204
function SymbolicIndexingInterface.parameter_index(ic::IndexCache, sym)
205+
if sym isa Symbol
206+
sym = ic.symbol_to_variable[sym]
207+
end
208+
validate_size = Symbolics.isarraysymbolic(sym) &&
209+
Symbolics.shape(sym) !== Symbolics.Unknown()
193210
return if (idx = check_index_map(ic.tunable_idx, sym)) !== nothing
194-
ParameterIndex(SciMLStructures.Tunable(), idx)
211+
ParameterIndex(SciMLStructures.Tunable(), idx, validate_size)
195212
elseif (idx = check_index_map(ic.discrete_idx, sym)) !== nothing
196-
ParameterIndex(SciMLStructures.Discrete(), idx)
213+
ParameterIndex(SciMLStructures.Discrete(), idx, validate_size)
197214
elseif (idx = check_index_map(ic.constant_idx, sym)) !== nothing
198-
ParameterIndex(SciMLStructures.Constants(), idx)
215+
ParameterIndex(SciMLStructures.Constants(), idx, validate_size)
199216
elseif (idx = check_index_map(ic.nonnumeric_idx, sym)) !== nothing
200-
ParameterIndex(NONNUMERIC_PORTION, idx)
217+
ParameterIndex(NONNUMERIC_PORTION, idx, validate_size)
201218
elseif (idx = check_index_map(ic.dependent_idx, sym)) !== nothing
202-
ParameterIndex(DEPENDENT_PORTION, idx)
219+
ParameterIndex(DEPENDENT_PORTION, idx, validate_size)
203220
else
204221
nothing
205222
end
@@ -224,26 +241,6 @@ function check_index_map(idxmap, sym)
224241
end
225242
end
226243

227-
function ParameterIndex(ic::IndexCache, p, sub_idx = ())
228-
p = unwrap(p)
229-
return if haskey(ic.tunable_idx, p)
230-
ParameterIndex(SciMLStructures.Tunable(), (ic.tunable_idx[p]..., sub_idx...))
231-
elseif haskey(ic.discrete_idx, p)
232-
ParameterIndex(SciMLStructures.Discrete(), (ic.discrete_idx[p]..., sub_idx...))
233-
elseif haskey(ic.constant_idx, p)
234-
ParameterIndex(SciMLStructures.Constants(), (ic.constant_idx[p]..., sub_idx...))
235-
elseif haskey(ic.dependent_idx, p)
236-
ParameterIndex(DEPENDENT_PORTION, (ic.dependent_idx[p]..., sub_idx...))
237-
elseif haskey(ic.nonnumeric_idx, p)
238-
ParameterIndex(NONNUMERIC_PORTION, (ic.nonnumeric_idx[p]..., sub_idx...))
239-
elseif istree(p) && operation(p) === getindex
240-
_p, sub_idx... = arguments(p)
241-
ParameterIndex(ic, _p, sub_idx)
242-
else
243-
nothing
244-
end
245-
end
246-
247244
function discrete_linear_index(ic::IndexCache, idx::ParameterIndex)
248245
idx.portion isa SciMLStructures.Discrete || error("Discrete variable index expected")
249246
ind = sum(temp.length for temp in ic.tunable_buffer_sizes; init = 0)

src/systems/parameter_buffer.jl

+91-4
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,8 @@ function MTKParameters(
132132
tunable_buffer = narrow_buffer_type.(tunable_buffer)
133133
disc_buffer = narrow_buffer_type.(disc_buffer)
134134
const_buffer = narrow_buffer_type.(const_buffer)
135-
nonnumeric_buffer = narrow_buffer_type.(nonnumeric_buffer)
135+
# Don't narrow nonnumeric types
136+
nonnumeric_buffer = nonnumeric_buffer
136137

137138
if has_parameter_dependencies(sys) &&
138139
(pdeps = get_parameter_dependencies(sys)) !== nothing
@@ -308,22 +309,31 @@ end
308309

309310
function SymbolicIndexingInterface.set_parameter!(
310311
p::MTKParameters, val, idx::ParameterIndex)
311-
@unpack portion, idx = idx
312+
@unpack portion, idx, validate_size = idx
312313
i, j, k... = idx
313314
if portion isa SciMLStructures.Tunable
314315
if isempty(k)
316+
if validate_size && size(val) !== size(p.tunable[i][j])
317+
throw(InvalidParameterSizeException(size(p.tunable[i][j]), size(val)))
318+
end
315319
p.tunable[i][j] = val
316320
else
317321
p.tunable[i][j][k...] = val
318322
end
319323
elseif portion isa SciMLStructures.Discrete
320324
if isempty(k)
325+
if validate_size && size(val) !== size(p.discrete[i][j])
326+
throw(InvalidParameterSizeException(size(p.discrete[i][j]), size(val)))
327+
end
321328
p.discrete[i][j] = val
322329
else
323330
p.discrete[i][j][k...] = val
324331
end
325332
elseif portion isa SciMLStructures.Constants
326333
if isempty(k)
334+
if validate_size && size(val) !== size(p.constant[i][j])
335+
throw(InvalidParameterSizeException(size(p.constant[i][j]), size(val)))
336+
end
327337
p.constant[i][j] = val
328338
else
329339
p.constant[i][j][k...] = val
@@ -392,14 +402,73 @@ function narrow_buffer_type_and_fallback_undefs(oldbuf::Vector, newbuf::Vector)
392402
isassigned(newbuf, i) || continue
393403
type = promote_type(type, typeof(newbuf[i]))
394404
end
405+
if type == Union{}
406+
type = eltype(oldbuf)
407+
end
395408
for i in eachindex(newbuf)
396409
isassigned(newbuf, i) && continue
397410
newbuf[i] = convert(type, oldbuf[i])
398411
end
399412
return convert(Vector{type}, newbuf)
400413
end
401414

402-
function SymbolicIndexingInterface.remake_buffer(sys, oldbuf::MTKParameters, vals::Dict)
415+
function validate_parameter_type(ic::IndexCache, p, index, val)
416+
p = unwrap(p)
417+
if p isa Symbol
418+
p = get(ic.symbol_to_variable, p, nothing)
419+
if p === nothing
420+
@warn "No matching variable found for `Symbol` $p, skipping type validation."
421+
return nothing
422+
end
423+
end
424+
(; portion) = index
425+
# Nonnumeric parameters have to match the type
426+
if portion === NONNUMERIC_PORTION
427+
stype = symtype(p)
428+
val isa stype && return nothing
429+
throw(ParameterTypeException(:validate_parameter_type, p, stype, val))
430+
end
431+
stype = symtype(p)
432+
# Array parameters need array values...
433+
if stype <: AbstractArray && !isa(val, AbstractArray)
434+
throw(ParameterTypeException(:validate_parameter_type, p, stype, val))
435+
end
436+
# ... and must match sizes
437+
if stype <: AbstractArray && Symbolics.shape(p) !== Symbolics.Unknown() &&
438+
size(val) != size(p)
439+
throw(InvalidParameterSizeException(p, val))
440+
end
441+
# Early exit
442+
val isa stype && return nothing
443+
if stype <: AbstractArray
444+
# Arrays need handling when eltype is `Real` (accept any real array)
445+
etype = eltype(stype)
446+
if etype <: Real
447+
etype = Real
448+
end
449+
# This is for duals and other complicated number types
450+
etype = SciMLBase.parameterless_type(etype)
451+
eltype(val) <: etype || throw(ParameterTypeException(
452+
:validate_parameter_type, p, AbstractArray{etype}, val))
453+
else
454+
# Real check
455+
if stype <: Real
456+
stype = Real
457+
end
458+
stype = SciMLBase.parameterless_type(stype)
459+
val isa stype ||
460+
throw(ParameterTypeException(:validate_parameter_type, p, stype, val))
461+
end
462+
end
463+
464+
function indp_to_system(indp)
465+
while hasmethod(symbolic_container, Tuple{typeof(indp)})
466+
indp = symbolic_container(indp)
467+
end
468+
return indp
469+
end
470+
471+
function SymbolicIndexingInterface.remake_buffer(indp, oldbuf::MTKParameters, vals::Dict)
403472
newbuf = @set oldbuf.tunable = Tuple(Vector{Any}(undef, length(buf))
404473
for buf in oldbuf.tunable)
405474
@set! newbuf.discrete = Tuple(Vector{Any}(undef, length(buf))
@@ -409,9 +478,15 @@ function SymbolicIndexingInterface.remake_buffer(sys, oldbuf::MTKParameters, val
409478
@set! newbuf.nonnumeric = Tuple(Vector{Any}(undef, length(buf))
410479
for buf in newbuf.nonnumeric)
411480

481+
# If the parameter buffer is an `MTKParameters` object, `indp` must eventually drill
482+
# down to an `AbstractSystem` using `symbolic_container`. We leverage this to get
483+
# the index cache.
484+
ic = get_index_cache(indp_to_system(indp))
412485
for (p, val) in vals
486+
idx = parameter_index(indp, p)
487+
validate_parameter_type(ic, p, idx, val)
413488
_set_parameter_unchecked!(
414-
newbuf, val, parameter_index(sys, p); update_dependent = false)
489+
newbuf, val, idx; update_dependent = false)
415490
end
416491

417492
@set! newbuf.tunable = narrow_buffer_type_and_fallback_undefs.(
@@ -588,3 +663,15 @@ function Base.showerror(io::IO, e::MissingParametersError)
588663
println(io, MISSING_PARAMETERS_MESSAGE)
589664
println(io, e.vars)
590665
end
666+
667+
function InvalidParameterSizeException(param, val)
668+
DimensionMismatch("InvalidParameterSizeException: For parameter $(param) expected value of size $(size(param)). Received value $(val) of size $(size(val)).")
669+
end
670+
671+
function InvalidParameterSizeException(param::Tuple, val::Tuple)
672+
DimensionMismatch("InvalidParameterSizeException: Expected value of size $(param). Received value of size $(val).")
673+
end
674+
675+
function ParameterTypeException(func, param, expected, val)
676+
TypeError(func, "Parameter $param", expected, val)
677+
end

test/index_cache.jl

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
using ModelingToolkit, SymbolicIndexingInterface
2+
using ModelingToolkit: t_nounits as t
3+
4+
# Ensure indexes of array symbolics are cached appropriately
5+
@variables x(t)[1:2]
6+
@named sys = ODESystem(Equation[], t, [x], [])
7+
sys1 = complete(sys)
8+
@named sys = ODESystem(Equation[], t, [x...], [])
9+
sys2 = complete(sys)
10+
for sys in [sys1, sys2]
11+
for (sym, idx) in [(x, 1:2), (x[1], 1), (x[2], 2)]
12+
@test is_variable(sys, sym)
13+
@test variable_index(sys, sym) == idx
14+
end
15+
end
16+
17+
@variables x(t)[1:2, 1:2]
18+
@named sys = ODESystem(Equation[], t, [x], [])
19+
sys1 = complete(sys)
20+
@named sys = ODESystem(Equation[], t, [x...], [])
21+
sys2 = complete(sys)
22+
for sys in [sys1, sys2]
23+
@test is_variable(sys, x)
24+
@test variable_index(sys, x) == [1 3; 2 4]
25+
for i in eachindex(x)
26+
@test is_variable(sys, x[i])
27+
@test variable_index(sys, x[i]) == variable_index(sys, x)[i]
28+
end
29+
end
30+
31+
# Ensure Symbol to symbolic map is correct
32+
@parameters p1 p2[1:2] p3::String
33+
@variables x(t) y(t)[1:2] z(t)
34+
35+
@named sys = ODESystem(Equation[], t, [x, y, z], [p1, p2, p3])
36+
sys = complete(sys)
37+
38+
ic = ModelingToolkit.get_index_cache(sys)
39+
40+
@test isequal(ic.symbol_to_variable[:p1], p1)
41+
@test isequal(ic.symbol_to_variable[:p2], p2)
42+
@test isequal(ic.symbol_to_variable[:p3], p3)
43+
@test isequal(ic.symbol_to_variable[:x], x)
44+
@test isequal(ic.symbol_to_variable[:y], y)
45+
@test isequal(ic.symbol_to_variable[:z], z)

test/mtkparameters.jl

+42
Original file line numberDiff line numberDiff line change
@@ -241,3 +241,45 @@ newps = remake_buffer(
241241
VDual = Vector{<:ForwardDiff.Dual}
242242
VVDual = Vector{<:Vector{<:ForwardDiff.Dual}}
243243
@test newps.dependent isa Union{Tuple{VDual, VVDual}, Tuple{VVDual, VDual}}
244+
245+
@testset "Parameter type validation" begin
246+
struct Foo{T}
247+
x::T
248+
end
249+
250+
@parameters a b::Int c::Vector{Float64} d[1:2, 1:2]::Int e::Foo{Int} f::Foo
251+
@named sys = ODESystem(Equation[], t, [], [a, b, c, d, e, f])
252+
sys = complete(sys)
253+
ps = MTKParameters(sys,
254+
Dict(a => 1.0, b => 2, c => 3ones(2),
255+
d => 3ones(Int, 2, 2), e => Foo(1), f => Foo("a")))
256+
@test_nowarn setp(sys, c)(ps, ones(4)) # so this is fixed when SII is fixed
257+
@test_throws DimensionMismatch set_parameter!(
258+
ps, 4ones(Int, 3, 2), parameter_index(sys, d))
259+
@test_throws DimensionMismatch set_parameter!(
260+
ps, 4ones(Int, 4), parameter_index(sys, d)) # size has to match, not just length
261+
@test_nowarn setp(sys, f)(ps, Foo(:a)) # can change non-concrete type
262+
263+
# Same flexibility is afforded to `b::Int` to allow for ForwardDiff
264+
for sym in [a, b]
265+
@test_nowarn remake_buffer(sys, ps, Dict(sym => 1))
266+
newps = @test_nowarn remake_buffer(sys, ps, Dict(sym => 1.0f0)) # Can change type if it's numeric
267+
@test getp(sys, sym)(newps) isa Float32
268+
newps = @test_nowarn remake_buffer(sys, ps, Dict(sym => ForwardDiff.Dual(1.0)))
269+
@test getp(sys, sym)(newps) isa ForwardDiff.Dual
270+
@test_throws TypeError remake_buffer(sys, ps, Dict(sym => :a)) # still has to be numeric
271+
end
272+
273+
newps = @test_nowarn remake_buffer(sys, ps, Dict(c => view(1.0:4.0, 2:4))) # can change type of array
274+
@test getp(sys, c)(newps) == 2.0:4.0
275+
@test parameter_values(newps, parameter_index(sys, c)) [2.0, 3.0, 4.0]
276+
@test_throws TypeError remake_buffer(sys, ps, Dict(c => [:a, :b, :c])) # can't arbitrarily change eltype
277+
@test_throws TypeError remake_buffer(sys, ps, Dict(c => :a)) # can't arbitrarily change type
278+
279+
newps = @test_nowarn remake_buffer(sys, ps, Dict(d => ForwardDiff.Dual.(ones(2, 2)))) # can change eltype
280+
@test_throws TypeError remake_buffer(sys, ps, Dict(d => [:a :b; :c :d])) # eltype still has to be numeric
281+
@test getp(sys, d)(newps) isa Matrix{<:ForwardDiff.Dual}
282+
283+
@test_throws TypeError remake_buffer(sys, ps, Dict(e => Foo(2.0))) # need exact same type for nonnumeric
284+
@test_nowarn remake_buffer(sys, ps, Dict(f => Foo(:a)))
285+
end

test/runtests.jl

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ end
2424
@safetestset "Parsing Test" include("variable_parsing.jl")
2525
@safetestset "Simplify Test" include("simplify.jl")
2626
@safetestset "Direct Usage Test" include("direct.jl")
27+
@safetestset "IndexCache Test" include("index_cache.jl")
2728
@safetestset "System Linearity Test" include("linearity.jl")
2829
@safetestset "Input Output Test" include("input_output_handling.jl")
2930
@safetestset "Clock Test" include("clock.jl")

0 commit comments

Comments
 (0)