Skip to content

Commit 914a209

Browse files
committed
Support typed arrays
1 parent 8ab6bfe commit 914a209

File tree

3 files changed

+37
-1
lines changed

3 files changed

+37
-1
lines changed

Diff for: src/systems/model_parsing.jl

+15-1
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,11 @@ function update_kwargs_and_metadata!(dict, kwargs, a, def, indices, type, var,
160160
Expr(:(::), a,
161161
Expr(:curly, :Union, :Nothing, Expr(:curly, :AbstractArray, vartype))),
162162
nothing))
163-
push!(where_types, :($vartype <: $type))
163+
if !isnothing(meta) && haskey(meta, VariableUnit)
164+
push!(where_types, vartype)
165+
else
166+
push!(where_types, :($vartype <: $type))
167+
end
164168
dict[:kwargs][getname(var)] = Dict(:value => def, :type => AbstractArray{type})
165169
end
166170
if dict[varclass] isa Vector
@@ -624,10 +628,20 @@ function convert_units(varunits::DynamicQuantities.Quantity, value)
624628
DynamicQuantities.SymbolicUnits.as_quantity(varunits), value))
625629
end
626630

631+
function convert_units(varunits::DynamicQuantities.Quantity, value::AbstractArray{T}) where T
632+
DynamicQuantities.ustrip.(DynamicQuantities.uconvert.(
633+
DynamicQuantities.SymbolicUnits.as_quantity(varunits), value))
634+
end
635+
627636
function convert_units(varunits::Unitful.FreeUnits, value)
628637
Unitful.ustrip(varunits, value)
629638
end
630639

640+
function convert_units(varunits::Unitful.FreeUnits, value::AbstractArray{T}) where T
641+
Unitful.ustrip.(varunits, value)
642+
end
643+
644+
631645
function parse_variable_arg(dict, mod, arg, varclass, kwargs, where_types)
632646
vv, def, metadata_with_exprs = parse_variable_def!(
633647
dict, mod, arg, varclass, kwargs, where_types)

Diff for: test/dq_units.jl

+11
Original file line numberDiff line numberDiff line change
@@ -174,3 +174,14 @@ end
174174

175175
@test_throws ErrorException ParamTest(; name = :t, a = 1.0)
176176
@test_throws ErrorException ParamTest(; name = :t, a = 1.0u"s")
177+
178+
@mtkmodel ArrayParamTest begin
179+
@parameters begin
180+
a[1:2], [unit = u"m"]
181+
end
182+
end
183+
184+
@named sys = ArrayParamTest()
185+
186+
@named sys = ArrayParamTest(a = [1.0, 3.0]u"cm")
187+
@test ModelingToolkit.getdefault(sys.a) [0.01, 0.03]

Diff for: test/units.jl

+11
Original file line numberDiff line numberDiff line change
@@ -209,3 +209,14 @@ end
209209

210210
@test_throws ErrorException ParamTest(; name = :t, a = 1.0)
211211
@test_throws ErrorException ParamTest(; name = :t, a = 1.0u"s")
212+
213+
@mtkmodel ArrayParamTest begin
214+
@parameters begin
215+
a[1:2], [unit = u"m"]
216+
end
217+
end
218+
219+
@named sys = ArrayParamTest()
220+
221+
@named sys = ArrayParamTest(a = [1.0, 3.0]u"cm")
222+
@test ModelingToolkit.getdefault(sys.a) [0.01, 0.03]

0 commit comments

Comments
 (0)