Skip to content

Commit 1357bc0

Browse files
committed
matplotlib barh bottom->y, add return_train_score
1 parent 65cf6ed commit 1357bc0

File tree

2 files changed

+12
-13
lines changed

2 files changed

+12
-13
lines changed

mglearn/plot_cross_validation.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,13 @@ def plot_group_kfold():
2828
colors = ["grey" if x == 2 else "white" for x in mask[:, i]]
2929
# not selected has no hatch
3030

31-
boxes = axes.barh(bottom=range(n_iter), width=[1 - 0.1] * n_iter,
31+
boxes = axes.barh(y=range(n_iter), width=[1 - 0.1] * n_iter,
3232
left=i * n_samples_per_fold, height=.6, color=colors,
3333
hatch="//", edgecolor="k", align='edge')
3434
for j in np.where(mask[:, i] == 0)[0]:
3535
boxes[j].set_hatch("")
3636

37-
axes.barh(bottom=[n_iter] * n_folds, width=[1 - 0.1] * n_folds,
37+
axes.barh(y=[n_iter] * n_folds, width=[1 - 0.1] * n_folds,
3838
left=np.arange(n_folds) * n_samples_per_fold, height=.6,
3939
color="w", edgecolor='k', align="edge")
4040

@@ -80,7 +80,7 @@ def plot_shuffle_split():
8080
colors = ["grey" if x == 2 else "white" for x in mask[:, i]]
8181
# not selected has no hatch
8282

83-
boxes = axes.barh(bottom=range(n_iter), width=[1 - 0.1] * n_iter,
83+
boxes = axes.barh(y=range(n_iter), width=[1 - 0.1] * n_iter,
8484
left=i * n_samples_per_fold, height=.6, color=colors,
8585
hatch="//", edgecolor='k', align='edge')
8686
for j in np.where(mask[:, i] == 0)[0]:
@@ -116,11 +116,11 @@ def plot_stratified_cross_validation():
116116
for i in range(n_folds):
117117
colors = ["w"] * n_folds
118118
colors[i] = "grey"
119-
axes.barh(bottom=range(n_folds), width=[n_samples_per_fold - 1] *
119+
axes.barh(y=range(n_folds), width=[n_samples_per_fold - 1] *
120120
n_folds, left=i * n_samples_per_fold, height=.6,
121121
color=colors, hatch="//", edgecolor='k', align='edge')
122122

123-
axes.barh(bottom=[n_folds] * n_folds, width=[n_samples_per_fold - 1] *
123+
axes.barh(y=[n_folds] * n_folds, width=[n_samples_per_fold - 1] *
124124
n_folds, left=np.arange(3) * n_samples_per_fold, height=.6,
125125
color="w", edgecolor='k', align='edge')
126126

@@ -153,24 +153,24 @@ def plot_stratified_cross_validation():
153153
n_subsplit = n_samples_per_fold / 3.
154154
for i in range(n_folds):
155155
test_bars = ax.barh(
156-
bottom=[i] * n_folds, width=[n_subsplit - 1] * n_folds,
156+
y=[i] * n_folds, width=[n_subsplit - 1] * n_folds,
157157
left=np.arange(n_folds) * n_samples_per_fold + i * n_subsplit,
158158
height=.6, color="grey", hatch="//", edgecolor='k', align='edge')
159159

160160
w = 2 * n_subsplit - 1
161-
ax.barh(bottom=[0] * n_folds, width=[w] * n_folds, left=np.arange(n_folds)
161+
ax.barh(y=[0] * n_folds, width=[w] * n_folds, left=np.arange(n_folds)
162162
* n_samples_per_fold + (0 + 1) * n_subsplit, height=.6, color="w",
163163
hatch="//", edgecolor='k', align='edge')
164-
ax.barh(bottom=[1] * (n_folds + 1), width=[w / 2., w, w, w / 2.],
164+
ax.barh(y=[1] * (n_folds + 1), width=[w / 2., w, w, w / 2.],
165165
left=np.maximum(0, np.arange(n_folds + 1) * n_samples_per_fold -
166166
n_subsplit), height=.6, color="w", hatch="//",
167167
edgecolor='k', align='edge')
168-
training_bars = ax.barh(bottom=[2] * n_folds, width=[w] * n_folds,
168+
training_bars = ax.barh(y=[2] * n_folds, width=[w] * n_folds,
169169
left=np.arange(n_folds) * n_samples_per_fold,
170170
height=.6, color="w", hatch="//", edgecolor='k',
171171
align='edge')
172172

173-
ax.barh(bottom=[n_folds] * n_folds, width=[n_samples_per_fold - 1] *
173+
ax.barh(y=[n_folds] * n_folds, width=[n_samples_per_fold - 1] *
174174
n_folds, left=np.arange(n_folds) * n_samples_per_fold, height=.6,
175175
color="w", edgecolor='k', align='edge')
176176

@@ -199,7 +199,7 @@ def plot_cross_validation():
199199
colors = ["w"] * n_folds
200200
colors[i] = "grey"
201201
bars = plt.barh(
202-
bottom=range(n_folds), width=[n_samples_per_fold - 0.1] * n_folds,
202+
y=range(n_folds), width=[n_samples_per_fold - 0.1] * n_folds,
203203
left=i * n_samples_per_fold, height=.6, color=colors, hatch="//",
204204
edgecolor='k', align='edge')
205205
axes.invert_yaxis()

mglearn/plot_grid_search.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@ def plot_cross_val_selection():
1414

1515
param_grid = {'C': [0.001, 0.01, 0.1, 1, 10, 100],
1616
'gamma': [0.001, 0.01, 0.1, 1, 10, 100]}
17-
grid_search = GridSearchCV(SVC(), param_grid, cv=5,
18-
return_train_score=True)
17+
grid_search = GridSearchCV(SVC(), param_grid, cv=5)
1918
grid_search.fit(X_trainval, y_trainval)
2019
results = pd.DataFrame(grid_search.cv_results_)[15:]
2120

0 commit comments

Comments
 (0)