Skip to content

Commit f34faf0

Browse files
committed
plot fixes for matplotlib v2, fixed typo in grid search plot
1 parent e268a6d commit f34faf0

5 files changed

+34
-21
lines changed

mglearn/plot_cross_validation.py

+14-12
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,13 @@ def plot_group_kfold():
3030

3131
boxes = axes.barh(bottom=range(n_iter), width=[1 - 0.1] * n_iter,
3232
left=i * n_samples_per_fold, height=.6, color=colors,
33-
hatch="//", edgecolor="k")
33+
hatch="//", edgecolor="k", align='edge')
3434
for j in np.where(mask[:, i] == 0)[0]:
3535
boxes[j].set_hatch("")
3636

3737
axes.barh(bottom=[n_iter] * n_folds, width=[1 - 0.1] * n_folds,
3838
left=np.arange(n_folds) * n_samples_per_fold, height=.6,
39-
color="w", edgecolor='k')
39+
color="w", edgecolor='k', align="edge")
4040

4141
for i in range(12):
4242
axes.text((i + .5) * n_samples_per_fold, 3.5, "%d" %
@@ -82,7 +82,7 @@ def plot_shuffle_split():
8282

8383
boxes = axes.barh(bottom=range(n_iter), width=[1 - 0.1] * n_iter,
8484
left=i * n_samples_per_fold, height=.6, color=colors,
85-
hatch="//", edgecolor='k')
85+
hatch="//", edgecolor='k', align='edge')
8686
for j in np.where(mask[:, i] == 0)[0]:
8787
boxes[j].set_hatch("")
8888

@@ -118,11 +118,11 @@ def plot_stratified_cross_validation():
118118
colors[i] = "grey"
119119
axes.barh(bottom=range(n_folds), width=[n_samples_per_fold - 1] *
120120
n_folds, left=i * n_samples_per_fold, height=.6,
121-
color=colors, hatch="//", edgecolor='k')
121+
color=colors, hatch="//", edgecolor='k', align='edge')
122122

123123
axes.barh(bottom=[n_folds] * n_folds, width=[n_samples_per_fold - 1] *
124124
n_folds, left=np.arange(3) * n_samples_per_fold, height=.6,
125-
color="w", edgecolor='k')
125+
color="w", edgecolor='k', align='edge')
126126

127127
axes.invert_yaxis()
128128
axes.set_xlim(0, n_samples + 1)
@@ -155,23 +155,24 @@ def plot_stratified_cross_validation():
155155
test_bars = ax.barh(
156156
bottom=[i] * n_folds, width=[n_subsplit - 1] * n_folds,
157157
left=np.arange(n_folds) * n_samples_per_fold + i * n_subsplit,
158-
height=.6, color="grey", hatch="//", edgecolor='k')
158+
height=.6, color="grey", hatch="//", edgecolor='k', align='edge')
159159

160160
w = 2 * n_subsplit - 1
161161
ax.barh(bottom=[0] * n_folds, width=[w] * n_folds, left=np.arange(n_folds)
162162
* n_samples_per_fold + (0 + 1) * n_subsplit, height=.6, color="w",
163-
hatch="//", edgecolor='k')
163+
hatch="//", edgecolor='k', align='edge')
164164
ax.barh(bottom=[1] * (n_folds + 1), width=[w / 2., w, w, w / 2.],
165165
left=np.maximum(0, np.arange(n_folds + 1) * n_samples_per_fold -
166166
n_subsplit), height=.6, color="w", hatch="//",
167-
edgecolor='k')
167+
edgecolor='k', align='edge')
168168
training_bars = ax.barh(bottom=[2] * n_folds, width=[w] * n_folds,
169169
left=np.arange(n_folds) * n_samples_per_fold,
170-
height=.6, color="w", hatch="//", edgecolor='k')
170+
height=.6, color="w", hatch="//", edgecolor='k',
171+
align='edge')
171172

172173
ax.barh(bottom=[n_folds] * n_folds, width=[n_samples_per_fold - 1] *
173174
n_folds, left=np.arange(n_folds) * n_samples_per_fold, height=.6,
174-
color="w", edgecolor='k')
175+
color="w", edgecolor='k', align='edge')
175176

176177
for i in range(3):
177178
ax.text((i + .5) * n_samples_per_fold, 3.5, "Class %d" %
@@ -200,7 +201,7 @@ def plot_cross_validation():
200201
bars = plt.barh(
201202
bottom=range(n_folds), width=[n_samples_per_fold - 0.1] * n_folds,
202203
left=i * n_samples_per_fold, height=.6, color=colors, hatch="//",
203-
edgecolor='k')
204+
edgecolor='k', align='edge')
204205
axes.invert_yaxis()
205206
axes.set_xlim(0, n_samples + 1)
206207
plt.ylabel("CV iterations")
@@ -218,7 +219,8 @@ def plot_threefold_split():
218219
plt.figure(figsize=(15, 1))
219220
axis = plt.gca()
220221
bars = axis.barh([0, 0, 0], [11.9, 2.9, 4.9], left=[0, 12, 15], color=[
221-
'white', 'grey', 'grey'], hatch="//", edgecolor='k')
222+
'white', 'grey', 'grey'], hatch="//", edgecolor='k',
223+
align='edge')
222224
bars[2].set_hatch(r"")
223225
axis.set_yticks(())
224226
axis.set_frame_on(False)

mglearn/plot_grid_search.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def plot_cross_val_selection():
2323
plt.xlim(-1, len(results))
2424
plt.ylim(0, 1.1)
2525
for i, (_, row) in enumerate(results.iterrows()):
26-
scores = row[['test_split%d_test_score' % i for i in range(5)]]
26+
scores = row[['split%d_test_score' % i for i in range(5)]]
2727
marker_cv, = plt.plot([i] * 5, scores, '^', c='gray', markersize=5,
2828
alpha=.5)
2929
marker_mean, = plt.plot(i, row.mean_test_score, 'v', c='none', alpha=1,

mglearn/plot_improper_preprocessing.py

+11-6
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ def plot_improper_processing():
1414

1515
for axis in axes:
1616
bars = axis.barh([0, 0, 0], [11.9, 2.9, 4.9], left=[0, 12, 15],
17-
color=['white', 'grey', 'grey'], hatch="//")
17+
color=['white', 'grey', 'grey'], hatch="//",
18+
align='edge', edgecolor='k')
1819
bars[2].set_hatch(r"")
1920
axis.set_yticks(())
2021
axis.set_frame_on(False)
@@ -46,18 +47,22 @@ def plot_proper_processing():
4647

4748
for axis in axes:
4849
bars = axis.barh([0, 0, 0], [11.9, 2.9, 4.9],
49-
left=[0, 12, 15], color=['white', 'grey', 'grey'], hatch="//")
50+
left=[0, 12, 15], color=['white', 'grey', 'grey'],
51+
hatch="//", align='edge', edgecolor='k')
5052
bars[2].set_hatch(r"")
5153
axis.set_yticks(())
5254
axis.set_frame_on(False)
5355
axis.set_ylim(-.1, 4.5)
5456
axis.set_xlim(-0.1, 20.1)
5557
axis.set_xticks(())
5658
axis.tick_params(length=0, labeltop=True, labelbottom=False)
57-
axis.text(6, -.3, "training folds", fontdict={'fontsize': 14}, horizontalalignment="center")
58-
axis.text(13.5, -.3, "validation fold", fontdict={'fontsize': 14}, horizontalalignment="center")
59-
axis.text(17.5, -.3, "test set", fontdict={'fontsize': 14}, horizontalalignment="center")
60-
59+
axis.text(6, -.3, "training folds", fontdict={'fontsize': 14},
60+
horizontalalignment="center")
61+
axis.text(13.5, -.3, "validation fold", fontdict={'fontsize': 14},
62+
horizontalalignment="center")
63+
axis.text(17.5, -.3, "test set", fontdict={'fontsize': 14},
64+
horizontalalignment="center")
65+
6166
make_bracket("scaler fit", (6, 1.3), (6, 2.), 12, axes[0])
6267
make_bracket("SVC fit", (6, 3), (6, 4), 12, axes[0])
6368
make_bracket("SVC predict", (13.4, 3), (13.4, 4), 2.5, axes[0])

mglearn/plot_nmf.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,12 @@ def plot_nmf_illustration():
1919

2020
fig, axes = plt.subplots(1, 2, figsize=(15, 5))
2121

22-
axes[0].scatter(X_blob[:, 0], X_blob[:, 1], c=X_nmf[:, 0], linewidths=0, s=60, cmap='viridis')
22+
axes[0].scatter(X_blob[:, 0], X_blob[:, 1], c=X_nmf[:, 0], linewidths=0,
23+
s=60, cmap='viridis')
2324
axes[0].set_xlabel("feature 1")
2425
axes[0].set_ylabel("feature 2")
26+
axes[0].set_xlim(0, 12)
27+
axes[0].set_ylim(0, 12)
2528
axes[0].arrow(0, 0, nmf.components_[0, 0], nmf.components_[0, 1], width=.1,
2629
head_width=.3, color='k')
2730
axes[0].arrow(0, 0, nmf.components_[1, 0], nmf.components_[1, 1], width=.1,
@@ -37,6 +40,8 @@ def plot_nmf_illustration():
3740
s=60, cmap='viridis')
3841
axes[1].set_xlabel("feature 1")
3942
axes[1].set_ylabel("feature 2")
43+
axes[1].set_xlim(0, 12)
44+
axes[1].set_ylim(0, 12)
4045
axes[1].arrow(0, 0, nmf.components_[0, 0], nmf.components_[0, 1], width=.1,
4146
head_width=.3, color='k')
4247

mglearn/tools.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def visualize_coefficients(coefficients, feature_names, n_top_features=25):
2828
# this is not a row or column vector
2929
raise ValueError("coeffients must be 1d array or column vector, got"
3030
" shape {}".format(coefficients.shape))
31+
coefficients = coefficients.ravel()
3132

3233
if len(coefficients) != len(feature_names):
3334
raise ValueError("Number of coefficients {} doesn't match number of"
@@ -59,7 +60,7 @@ def heatmap(values, xlabel, ylabel, xticklabels, yticklabels, cmap=None,
5960
if ax is None:
6061
ax = plt.gca()
6162
# plot the mean cross-validation scores
62-
img = ax.pcolor(values, cmap=cmap, vmin=None, vmax=None)
63+
img = ax.pcolor(values, cmap=cmap, vmin=vmin, vmax=vmax)
6364
img.update_scalarmappable()
6465
ax.set_xlabel(xlabel)
6566
ax.set_ylabel(ylabel)

0 commit comments

Comments
 (0)