Skip to content

Commit 446bb8c

Browse files
committed
format
1 parent 607a6f6 commit 446bb8c

File tree

2 files changed

+113
-69
lines changed

2 files changed

+113
-69
lines changed

ext/MTKBifurcationKitExt.jl

Lines changed: 51 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import BifurcationKit
99
### Observable Plotting Handling ###
1010

1111
# Functor used when the plotting variable is an observable. Keeps track of the required information for computing the observable's value at each point of the bifurcation diagram.
12-
struct ObservableRecordFromSolution{S,T}
12+
struct ObservableRecordFromSolution{S, T}
1313
# The equations determining the observables values.
1414
obs_eqs::S
1515
# The index of the observable that we wish to plot.
@@ -23,32 +23,43 @@ struct ObservableRecordFromSolution{S,T}
2323
# A Vector of pairs (Symbolic => value) with teh default values of all system variables and parameters.
2424
subs_vals::T
2525

26-
function ObservableRecordFromSolution(nsys::NonlinearSystem, plot_var, bif_idx, u0_vals, p_vals) where {S,T}
26+
function ObservableRecordFromSolution(nsys::NonlinearSystem,
27+
plot_var,
28+
bif_idx,
29+
u0_vals,
30+
p_vals) where {S, T}
2731
obs_eqs = observed(nsys)
2832
target_obs_idx = findfirst(isequal(plot_var, eq.lhs) for eq in observed(nsys))
2933
state_end_idxs = length(states(nsys))
3034
param_end_idxs = state_end_idxs + length(parameters(nsys))
3135

3236
bif_par_idx = state_end_idxs + bif_idx
3337
# Gets the (base) substitution values for states.
34-
subs_vals_states = Pair.(states(nsys),u0_vals)
38+
subs_vals_states = Pair.(states(nsys), u0_vals)
3539
# Gets the (base) substitution values for parameters.
36-
subs_vals_params = Pair.(parameters(nsys),p_vals)
40+
subs_vals_params = Pair.(parameters(nsys), p_vals)
3741
# Gets the (base) substitution values for observables.
38-
subs_vals_obs = [obs.lhs => substitute(obs.rhs, [subs_vals_states; subs_vals_params]) for obs in observed(nsys)]
42+
subs_vals_obs = [obs.lhs => substitute(obs.rhs,
43+
[subs_vals_states; subs_vals_params]) for obs in observed(nsys)]
3944
# Sometimes observables depend on other observables, hence we make a second upate to this vector.
40-
subs_vals_obs = [obs.lhs => substitute(obs.rhs, [subs_vals_states; subs_vals_params; subs_vals_obs]) for obs in observed(nsys)]
45+
subs_vals_obs = [obs.lhs => substitute(obs.rhs,
46+
[subs_vals_states; subs_vals_params; subs_vals_obs]) for obs in observed(nsys)]
4147
# During the bifurcation process, teh value of some states, parameters, and observables may vary (and are calculated in each step). Those that are not are stored in this vector
4248
subs_vals = [subs_vals_states; subs_vals_params; subs_vals_obs]
4349

4450
param_end_idxs = state_end_idxs + length(parameters(nsys))
45-
new{typeof(obs_eqs),typeof(subs_vals)}(obs_eqs, target_obs_idx, state_end_idxs, param_end_idxs, bif_par_idx, subs_vals)
51+
new{typeof(obs_eqs), typeof(subs_vals)}(obs_eqs,
52+
target_obs_idx,
53+
state_end_idxs,
54+
param_end_idxs,
55+
bif_par_idx,
56+
subs_vals)
4657
end
4758
end
4859
# Functor function that computes the value.
4960
function (orfs::ObservableRecordFromSolution)(x, p)
5061
# Updates the state values (in subs_vals).
51-
for state_idx in 1:orfs.state_end_idxs
62+
for state_idx in 1:(orfs.state_end_idxs)
5263
orfs.subs_vals[state_idx] = orfs.subs_vals[state_idx][1] => x[state_idx]
5364
end
5465

@@ -57,7 +68,8 @@ function (orfs::ObservableRecordFromSolution)(x, p)
5768

5869
# Updates the observable values (in subs_vals).
5970
for (obs_idx, obs_eq) in enumerate(orfs.obs_eqs)
60-
orfs.subs_vals[orfs.param_end_idxs+obs_idx] = orfs.subs_vals[orfs.param_end_idxs+obs_idx][1] => substitute(obs_eq.rhs, orfs.subs_vals)
71+
orfs.subs_vals[orfs.param_end_idxs + obs_idx] = orfs.subs_vals[orfs.param_end_idxs + obs_idx][1] => substitute(obs_eq.rhs,
72+
orfs.subs_vals)
6173
end
6274

6375
# Substitutes in the value for all states, parameters, and observables into the equation for the designated observable.
@@ -68,42 +80,55 @@ end
6880

6981
# When input is a NonlinearSystem.
7082
function BifurcationKit.BifurcationProblem(nsys::NonlinearSystem,
71-
u0_bif,
72-
ps,
73-
bif_par,
74-
args...;
75-
plot_var = nothing,
76-
record_from_solution = BifurcationKit.record_sol_default,
77-
jac = true,
78-
kwargs...)
83+
u0_bif,
84+
ps,
85+
bif_par,
86+
args...;
87+
plot_var = nothing,
88+
record_from_solution = BifurcationKit.record_sol_default,
89+
jac = true,
90+
kwargs...)
7991
# Creates F and J functions.
8092
ofun = NonlinearFunction(nsys; jac = jac)
8193
F = ofun.f
8294
J = jac ? ofun.jac : nothing
8395

8496
# Converts the input state guess.
85-
u0_bif_vals = ModelingToolkit.varmap_to_vars(u0_bif, states(nsys); defaults=nsys.defaults)
86-
p_vals = ModelingToolkit.varmap_to_vars(ps, parameters(nsys); defaults=nsys.defaults)
97+
u0_bif_vals = ModelingToolkit.varmap_to_vars(u0_bif,
98+
states(nsys);
99+
defaults = nsys.defaults)
100+
p_vals = ModelingToolkit.varmap_to_vars(ps, parameters(nsys); defaults = nsys.defaults)
87101

88102
# Computes bifurcation parameter and the plotting function.
89103
bif_idx = findfirst(isequal(bif_par), parameters(nsys))
90-
if !isnothing(plot_var)
104+
if !isnothing(plot_var)
91105
# If the plot var is a normal state.
92106
if any(isequal(plot_var, var) for var in states(nsys))
93107
plot_idx = findfirst(isequal(plot_var), states(nsys))
94108
record_from_solution = (x, p) -> x[plot_idx]
95109

96-
# If the plot var is an observed state.
97-
elseif any(isequal(plot_var, eq.lhs) for eq in observed(nsys))
98-
record_from_solution = ObservableRecordFromSolution(nsys, plot_var, bif_idx, u0_bif_vals, p_vals)
99-
100-
# If neither an variable nor observable, throw an error.
110+
# If the plot var is an observed state.
111+
elseif any(isequal(plot_var, eq.lhs) for eq in observed(nsys))
112+
record_from_solution = ObservableRecordFromSolution(nsys,
113+
plot_var,
114+
bif_idx,
115+
u0_bif_vals,
116+
p_vals)
117+
118+
# If neither an variable nor observable, throw an error.
101119
else
102120
error("The plot variable ($plot_var) was neither recognised as a system state nor observable.")
103121
end
104122
end
105123

106-
return BifurcationKit.BifurcationProblem(F, u0_bif_vals, p_vals, (@lens _[bif_idx]), args...; record_from_solution = record_from_solution, J = J, kwargs...)
124+
return BifurcationKit.BifurcationProblem(F,
125+
u0_bif_vals,
126+
p_vals,
127+
(@lens _[bif_idx]),
128+
args...;
129+
record_from_solution = record_from_solution,
130+
J = J,
131+
kwargs...)
107132
end
108133

109134
# When input is a ODESystem.

test/extensions/bifurcationkit.jl

Lines changed: 62 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ using BifurcationKit, ModelingToolkit, Test
22

33
# Simple pitchfork diagram, compares solution to native BifurcationKit, checks they are identical.
44
# Checks using `jac=false` option.
5-
let
6-
# Creaets model.
5+
let
6+
# Creates model.
77
@variables t x(t) y(t)
88
@parameters μ α
99
eqs = [0 ~ μ * x - x^3 + α * y,
@@ -14,64 +14,80 @@ let
1414
bif_par = μ
1515
p_start ==> -1.0, α => 1.0]
1616
u0_guess = [x => 1.0, y => 1.0]
17-
plot_var = x;
18-
bprob = BifurcationProblem(nsys, u0_guess, p_start, bif_par; plot_var=plot_var, jac=false)
17+
plot_var = x
18+
bprob = BifurcationProblem(nsys,
19+
u0_guess,
20+
p_start,
21+
bif_par;
22+
plot_var = plot_var,
23+
jac = false)
1924

2025
# Conputes bifurcation diagram.
2126
p_span = (-4.0, 6.0)
22-
opt_newton = NewtonPar(tol = 1e-9, max_iterations = 20)
23-
opts_br = ContinuationPar(dsmin = 0.001, dsmax = 0.05, ds = 0.01,
24-
max_steps = 100, nev = 2, newton_options = opt_newton,
25-
p_min = p_span[1], p_max = p_span[2],
26-
detect_bifurcation = 3, n_inversion = 4, tol_bisection_eigenvalue = 1e-8, dsmin_bisection = 1e-9)
27-
bif_dia = bifurcationdiagram(bprob, PALC(), 2, (args...) -> opts_br; bothside=true)
27+
opts_br = ContinuationPar(max_steps = 500, p_min = p_span[1], p_max = p_span[2])
28+
bif_dia = bifurcationdiagram(bprob, PALC(), 2, (args...) -> opts_br; bothside = true)
2829

2930
# Computes bifurcation diagram using BifurcationKit directly (without going through MTK).
3031
function f_BK(u, p)
3132
x, y = u
32-
μ, α =p
33-
return*x - x^3 + α*y, -y]
33+
μ, α = p
34+
return * x - x^3 + α * y, -y]
3435
end
35-
bprob_BK = BifurcationProblem(f_BK, [1.0, 1.0], [-1.0, 1.0], (@lens _[1]); record_from_solution = (x, p) -> x[1])
36-
bif_dia_BK = bifurcationdiagram(bprob_BK, PALC(), 2, (args...) -> opts_br; bothside=true)
36+
bprob_BK = BifurcationProblem(f_BK,
37+
[1.0, 1.0],
38+
[-1.0, 1.0],
39+
(@lens _[1]);
40+
record_from_solution = (x, p) -> x[1])
41+
bif_dia_BK = bifurcationdiagram(bprob_BK,
42+
PALC(),
43+
2,
44+
(args...) -> opts_br;
45+
bothside = true)
3746

3847
# Compares results.
39-
@test getfield.(bif_dia.γ.branch, :x) getfield.(bif_dia_BK.γ.branch, :x)
40-
@test getfield.(bif_dia.γ.branch, :param) getfield.(bif_dia_BK.γ.branch, :param)
48+
@test getfield.(bif_dia.γ.branch, :x) getfield.(bif_dia_BK.γ.branch, :x)
49+
@test getfield.(bif_dia.γ.branch, :param) getfield.(bif_dia_BK.γ.branch, :param)
4150
@test bif_dia.γ.specialpoint[1].x == bif_dia_BK.γ.specialpoint[1].x
4251
@test bif_dia.γ.specialpoint[1].param == bif_dia_BK.γ.specialpoint[1].param
4352
@test bif_dia.γ.specialpoint[1].type == bif_dia_BK.γ.specialpoint[1].type
4453
end
4554

4655
# Lotka–Volterra model, checks exact position of bifurcation variable and bifurcation points.
4756
# Checks using ODESystem input.
48-
let
57+
let
4958
# Creates a Lotka–Volterra model.
5059
@parameters α a b
5160
@variables t x(t) y(t) z(t)
5261
D = Differential(t)
53-
eqs = [D(x) ~ -x + a*y + x^2*y,
54-
D(y) ~ b - a*y - x^2*y]
62+
eqs = [D(x) ~ -x + a * y + x^2 * y,
63+
D(y) ~ b - a * y - x^2 * y]
5564
@named sys = ODESystem(eqs)
5665

5766
# Creates BifurcationProblem
58-
bprob = BifurcationProblem(sys, [x => 1.5, y => 1.0], [a => 0.1, b => 0.5], b; plot_var = x)
67+
bprob = BifurcationProblem(sys,
68+
[x => 1.5, y => 1.0],
69+
[a => 0.1, b => 0.5],
70+
b;
71+
plot_var = x)
5972

6073
# Computes bifurcation diagram.
6174
p_span = (0.0, 2.0)
6275
opt_newton = NewtonPar(tol = 1e-9, max_iterations = 2000)
63-
opts_br = ContinuationPar(dsmin = 0.001, dsmax = 0.05, ds = 0.01,
64-
max_steps = 100, nev = 2, newton_options = opt_newton,
65-
p_min = p_span[1], p_max = p_span[2],
66-
detect_bifurcation = 3, n_inversion = 4, tol_bisection_eigenvalue = 1e-8,
67-
dsmin_bisection = 1e-9)
76+
opts_br = ContinuationPar(dsmax = 0.05,
77+
max_steps = 500,
78+
newton_options = opt_newton,
79+
p_min = p_span[1],
80+
p_max = p_span[2],
81+
n_inversion = 4)
6882
bif_dia = bifurcationdiagram(bprob, PALC(), 2, (args...) -> opts_br; bothside = true)
6983

7084
# Tests that the diagram has the correct values (x = b)
7185
all([b.x b.param for b in bif_dia.γ.branch])
7286

7387
# Tests that we get two Hopf bifurcations at the correct positions.
74-
hopf_points = sort(getfield.(filter(sp -> sp.type == :hopf, bif_dia.γ.specialpoint), :x); by=x->x[1])
88+
hopf_points = sort(getfield.(filter(sp -> sp.type == :hopf, bif_dia.γ.specialpoint),
89+
:x);
90+
by = x -> x[1])
7591
@test length(hopf_points) == 2
7692
@test hopf_points[1] [0.41998733080424205, 1.5195495712453098]
7793
@test hopf_points[2] [0.7899715592573977, 1.0910379583813192]
@@ -80,38 +96,41 @@ end
8096
# Simple fold bifurcation model, checks exact position of bifurcation variable and bifurcation points.
8197
# Checks that default parameter values are accounted for.
8298
# Checks that observables (that depend on other observables, as in this case) are accounted for.
83-
let
99+
let
84100
# Creates model, and uses `structural_simplify` to generate observables.
85101
@parameters μ p=2
86102
@variables t x(t) y(t) z(t)
87103
D = Differential(t)
88104
eqs = [0 ~ μ - x^3 + 2x^2,
89-
0 ~ p*μ - y,
90-
0 ~ y - z]
105+
0 ~ p * μ - y,
106+
0 ~ y - z]
91107
@named nsys = NonlinearSystem(eqs, [x, y, z], [μ, p])
92108
nsys = structural_simplify(nsys)
93-
109+
94110
# Creates BifurcationProblem.
95111
bif_par = μ
96112
p_start ==> 1.0]
97113
u0_guess = [x => 1.0, y => 0.1, z => 0.1]
98-
plot_var = x;
99-
bprob = BifurcationProblem(nsys, u0_guess, p_start, bif_par; plot_var=plot_var)
100-
114+
plot_var = x
115+
bprob = BifurcationProblem(nsys, u0_guess, p_start, bif_par; plot_var = plot_var)
116+
101117
# Computes bifurcation diagram.
102118
p_span = (-4.3, 12.0)
103119
opt_newton = NewtonPar(tol = 1e-9, max_iterations = 20)
104-
opts_br = ContinuationPar(dsmin = 0.001, dsmax = 0.05, ds = 0.01,
105-
max_steps = 100, nev = 2, newton_options = opt_newton,
106-
p_min = p_span[1], p_max = p_span[2],
107-
detect_bifurcation = 3, n_inversion = 4, tol_bisection_eigenvalue = 1e-8, dsmin_bisection = 1e-9);
108-
bif_dia = bifurcationdiagram(bprob, PALC(), 2, (args...) -> opts_br; bothside=true)
109-
120+
opts_br = ContinuationPar(dsmax = 0.05,
121+
max_steps = 500,
122+
newton_options = opt_newton,
123+
p_min = p_span[1],
124+
p_max = p_span[2],
125+
n_inversion = 4)
126+
bif_dia = bifurcationdiagram(bprob, PALC(), 2, (args...) -> opts_br; bothside = true)
127+
110128
# Tests that the diagram has the correct values (x = b)
111-
all([b.x 2*b.param for b in bif_dia.γ.branch])
112-
129+
all([b.x 2 * b.param for b in bif_dia.γ.branch])
130+
113131
# Tests that we get two fold bifurcations at the correct positions.
114-
fold_points = sort(getfield.(filter(sp -> sp.type == :bp, bif_dia.γ.specialpoint), :param))
132+
fold_points = sort(getfield.(filter(sp -> sp.type == :bp, bif_dia.γ.specialpoint),
133+
:param))
115134
@test length(fold_points) == 2
116135
@test fold_points [-1.1851851706940317, -5.6734983580551894e-6] # test that they occur at the correct parameter values).
117-
end
136+
end

0 commit comments

Comments
 (0)