Skip to content

Commit f8a5bf9

Browse files
committed
fixes for matplotlib v2 in cross-validation, also some minor fixes otherwise
1 parent f188c47 commit f8a5bf9

4 files changed

+91
-58
lines changed

05-model-evaluation-and-improvement.ipynb

+2-2
Original file line numberDiff line numberDiff line change
@@ -2832,9 +2832,9 @@
28322832
"metadata": {
28332833
"anaconda-cloud": {},
28342834
"kernelspec": {
2835-
"display_name": "Python [Root]",
2835+
"display_name": "Python [conda root]",
28362836
"language": "python",
2837-
"name": "Python [Root]"
2837+
"name": "conda-root-py"
28382838
},
28392839
"language_info": {
28402840
"codemirror_mode": {

07-working-with-text-data.ipynb

+2-2
Original file line numberDiff line numberDiff line change
@@ -1416,9 +1416,9 @@
14161416
"metadata": {
14171417
"anaconda-cloud": {},
14181418
"kernelspec": {
1419-
"display_name": "Python [Root]",
1419+
"display_name": "Python [conda root]",
14201420
"language": "python",
1421-
"name": "Python [Root]"
1421+
"name": "conda-root-py"
14221422
},
14231423
"language_info": {
14241424
"codemirror_mode": {

mglearn/plot_cross_validation.py

+84-51
Original file line numberDiff line numberDiff line change
@@ -23,39 +23,43 @@ def plot_group_kfold():
2323
mask[i, train] = 1
2424
mask[i, test] = 2
2525

26-
2726
for i in range(n_folds):
2827
# test is grey
29-
colors = ["grey" if x == 2 else "white" for x in mask[:, i]]
28+
colors = ["grey" if x == 2 else "white" for x in mask[:, i]]
3029
# not selected has no hatch
31-
32-
boxes = axes.barh(bottom=range(n_iter), width=[1 - 0.1] * n_iter, left=i * n_samples_per_fold, height=.6, color=colors, hatch="//")
30+
31+
boxes = axes.barh(bottom=range(n_iter), width=[1 - 0.1] * n_iter,
32+
left=i * n_samples_per_fold, height=.6, color=colors,
33+
hatch="//", edgecolor="k")
3334
for j in np.where(mask[:, i] == 0)[0]:
3435
boxes[j].set_hatch("")
35-
36-
axes.barh(bottom=[n_iter] * n_folds, width=[1 - 0.1] * n_folds, left=np.arange(n_folds) * n_samples_per_fold, height=.6, color="w")
36+
37+
axes.barh(bottom=[n_iter] * n_folds, width=[1 - 0.1] * n_folds,
38+
left=np.arange(n_folds) * n_samples_per_fold, height=.6,
39+
color="w", edgecolor='k')
3740

3841
for i in range(12):
39-
axes.text((i + .5) * n_samples_per_fold, 3.5, "%d" % groups[i], horizontalalignment="center")
40-
#ax.set_ylim(4, -0.1)
41-
42+
axes.text((i + .5) * n_samples_per_fold, 3.5, "%d" %
43+
groups[i], horizontalalignment="center")
44+
4245
axes.invert_yaxis()
4346
axes.set_xlim(0, n_samples + 1)
4447
axes.set_ylabel("CV iterations")
4548
axes.set_xlabel("Data points")
4649
axes.set_xticks(np.arange(n_samples) + .5)
4750
axes.set_xticklabels(np.arange(1, n_samples + 1))
4851
axes.set_yticks(np.arange(n_iter + 1) + .3)
49-
axes.set_yticklabels(["Split %d" % x for x in range(1, n_iter + 1)] + ["Group"]);
50-
plt.legend([boxes[0], boxes[1]], ["Training set", "Test set"], loc=(1, .3));
52+
axes.set_yticklabels(
53+
["Split %d" % x for x in range(1, n_iter + 1)] + ["Group"])
54+
plt.legend([boxes[0], boxes[1]], ["Training set", "Test set"], loc=(1, .3))
5155
plt.tight_layout()
5256

5357

54-
5558
def plot_shuffle_split():
5659
from sklearn.model_selection import ShuffleSplit
5760
plt.figure(figsize=(10, 2))
58-
plt.title("ShuffleSplit with 10 points, train_size=5, test_size=2, n_splits=4")
61+
plt.title("ShuffleSplit with 10 points"
62+
", train_size=5, test_size=2, n_splits=4")
5963

6064
axes = plt.gca()
6165
axes.set_frame_on(False)
@@ -71,13 +75,14 @@ def plot_shuffle_split():
7175
mask[i, train] = 1
7276
mask[i, test] = 2
7377

74-
7578
for i in range(n_folds):
7679
# test is grey
77-
colors = ["grey" if x == 2 else "white" for x in mask[:, i]]
80+
colors = ["grey" if x == 2 else "white" for x in mask[:, i]]
7881
# not selected has no hatch
79-
80-
boxes = axes.barh(bottom=range(n_iter), width=[1 - 0.1] * n_iter, left=i * n_samples_per_fold, height=.6, color=colors, hatch="//")
82+
83+
boxes = axes.barh(bottom=range(n_iter), width=[1 - 0.1] * n_iter,
84+
left=i * n_samples_per_fold, height=.6, color=colors,
85+
hatch="//", edgecolor='k')
8186
for j in np.where(mask[:, i] == 0)[0]:
8287
boxes[j].set_hatch("")
8388

@@ -88,17 +93,16 @@ def plot_shuffle_split():
8893
axes.set_xticks(np.arange(n_samples) + .5)
8994
axes.set_xticklabels(np.arange(1, n_samples + 1))
9095
axes.set_yticks(np.arange(n_iter) + .3)
91-
axes.set_yticklabels(["Split %d" % x for x in range(1, n_iter + 1)]);
96+
axes.set_yticklabels(["Split %d" % x for x in range(1, n_iter + 1)])
9297
# legend hacked for this random state
93-
plt.legend([boxes[1], boxes[0], boxes[2]], ["Training set", "Test set", "Not selected"], loc=(1, .3));
98+
plt.legend([boxes[1], boxes[0], boxes[2]], [
99+
"Training set", "Test set", "Not selected"], loc=(1, .3))
94100
plt.tight_layout()
95-
plt.savefig("images/06_shuffle_split.png")
96-
plt.close()
97101

98102

99103
def plot_stratified_cross_validation():
100104
fig, both_axes = plt.subplots(2, 1, figsize=(12, 5))
101-
#plt.title("cross_validation_not_stratified")
105+
# plt.title("cross_validation_not_stratified")
102106
axes = both_axes[0]
103107
axes.set_title("Standard cross-validation with sorted class labels")
104108

@@ -109,25 +113,30 @@ def plot_stratified_cross_validation():
109113

110114
n_samples_per_fold = n_samples / float(n_folds)
111115

112-
113116
for i in range(n_folds):
114117
colors = ["w"] * n_folds
115118
colors[i] = "grey"
116-
axes.barh(bottom=range(n_folds), width=[n_samples_per_fold - 1] * n_folds, left=i * n_samples_per_fold, height=.6, color=colors, hatch="//")
117-
118-
axes.barh(bottom=[n_folds] * n_folds, width=[n_samples_per_fold - 1] * n_folds, left=np.arange(3) * n_samples_per_fold, height=.6, color="w")
119+
axes.barh(bottom=range(n_folds), width=[n_samples_per_fold - 1] *
120+
n_folds, left=i * n_samples_per_fold, height=.6,
121+
color=colors, hatch="//", edgecolor='k')
122+
123+
axes.barh(bottom=[n_folds] * n_folds, width=[n_samples_per_fold - 1] *
124+
n_folds, left=np.arange(3) * n_samples_per_fold, height=.6,
125+
color="w", edgecolor='k')
119126

120127
axes.invert_yaxis()
121128
axes.set_xlim(0, n_samples + 1)
122129
axes.set_ylabel("CV iterations")
123130
axes.set_xlabel("Data points")
124-
axes.set_xticks(np.arange(n_samples_per_fold / 2., n_samples, n_samples_per_fold))
131+
axes.set_xticks(np.arange(n_samples_per_fold / 2.,
132+
n_samples, n_samples_per_fold))
125133
axes.set_xticklabels(["Fold %d" % x for x in range(1, n_folds + 1)])
126134
axes.set_yticks(np.arange(n_folds + 1) + .3)
127-
axes.set_yticklabels(["Split %d" % x for x in range(1, n_folds + 1)] + ["Class label"])
135+
axes.set_yticklabels(
136+
["Split %d" % x for x in range(1, n_folds + 1)] + ["Class label"])
128137
for i in range(3):
129-
axes.text((i + .5) * n_samples_per_fold, 3.5, "Class %d" % i, horizontalalignment="center")
130-
138+
axes.text((i + .5) * n_samples_per_fold, 3.5, "Class %d" %
139+
i, horizontalalignment="center")
131140

132141
ax = both_axes[1]
133142
ax.set_title("Stratified Cross-validation")
@@ -138,24 +147,38 @@ def plot_stratified_cross_validation():
138147
ax.set_xlabel("Data points")
139148

140149
ax.set_yticks(np.arange(n_folds + 1) + .3)
141-
ax.set_yticklabels(["Split %d" % x for x in range(1, n_folds + 1)] + ["Class label"]);
150+
ax.set_yticklabels(
151+
["Split %d" % x for x in range(1, n_folds + 1)] + ["Class label"])
142152

143153
n_subsplit = n_samples_per_fold / 3.
144154
for i in range(n_folds):
145-
test_bars = ax.barh(bottom=[i] * n_folds, width=[n_subsplit - 1] * n_folds, left=np.arange(n_folds) * n_samples_per_fold + i * n_subsplit, height=.6, color="grey", hatch="//")
155+
test_bars = ax.barh(
156+
bottom=[i] * n_folds, width=[n_subsplit - 1] * n_folds,
157+
left=np.arange(n_folds) * n_samples_per_fold + i * n_subsplit,
158+
height=.6, color="grey", hatch="//", edgecolor='k')
146159

147160
w = 2 * n_subsplit - 1
148-
ax.barh(bottom=[0] * n_folds, width=[w] * n_folds, left=np.arange(n_folds) * n_samples_per_fold + (0 + 1) * n_subsplit, height=.6, color="w", hatch="//")
149-
ax.barh(bottom=[1] * (n_folds + 1), width=[w / 2., w, w, w / 2.], left=np.maximum(0, np.arange(n_folds + 1) * n_samples_per_fold - n_subsplit), height=.6, color="w", hatch="//")
150-
training_bars = ax.barh(bottom=[2] * n_folds, width=[w] * n_folds, left=np.arange(n_folds) * n_samples_per_fold , height=.6, color="w", hatch="//")
151-
152-
153-
ax.barh(bottom=[n_folds] * n_folds, width=[n_samples_per_fold - 1] * n_folds, left=np.arange(n_folds) * n_samples_per_fold, height=.6, color="w")
161+
ax.barh(bottom=[0] * n_folds, width=[w] * n_folds, left=np.arange(n_folds)
162+
* n_samples_per_fold + (0 + 1) * n_subsplit, height=.6, color="w",
163+
hatch="//", edgecolor='k')
164+
ax.barh(bottom=[1] * (n_folds + 1), width=[w / 2., w, w, w / 2.],
165+
left=np.maximum(0, np.arange(n_folds + 1) * n_samples_per_fold -
166+
n_subsplit), height=.6, color="w", hatch="//",
167+
edgecolor='k')
168+
training_bars = ax.barh(bottom=[2] * n_folds, width=[w] * n_folds,
169+
left=np.arange(n_folds) * n_samples_per_fold,
170+
height=.6, color="w", hatch="//", edgecolor='k')
171+
172+
ax.barh(bottom=[n_folds] * n_folds, width=[n_samples_per_fold - 1] *
173+
n_folds, left=np.arange(n_folds) * n_samples_per_fold, height=.6,
174+
color="w", edgecolor='k')
154175

155176
for i in range(3):
156-
ax.text((i + .5) * n_samples_per_fold, 3.5, "Class %d" % i, horizontalalignment="center")
177+
ax.text((i + .5) * n_samples_per_fold, 3.5, "Class %d" %
178+
i, horizontalalignment="center")
157179
ax.set_ylim(4, -0.1)
158-
plt.legend([training_bars[0], test_bars[0]], ['Training data', 'Test data'], loc=(1.05, 1), frameon=False);
180+
plt.legend([training_bars[0], test_bars[0]], [
181+
'Training data', 'Test data'], loc=(1.05, 1), frameon=False)
159182

160183
fig.tight_layout()
161184

@@ -171,33 +194,43 @@ def plot_cross_validation():
171194

172195
n_samples_per_fold = n_samples / float(n_folds)
173196

174-
175197
for i in range(n_folds):
176198
colors = ["w"] * n_folds
177199
colors[i] = "grey"
178-
bars = plt.barh(bottom=range(n_folds), width=[n_samples_per_fold - 0.1] * n_folds,
179-
left=i * n_samples_per_fold, height=.6, color=colors, hatch="//")
200+
bars = plt.barh(
201+
bottom=range(n_folds), width=[n_samples_per_fold - 0.1] * n_folds,
202+
left=i * n_samples_per_fold, height=.6, color=colors, hatch="//",
203+
edgecolor='k')
180204
axes.invert_yaxis()
181205
axes.set_xlim(0, n_samples + 1)
182206
plt.ylabel("CV iterations")
183207
plt.xlabel("Data points")
184-
plt.xticks(np.arange(n_samples_per_fold / 2., n_samples, n_samples_per_fold), ["Fold %d" % x for x in range(1, n_folds + 1)])
185-
plt.yticks(np.arange(n_folds) + .3, ["Split %d" % x for x in range(1, n_folds + 1)])
186-
plt.legend([bars[0], bars[4]], ['Training data', 'Test data'], loc=(1.05, 0.4), frameon=False);
208+
plt.xticks(np.arange(n_samples_per_fold / 2., n_samples,
209+
n_samples_per_fold),
210+
["Fold %d" % x for x in range(1, n_folds + 1)])
211+
plt.yticks(np.arange(n_folds) + .3,
212+
["Split %d" % x for x in range(1, n_folds + 1)])
213+
plt.legend([bars[0], bars[4]], ['Training data', 'Test data'],
214+
loc=(1.05, 0.4), frameon=False)
187215

188216

189217
def plot_threefold_split():
190218
plt.figure(figsize=(15, 1))
191219
axis = plt.gca()
192-
bars = axis.barh([0, 0, 0], [11.9, 2.9, 4.9], left=[0, 12, 15], color=['white', 'grey', 'grey'], hatch="//")
220+
bars = axis.barh([0, 0, 0], [11.9, 2.9, 4.9], left=[0, 12, 15], color=[
221+
'white', 'grey', 'grey'], hatch="//", edgecolor='k')
193222
bars[2].set_hatch(r"")
194223
axis.set_yticks(())
195224
axis.set_frame_on(False)
196225
axis.set_ylim(-.1, .8)
197226
axis.set_xlim(-0.1, 20.1)
198227
axis.set_xticks([6, 13.3, 17.5])
199-
axis.set_xticklabels(["training set", "validation set", "test set"], fontdict={'fontsize': 20});
228+
axis.set_xticklabels(["training set", "validation set",
229+
"test set"], fontdict={'fontsize': 20})
200230
axis.tick_params(length=0, labeltop=True, labelbottom=False)
201-
axis.text(6, -.3, "Model fitting", fontdict={'fontsize': 13}, horizontalalignment="center")
202-
axis.text(13.3, -.3, "Parameter selection", fontdict={'fontsize': 13}, horizontalalignment="center")
203-
axis.text(17.5, -.3, "Evaluation", fontdict={'fontsize': 13}, horizontalalignment="center")
231+
axis.text(6, -.3, "Model fitting",
232+
fontdict={'fontsize': 13}, horizontalalignment="center")
233+
axis.text(13.3, -.3, "Parameter selection",
234+
fontdict={'fontsize': 13}, horizontalalignment="center")
235+
axis.text(17.5, -.3, "Evaluation",
236+
fontdict={'fontsize': 13}, horizontalalignment="center")

mglearn/plot_grid_search.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@ def plot_cross_val_selection():
2323
plt.xlim(-1, len(results))
2424
plt.ylim(0, 1.1)
2525
for i, (_, row) in enumerate(results.iterrows()):
26-
scores = row[['test_split%d_score' % i for i in range(5)]]
26+
scores = row[['test_split%d_test_score' % i for i in range(5)]]
2727
marker_cv, = plt.plot([i] * 5, scores, '^', c='gray', markersize=5,
2828
alpha=.5)
2929
marker_mean, = plt.plot(i, row.mean_test_score, 'v', c='none', alpha=1,
30-
markersize=10)
30+
markersize=10, markeredgecolor='k')
3131
if i == best:
3232
marker_best, = plt.plot(i, row.mean_test_score, 'o', c='red',
3333
fillstyle="none", alpha=1, markersize=20,
@@ -44,7 +44,7 @@ def plot_cross_val_selection():
4444

4545

4646
def plot_grid_search_overview():
47-
plt.figure(figsize=(10, 3))
47+
plt.figure(figsize=(10, 3), dpi=70)
4848
axes = plt.gca()
4949
axes.yaxis.set_visible(False)
5050
axes.xaxis.set_visible(False)

0 commit comments

Comments
 (0)