-
-
Notifications
You must be signed in to change notification settings - Fork 212
/
Copy pathMTKBifurcationKitExt.jl
159 lines (140 loc) · 6.33 KB
/
MTKBifurcationKitExt.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
module MTKBifurcationKitExt
### Preparations ###
# Imports
using ModelingToolkit, Setfield
import BifurcationKit
### Observable Plotting Handling ###
# Functor used when the plotting variable is an observable. Keeps track of the required information for computing the observable's value at each point of the bifurcation diagram.
struct ObservableRecordFromSolution{S, T}
# The equations determining the observables values.
obs_eqs::S
# The index of the observable that we wish to plot.
target_obs_idx::Int64
# The final index in subs_vals that contains a state.
state_end_idxs::Int64
# The final index in subs_vals that contains a param.
param_end_idxs::Int64
# The index (in subs_vals) that contain the bifurcation parameter.
bif_par_idx::Int64
# A Vector of pairs (Symbolic => value) with the default values of all system variables and parameters.
subs_vals::T
function ObservableRecordFromSolution(nsys::NonlinearSystem,
plot_var,
bif_idx,
u0_vals,
p_vals)
obs_eqs = observed(nsys)
target_obs_idx = findfirst(isequal(plot_var, eq.lhs) for eq in observed(nsys))
state_end_idxs = length(unknowns(nsys))
param_end_idxs = state_end_idxs + length(parameters(nsys))
bif_par_idx = state_end_idxs + bif_idx
# Gets the (base) substitution values for states.
subs_vals_states = Pair.(unknowns(nsys), u0_vals)
# Gets the (base) substitution values for parameters.
subs_vals_params = Pair.(parameters(nsys), p_vals)
# Gets the (base) substitution values for observables.
subs_vals_obs = [obs.lhs => substitute(obs.rhs,
[subs_vals_states; subs_vals_params])
for obs in observed(nsys)]
# Sometimes observables depend on other observables, hence we make a second update to this vector.
subs_vals_obs = [obs.lhs => substitute(obs.rhs,
[subs_vals_states; subs_vals_params; subs_vals_obs])
for obs in observed(nsys)]
# During the bifurcation process, the value of some states, parameters, and observables may vary (and are calculated in each step). Those that are not are stored in this vector
subs_vals = [subs_vals_states; subs_vals_params; subs_vals_obs]
param_end_idxs = state_end_idxs + length(parameters(nsys))
new{typeof(obs_eqs), typeof(subs_vals)}(obs_eqs,
target_obs_idx,
state_end_idxs,
param_end_idxs,
bif_par_idx,
subs_vals)
end
end
# Functor function that computes the value.
function (orfs::ObservableRecordFromSolution)(x, p; k...)
# Updates the state values (in subs_vals).
for state_idx in 1:(orfs.state_end_idxs)
orfs.subs_vals[state_idx] = orfs.subs_vals[state_idx][1] => x[state_idx]
end
# Updates the bifurcation parameters value (in subs_vals).
orfs.subs_vals[orfs.bif_par_idx] = orfs.subs_vals[orfs.bif_par_idx][1] => p
# Updates the observable values (in subs_vals).
for (obs_idx, obs_eq) in enumerate(orfs.obs_eqs)
orfs.subs_vals[orfs.param_end_idxs + obs_idx] = orfs.subs_vals[orfs.param_end_idxs + obs_idx][1] => substitute(
obs_eq.rhs,
orfs.subs_vals)
end
# Substitutes in the value for all states, parameters, and observables into the equation for the designated observable.
return substitute(orfs.obs_eqs[orfs.target_obs_idx].rhs, orfs.subs_vals)
end
### Creates BifurcationProblem Overloads ###
# When input is a NonlinearSystem.
function BifurcationKit.BifurcationProblem(nsys::NonlinearSystem,
u0_bif,
ps,
bif_par,
args...;
plot_var = nothing,
record_from_solution = BifurcationKit.record_sol_default,
jac = true,
kwargs...)
if !ModelingToolkit.iscomplete(nsys)
error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `BifurcationProblem`")
end
@set! nsys.index_cache = nothing # force usage of a parameter vector instead of `MTKParameters`
# Creates F and J functions.
ofun = NonlinearFunction(nsys; jac = jac)
F = let f = ofun.f
_f(resid, u, p) = (f(resid, u, p); resid)
_f(u, p) = f(u, p)
end
J = jac ? ofun.jac : nothing
# Converts the input state guess.
u0_bif_vals = ModelingToolkit.varmap_to_vars(u0_bif,
unknowns(nsys);
defaults = ModelingToolkit.get_defaults(nsys))
p_vals = ModelingToolkit.varmap_to_vars(
ps, parameters(nsys); defaults = ModelingToolkit.get_defaults(nsys))
# Computes bifurcation parameter and the plotting function.
bif_idx = findfirst(isequal(bif_par), parameters(nsys))
if !isnothing(plot_var)
# If the plot var is a normal state.
if any(isequal(plot_var, var) for var in unknowns(nsys))
plot_idx = findfirst(isequal(plot_var), unknowns(nsys))
record_from_solution = (x, p; k...) -> x[plot_idx]
# If the plot var is an observed state.
elseif any(isequal(plot_var, eq.lhs) for eq in observed(nsys))
record_from_solution = ObservableRecordFromSolution(nsys,
plot_var,
bif_idx,
u0_bif_vals,
p_vals)
# If neither an variable nor observable, throw an error.
else
error("The plot variable ($plot_var) was neither recognised as a system state nor observable.")
end
end
return BifurcationKit.BifurcationProblem(F,
u0_bif_vals,
p_vals,
(BifurcationKit.@optic _[bif_idx]),
args...;
record_from_solution = record_from_solution,
J = J,
inplace = true,
kwargs...)
end
# When input is a ODESystem.
function BifurcationKit.BifurcationProblem(osys::ODESystem, args...; kwargs...)
if !ModelingToolkit.iscomplete(osys)
error("A completed `ODESystem` is required. Call `complete` or `structural_simplify` on the system before creating a `BifurcationProblem`")
end
nsys = NonlinearSystem([0 ~ eq.rhs for eq in full_equations(osys)],
unknowns(osys),
parameters(osys);
observed = observed(osys),
name = nameof(osys))
return BifurcationKit.BifurcationProblem(complete(nsys), args...; kwargs...)
end
end # module