Skip to content

Commit a9c56b6

Browse files
committed
Suppose heterogeneous parameters for linearize and remake
1 parent ee4deb9 commit a9c56b6

7 files changed

+85
-77
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ PrecompileTools = "1"
8585
RecursiveArrayTools = "2.3"
8686
Reexport = "0.2, 1"
8787
RuntimeGeneratedFunctions = "0.5.9"
88-
SciMLBase = "2.0.1"
88+
SciMLBase = "1, 2.0.1"
8989
Setfield = "0.7, 0.8, 1"
9090
SimpleNonlinearSolve = "0.1.0"
9191
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"

src/systems/abstractsystem.jl

+15-3
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,8 @@ for prop in [:eqs
230230
:metadata
231231
:gui_metadata
232232
:discrete_subsystems
233-
:unknown_states]
233+
:unknown_states
234+
:split_idxs]
234235
fname1 = Symbol(:get_, prop)
235236
fname2 = Symbol(:has_, prop)
236237
@eval begin
@@ -1274,14 +1275,25 @@ See also [`linearize`](@ref) which provides a higher-level interface.
12741275
function linearization_function(sys::AbstractSystem, inputs,
12751276
outputs; simplify = false,
12761277
initialize = true,
1278+
op = Dict(),
1279+
p = DiffEqBase.NullParameters(),
12771280
kwargs...)
12781281
sys, diff_idxs, alge_idxs, input_idxs = io_preprocessing(sys, inputs, outputs; simplify,
12791282
kwargs...)
1283+
x0 = merge(defaults(sys), op)
1284+
u0, p, _ = get_u0_p(sys, x0, p; use_union = false, tofloat = true)
1285+
p, split_idxs = split_parameters_by_type(p)
1286+
ps = parameters(sys)
1287+
if p isa Tuple
1288+
ps = Base.Fix1(getindex, ps).(split_idxs)
1289+
ps = (ps...,) #if p is Tuple, ps should be Tuple
1290+
end
1291+
12801292
lin_fun = let diff_idxs = diff_idxs,
12811293
alge_idxs = alge_idxs,
12821294
input_idxs = input_idxs,
12831295
sts = states(sys),
1284-
fun = ODEFunction{true, SciMLBase.FullSpecialize}(sys),
1296+
fun = ODEFunction{true, SciMLBase.FullSpecialize}(sys, states(sys), ps; p = p),
12851297
h = build_explicit_observed_function(sys, outputs),
12861298
chunk = ForwardDiff.Chunk(input_idxs)
12871299

@@ -1600,11 +1612,11 @@ function linearize(sys, inputs, outputs; op = Dict(), t = 0.0,
16001612
allow_input_derivatives = false,
16011613
zero_dummy_der = false,
16021614
kwargs...)
1603-
lin_fun, ssys = linearization_function(sys, inputs, outputs; kwargs...)
16041615
if zero_dummy_der
16051616
dummyder = setdiff(states(ssys), states(sys))
16061617
op = merge(op, Dict(x => 0.0 for x in dummyder))
16071618
end
1619+
lin_fun, ssys = linearization_function(sys, inputs, outputs; op, kwargs...)
16081620
linearize(ssys, lin_fun; op, t, allow_input_derivatives), ssys
16091621
end
16101622

src/systems/diffeqs/abstractodesystem.jl

+7-3
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = s
354354
checkbounds = false,
355355
sparsity = false,
356356
analytic = nothing,
357+
split_idxs = nothing,
357358
kwargs...) where {iip, specialize}
358359
f_gen = generate_function(sys, dvs, ps; expression = Val{eval_expression},
359360
expression_module = eval_module, checkbounds = checkbounds,
@@ -508,6 +509,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = s
508509
nothing
509510
end
510511

512+
@set! sys.split_idxs = split_idxs
511513
ODEFunction{iip, specialize}(f;
512514
sys = sys,
513515
jac = _jac === nothing ? nothing : _jac,
@@ -765,15 +767,17 @@ Take dictionaries with initial conditions and parameters and convert them to num
765767
"""
766768
function get_u0_p(sys,
767769
u0map,
768-
parammap;
770+
parammap = nothing;
769771
use_union = true,
770772
tofloat = true,
771773
symbolic_u0 = false)
772774
dvs = states(sys)
773775
ps = parameters(sys)
774776

775777
defs = defaults(sys)
776-
defs = mergedefaults(defs, parammap, ps)
778+
if parammap !== nothing
779+
defs = mergedefaults(defs, parammap, ps)
780+
end
777781
defs = mergedefaults(defs, u0map, dvs)
778782

779783
if symbolic_u0
@@ -835,7 +839,7 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
835839
f = constructor(sys, dvs, ps, u0; ddvs = ddvs, tgrad = tgrad, jac = jac,
836840
checkbounds = checkbounds, p = p,
837841
linenumbers = linenumbers, parallel = parallel, simplify = simplify,
838-
sparse = sparse, eval_expression = eval_expression,
842+
sparse = sparse, eval_expression = eval_expression, split_idxs,
839843
kwargs...)
840844
implicit_dae ? (f, du0, u0, p) : (f, u0, p)
841845
end

src/systems/diffeqs/odesystem.jl

+7-3
Original file line numberDiff line numberDiff line change
@@ -139,15 +139,19 @@ struct ODESystem <: AbstractODESystem
139139
used for ODAEProblem.
140140
"""
141141
unknown_states::Union{Nothing, Vector{Any}}
142+
"""
143+
split_idxs: a vector of vectors of indices for the split parameters.
144+
"""
145+
split_idxs::Union{Nothing, Vector{Vector{Int}}}
142146

143147
function ODESystem(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad,
144148
jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults,
145149
torn_matching, connector_type, preface, cevents,
146150
devents, metadata = nothing, gui_metadata = nothing,
147151
tearing_state = nothing,
148152
substitutions = nothing, complete = false,
149-
discrete_subsystems = nothing, unknown_states = nothing;
150-
checks::Union{Bool, Int} = true)
153+
discrete_subsystems = nothing, unknown_states = nothing,
154+
split_idxs = nothing; checks::Union{Bool, Int} = true)
151155
if checks == true || (checks & CheckComponents) > 0
152156
check_variables(dvs, iv)
153157
check_parameters(ps, iv)
@@ -161,7 +165,7 @@ struct ODESystem <: AbstractODESystem
161165
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, torn_matching,
162166
connector_type, preface, cevents, devents, metadata, gui_metadata,
163167
tearing_state, substitutions, complete, discrete_subsystems,
164-
unknown_states)
168+
unknown_states, split_idxs)
165169
end
166170
end
167171

src/utils.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -661,7 +661,7 @@ function promote_to_concrete(vs; tofloat = true, use_union = true)
661661
else
662662
sym_vs = filter(x -> SymbolicUtils.issym(x) || SymbolicUtils.istree(x), vs)
663663
isempty(sym_vs) || throw_missingvars_in_sys(sym_vs)
664-
664+
665665
C = nothing
666666
for v in vs
667667
E = typeof(v)
@@ -676,7 +676,7 @@ function promote_to_concrete(vs; tofloat = true, use_union = true)
676676
if use_union
677677
C = Union{C, E}
678678
else
679-
@assert C == E "`promote_to_concrete` can't make type $E uniform with $C"
679+
@assert C==E "`promote_to_concrete` can't make type $E uniform with $C"
680680
C = E
681681
end
682682
end
@@ -686,7 +686,7 @@ function promote_to_concrete(vs; tofloat = true, use_union = true)
686686
if (vs[i] isa Number) & tofloat
687687
y[i] = float(vs[i]) #needed because copyto! can't convert Int to Float automatically
688688
else
689-
y[i] = vs[i]
689+
y[i] = vs[i]
690690
end
691691
end
692692

src/variables.jl

+17-9
Original file line numberDiff line numberDiff line change
@@ -145,15 +145,23 @@ function SciMLBase.process_p_u0_symbolic(prob::Union{SciMLBase.AbstractDEProblem
145145
" Please use `remake` with the `u0` keyword argument as a vector of values, paying attention to state order."))
146146
end
147147

148-
# assemble defaults
149-
defs = defaults(prob.f.sys)
150-
defs = mergedefaults(defs, prob.p, parameters(prob.f.sys))
151-
defs = mergedefaults(defs, p, parameters(prob.f.sys))
152-
defs = mergedefaults(defs, prob.u0, states(prob.f.sys))
153-
defs = mergedefaults(defs, u0, states(prob.f.sys))
154-
155-
u0 = varmap_to_vars(u0, states(prob.f.sys); defaults = defs, tofloat = true)
156-
p = varmap_to_vars(p, parameters(prob.f.sys); defaults = defs)
148+
sys = prob.f.sys
149+
defs = defaults(sys)
150+
ps = parameters(sys)
151+
if has_split_idxs(sys) && (split_idxs = get_split_idxs(sys)) !== nothing
152+
for (i, idxs) in enumerate(split_idxs)
153+
defs = mergedefaults(defs, prob.p[i], ps[idxs])
154+
end
155+
else
156+
# assemble defaults
157+
defs = defaults(sys)
158+
defs = mergedefaults(defs, prob.p, ps)
159+
end
160+
defs = mergedefaults(defs, p, ps)
161+
sts = states(sys)
162+
defs = mergedefaults(defs, prob.u0, sts)
163+
defs = mergedefaults(defs, u0, sts)
164+
u0, p, defs = get_u0_p(sys, defs)
157165

158166
return p, u0
159167
end

test/split_parameters.jl

+35-55
Original file line numberDiff line numberDiff line change
@@ -2,34 +2,29 @@ using ModelingToolkit, Test
22
using ModelingToolkitStandardLibrary.Blocks
33
using OrdinaryDiffEq
44

5-
6-
x = [1, 2.0, false, [1,2,3], Parameter(1.0)]
5+
x = [1, 2.0, false, [1, 2, 3], Parameter(1.0)]
76

87
y = ModelingToolkit.promote_to_concrete(x)
98
@test eltype(y) == Union{Float64, Parameter{Float64}, Vector{Int64}}
109

11-
y = ModelingToolkit.promote_to_concrete(x; tofloat=false)
10+
y = ModelingToolkit.promote_to_concrete(x; tofloat = false)
1211
@test eltype(y) == Union{Bool, Float64, Int64, Parameter{Float64}, Vector{Int64}}
1312

14-
15-
x = [1, 2.0, false, [1,2,3]]
13+
x = [1, 2.0, false, [1, 2, 3]]
1614
y = ModelingToolkit.promote_to_concrete(x)
1715
@test eltype(y) == Union{Float64, Vector{Int64}}
1816

1917
x = Any[1, 2.0, false]
20-
y = ModelingToolkit.promote_to_concrete(x; tofloat=false)
18+
y = ModelingToolkit.promote_to_concrete(x; tofloat = false)
2119
@test eltype(y) == Union{Bool, Float64, Int64}
2220

23-
y = ModelingToolkit.promote_to_concrete(x; use_union=false)
21+
y = ModelingToolkit.promote_to_concrete(x; use_union = false)
2422
@test eltype(y) == Float64
2523

26-
x = Float16[1., 2., 3.]
24+
x = Float16[1.0, 2.0, 3.0]
2725
y = ModelingToolkit.promote_to_concrete(x)
2826
@test eltype(y) == Float16
2927

30-
31-
32-
3328
# ------------------------ Mixed Single Values and Vector
3429

3530
dt = 4e-4
@@ -74,7 +69,7 @@ eqs = [y ~ src.output.u
7469
@named sys = ODESystem(eqs, t, vars, []; systems = [int, src])
7570
s = complete(sys)
7671
sys = structural_simplify(sys)
77-
prob = ODEProblem(sys, [], (0.0, t_end), [s.src.data => x]; tofloat=false)
72+
prob = ODEProblem(sys, [], (0.0, t_end), [s.src.data => x]; tofloat = false)
7873
@test prob.p isa Tuple{Vector{Float64}, Vector{Int}, Vector{Vector{Float64}}}
7974
sol = solve(prob, ImplicitEuler());
8075
@test sol.retcode == ReturnCode.Success
@@ -83,18 +78,15 @@ sol = solve(prob, ImplicitEuler());
8378
#TODO: remake becomes more complicated now, how to improve?
8479
defs = ModelingToolkit.defaults(sys)
8580
defs[s.src.data] = 2x
86-
p′ = ModelingToolkit.varmap_to_vars(defs, parameters(sys); tofloat=false)
81+
p′ = ModelingToolkit.varmap_to_vars(defs, parameters(sys); tofloat = false)
8782
p′, = ModelingToolkit.split_parameters_by_type(p′) #NOTE: we need to ensure this is called now before calling remake()
88-
prob′ = remake(prob; p=p′)
83+
prob′ = remake(prob; p = p′)
8984
sol = solve(prob′, ImplicitEuler());
9085
@test sol.retcode == ReturnCode.Success
9186
@test sol[y][end] == 2x[end]
9287

93-
prob′′ = remake(prob; p=[s.src.data => x])
94-
@test prob′′.p isa Tuple
95-
96-
97-
88+
prob′′ = remake(prob; p = [s.src.data => x])
89+
@test_broken prob′′.p isa Tuple
9890

9991
# ------------------------ Mixed Type Converted to float (default behavior)
10092

@@ -122,11 +114,6 @@ prob = ODEProblem(sys, [], tspan, []; tofloat = false)
122114
sol = solve(prob, ImplicitEuler());
123115
@test sol.retcode == ReturnCode.Success
124116

125-
126-
127-
128-
129-
130117
# ------------------------- Bug
131118
using ModelingToolkit, LinearAlgebra
132119
using ModelingToolkitStandardLibrary.Mechanical.Rotational
@@ -136,51 +123,48 @@ using ModelingToolkit: connect
136123

137124
"A wrapper function to make symbolic indexing easier"
138125
function wr(sys)
139-
ODESystem(Equation[], ModelingToolkit.get_iv(sys), systems=[sys], name=:a_wrapper)
126+
ODESystem(Equation[], ModelingToolkit.get_iv(sys), systems = [sys], name = :a_wrapper)
140127
end
141-
indexof(sym,syms) = findfirst(isequal(sym),syms)
128+
indexof(sym, syms) = findfirst(isequal(sym), syms)
142129

143130
# Parameters
144-
m1 = 1.
145-
m2 = 1.
146-
k = 10. # Spring stiffness
147-
c = 3. # Damping coefficient
131+
m1 = 1.0
132+
m2 = 1.0
133+
k = 10.0 # Spring stiffness
134+
c = 3.0 # Damping coefficient
148135

149136
@named inertia1 = Inertia(; J = m1)
150137
@named inertia2 = Inertia(; J = m2)
151138
@named spring = Spring(; c = k)
152139
@named damper = Damper(; d = c)
153-
@named torque = Torque(use_support=false)
140+
@named torque = Torque(use_support = false)
154141

155-
function SystemModel(u=nothing; name=:model)
156-
eqs = [
157-
connect(torque.flange, inertia1.flange_a)
142+
function SystemModel(u = nothing; name = :model)
143+
eqs = [connect(torque.flange, inertia1.flange_a)
158144
connect(inertia1.flange_b, spring.flange_a, damper.flange_a)
159-
connect(inertia2.flange_a, spring.flange_b, damper.flange_b)
160-
]
145+
connect(inertia2.flange_a, spring.flange_b, damper.flange_b)]
161146
if u !== nothing
162147
push!(eqs, connect(torque.tau, u.output))
163-
return @named model = ODESystem(eqs, t; systems = [torque, inertia1, inertia2, spring, damper, u])
148+
return @named model = ODESystem(eqs,
149+
t;
150+
systems = [torque, inertia1, inertia2, spring, damper, u])
164151
end
165152
ODESystem(eqs, t; systems = [torque, inertia1, inertia2, spring, damper], name)
166153
end
167154

168-
169155
model = SystemModel() # Model with load disturbance
170-
@named d = Step(start_time=1., duration=10., offset=0., height=1.) # Disturbance
156+
@named d = Step(start_time = 1.0, duration = 10.0, offset = 0.0, height = 1.0) # Disturbance
171157
model_outputs = [model.inertia1.w, model.inertia2.w, model.inertia1.phi, model.inertia2.phi] # This is the state realization we want to control
172158
inputs = [model.torque.tau.u]
173159
matrices, ssys = ModelingToolkit.linearize(wr(model), inputs, model_outputs)
174160

175161
# Design state-feedback gain using LQR
176162
# Define cost matrices
177-
x_costs = [
178-
model.inertia1.w => 1.
179-
model.inertia2.w => 1.
180-
model.inertia1.phi => 1.
181-
model.inertia2.phi => 1.
182-
]
183-
L = randn(1,4) # Post-multiply by `C` to get the correct input to the controller
163+
x_costs = [model.inertia1.w => 1.0
164+
model.inertia2.w => 1.0
165+
model.inertia1.phi => 1.0
166+
model.inertia2.phi => 1.0]
167+
L = randn(1, 4) # Post-multiply by `C` to get the correct input to the controller
184168

185169
# This old definition of MatrixGain will work because the parameter space does not include K (an Array term)
186170
# @component function MatrixGainAlt(K::AbstractArray; name)
@@ -191,16 +175,12 @@ L = randn(1,4) # Post-multiply by `C` to get the correct input to the controller
191175
# compose(ODESystem(eqs, t, [], []; name = name), [input, output])
192176
# end
193177

194-
@named state_feedback = MatrixGain(K=-L) # Build negative feedback into the feedback matrix
195-
@named add = Add(;k1=1., k2=1.) # To add the control signal and the disturbance
178+
@named state_feedback = MatrixGain(K = -L) # Build negative feedback into the feedback matrix
179+
@named add = Add(; k1 = 1.0, k2 = 1.0) # To add the control signal and the disturbance
196180

197-
connections = [
198-
[state_feedback.input.u[i] ~ model_outputs[i] for i in 1:4]
181+
connections = [[state_feedback.input.u[i] ~ model_outputs[i] for i in 1:4]
199182
connect(d.output, :d, add.input1)
200183
connect(add.input2, state_feedback.output)
201-
connect(add.output, :u, model.torque.tau)
202-
]
203-
closed_loop = ODESystem(connections, t, systems=[model, state_feedback, add, d], name=:closed_loop)
184+
connect(add.output, :u, model.torque.tau)]
185+
@named closed_loop = ODESystem(connections, t, systems = [model, state_feedback, add, d])
204186
S = get_sensitivity(closed_loop, :u)
205-
206-

0 commit comments

Comments
 (0)