@@ -132,7 +132,8 @@ function MTKParameters(
132
132
tunable_buffer = narrow_buffer_type .(tunable_buffer)
133
133
disc_buffer = narrow_buffer_type .(disc_buffer)
134
134
const_buffer = narrow_buffer_type .(const_buffer)
135
- nonnumeric_buffer = narrow_buffer_type .(nonnumeric_buffer)
135
+ # Don't narrow nonnumeric types
136
+ nonnumeric_buffer = nonnumeric_buffer
136
137
137
138
if has_parameter_dependencies (sys) &&
138
139
(pdeps = get_parameter_dependencies (sys)) != = nothing
@@ -308,22 +309,31 @@ end
308
309
309
310
function SymbolicIndexingInterface. set_parameter! (
310
311
p:: MTKParameters , val, idx:: ParameterIndex )
311
- @unpack portion, idx = idx
312
+ @unpack portion, idx, validate_size = idx
312
313
i, j, k... = idx
313
314
if portion isa SciMLStructures. Tunable
314
315
if isempty (k)
316
+ if validate_size && size (val) != = size (p. tunable[i][j])
317
+ throw (InvalidParameterSizeException (size (p. tunable[i][j]), size (val)))
318
+ end
315
319
p. tunable[i][j] = val
316
320
else
317
321
p. tunable[i][j][k... ] = val
318
322
end
319
323
elseif portion isa SciMLStructures. Discrete
320
324
if isempty (k)
325
+ if validate_size && size (val) != = size (p. discrete[i][j])
326
+ throw (InvalidParameterSizeException (size (p. discrete[i][j]), size (val)))
327
+ end
321
328
p. discrete[i][j] = val
322
329
else
323
330
p. discrete[i][j][k... ] = val
324
331
end
325
332
elseif portion isa SciMLStructures. Constants
326
333
if isempty (k)
334
+ if validate_size && size (val) != = size (p. constant[i][j])
335
+ throw (InvalidParameterSizeException (size (p. constant[i][j]), size (val)))
336
+ end
327
337
p. constant[i][j] = val
328
338
else
329
339
p. constant[i][j][k... ] = val
@@ -392,14 +402,73 @@ function narrow_buffer_type_and_fallback_undefs(oldbuf::Vector, newbuf::Vector)
392
402
isassigned (newbuf, i) || continue
393
403
type = promote_type (type, typeof (newbuf[i]))
394
404
end
405
+ if type == Union{}
406
+ type = eltype (oldbuf)
407
+ end
395
408
for i in eachindex (newbuf)
396
409
isassigned (newbuf, i) && continue
397
410
newbuf[i] = convert (type, oldbuf[i])
398
411
end
399
412
return convert (Vector{type}, newbuf)
400
413
end
401
414
402
- function SymbolicIndexingInterface. remake_buffer (sys, oldbuf:: MTKParameters , vals:: Dict )
415
+ function validate_parameter_type (ic:: IndexCache , p, index, val)
416
+ p = unwrap (p)
417
+ if p isa Symbol
418
+ p = get (ic. symbol_to_variable, p, nothing )
419
+ if p === nothing
420
+ @warn " No matching variable found for `Symbol` $p , skipping type validation."
421
+ return nothing
422
+ end
423
+ end
424
+ (; portion) = index
425
+ # Nonnumeric parameters have to match the type
426
+ if portion === NONNUMERIC_PORTION
427
+ stype = symtype (p)
428
+ val isa stype && return nothing
429
+ throw (ParameterTypeException (:validate_parameter_type , p, stype, val))
430
+ end
431
+ stype = symtype (p)
432
+ # Array parameters need array values...
433
+ if stype <: AbstractArray && ! isa (val, AbstractArray)
434
+ throw (ParameterTypeException (:validate_parameter_type , p, stype, val))
435
+ end
436
+ # ... and must match sizes
437
+ if stype <: AbstractArray && Symbolics. shape (p) != = Symbolics. Unknown () &&
438
+ size (val) != size (p)
439
+ throw (InvalidParameterSizeException (p, val))
440
+ end
441
+ # Early exit
442
+ val isa stype && return nothing
443
+ if stype <: AbstractArray
444
+ # Arrays need handling when eltype is `Real` (accept any real array)
445
+ etype = eltype (stype)
446
+ if etype <: Real
447
+ etype = Real
448
+ end
449
+ # This is for duals and other complicated number types
450
+ etype = SciMLBase. parameterless_type (etype)
451
+ eltype (val) <: etype || throw (ParameterTypeException (
452
+ :validate_parameter_type , p, AbstractArray{etype}, val))
453
+ else
454
+ # Real check
455
+ if stype <: Real
456
+ stype = Real
457
+ end
458
+ stype = SciMLBase. parameterless_type (stype)
459
+ val isa stype ||
460
+ throw (ParameterTypeException (:validate_parameter_type , p, stype, val))
461
+ end
462
+ end
463
+
464
+ function indp_to_system (indp)
465
+ while hasmethod (symbolic_container, Tuple{typeof (indp)})
466
+ indp = symbolic_container (indp)
467
+ end
468
+ return indp
469
+ end
470
+
471
+ function SymbolicIndexingInterface. remake_buffer (indp, oldbuf:: MTKParameters , vals:: Dict )
403
472
newbuf = @set oldbuf. tunable = Tuple (Vector {Any} (undef, length (buf))
404
473
for buf in oldbuf. tunable)
405
474
@set! newbuf. discrete = Tuple (Vector {Any} (undef, length (buf))
@@ -409,9 +478,15 @@ function SymbolicIndexingInterface.remake_buffer(sys, oldbuf::MTKParameters, val
409
478
@set! newbuf. nonnumeric = Tuple (Vector {Any} (undef, length (buf))
410
479
for buf in newbuf. nonnumeric)
411
480
481
+ # If the parameter buffer is an `MTKParameters` object, `indp` must eventually drill
482
+ # down to an `AbstractSystem` using `symbolic_container`. We leverage this to get
483
+ # the index cache.
484
+ ic = get_index_cache (indp_to_system (indp))
412
485
for (p, val) in vals
486
+ idx = parameter_index (indp, p)
487
+ validate_parameter_type (ic, p, idx, val)
413
488
_set_parameter_unchecked! (
414
- newbuf, val, parameter_index (sys, p) ; update_dependent = false )
489
+ newbuf, val, idx ; update_dependent = false )
415
490
end
416
491
417
492
@set! newbuf. tunable = narrow_buffer_type_and_fallback_undefs .(
@@ -588,3 +663,15 @@ function Base.showerror(io::IO, e::MissingParametersError)
588
663
println (io, MISSING_PARAMETERS_MESSAGE)
589
664
println (io, e. vars)
590
665
end
666
+
667
+ function InvalidParameterSizeException (param, val)
668
+ DimensionMismatch (" InvalidParameterSizeException: For parameter $(param) expected value of size $(size (param)) . Received value $(val) of size $(size (val)) ." )
669
+ end
670
+
671
+ function InvalidParameterSizeException (param:: Tuple , val:: Tuple )
672
+ DimensionMismatch (" InvalidParameterSizeException: Expected value of size $(param) . Received value of size $(val) ." )
673
+ end
674
+
675
+ function ParameterTypeException (func, param, expected, val)
676
+ TypeError (func, " Parameter $param " , expected, val)
677
+ end
0 commit comments