Skip to content

Commit d8ee8e3

Browse files
committed
minor plotting fixes, fixes for scikit-learn master
1 parent 1351c27 commit d8ee8e3

7 files changed

+47
-24
lines changed

mglearn/datasets.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def load_citibike():
3838
data_mine['one'] = 1
3939
data_mine['starttime'] = pd.to_datetime(data_mine.starttime)
4040
data_starttime = data_mine.set_index("starttime")
41-
data_resampled = data_starttime.resample("3h", how="sum").fillna(0)
41+
data_resampled = data_starttime.resample("3h").sum().fillna(0)
4242
return data_resampled.one
4343

4444

mglearn/plot_cross_validation.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def plot_label_kfold():
1717
n_iter = 3
1818
n_samples_per_fold = 1
1919

20-
cv = LabelKFold(n_folds=3)
20+
cv = LabelKFold(n_splits=3)
2121
mask = np.zeros((n_iter, n_samples))
2222
for i, (train, test) in enumerate(cv.split(range(12), labels=labels)):
2323
mask[i, train] = 1
@@ -45,8 +45,8 @@ def plot_label_kfold():
4545
axes.set_xlabel("Data points")
4646
axes.set_xticks(np.arange(n_samples) + .5)
4747
axes.set_xticklabels(np.arange(1, n_samples + 1))
48-
axes.set_yticks(np.arange(n_iter) + .3)
49-
axes.set_yticklabels(["Split %d" % x for x in range(1, n_iter + 1)] + ["labels"]);
48+
axes.set_yticks(np.arange(n_iter + 1) + .3)
49+
axes.set_yticklabels(["Split %d" % x for x in range(1, n_iter + 1)] + ["Group"]);
5050
plt.legend([boxes[0], boxes[1]], ["Training set", "Test set"], loc=(1, .3));
5151
plt.tight_layout()
5252

@@ -55,7 +55,7 @@ def plot_label_kfold():
5555
def plot_shuffle_split():
5656
from sklearn.model_selection import ShuffleSplit
5757
plt.figure(figsize=(10, 2))
58-
plt.title("ShuffleSplit with 10 points, train_size=5, test_size=2, n_iter=4")
58+
plt.title("ShuffleSplit with 10 points, train_size=5, test_size=2, n_splits=4")
5959

6060
axes = plt.gca()
6161
axes.set_frame_on(False)
@@ -65,7 +65,7 @@ def plot_shuffle_split():
6565
n_iter = 4
6666
n_samples_per_fold = 1
6767

68-
ss = ShuffleSplit(n_iter=4, train_size=5, test_size=2, random_state=43)
68+
ss = ShuffleSplit(n_splits=4, train_size=5, test_size=2, random_state=43)
6969
mask = np.zeros((n_iter, n_samples))
7070
for i, (train, test) in enumerate(ss.split(range(10))):
7171
mask[i, train] = 1

mglearn/plot_helpers.py

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from matplotlib.colors import ListedColormap, colorConverter, LinearSegmentedColormap
55

66

7+
cm_cycle = ListedColormap(['#0000aa', '#ff2020', '#50ff50', 'c', '#fff000'])
78
cm3 = ListedColormap(['#0000aa', '#ff2020', '#50ff50'])
89
cm2 = ListedColormap(['#0000aa', '#ff2020'])
910

mglearn/plot_interactive_tree.py

+15-8
Original file line numberDiff line numberDiff line change
@@ -39,17 +39,24 @@ def tree_image(tree, fout=None):
3939

4040

4141
def plot_tree_progressive():
42-
fig, axes = plt.subplots(4, 2, figsize=(15, 25), subplot_kw={'xticks': (), 'yticks': ()})
4342
X, y = make_moons(n_samples=100, noise=0.25, random_state=3)
43+
plt.figure()
44+
ax = plt.gca()
45+
discrete_scatter(X[:, 0], X[:, 1], y, ax=ax)
46+
ax.set_xticks(())
47+
ax.set_yticks(())
48+
49+
axes = []
50+
for i in range(3):
51+
fig, ax = plt.subplots(1, 2, figsize=(12, 4),
52+
subplot_kw={'xticks': (), 'yticks': ()})
53+
axes.append(ax)
54+
axes = np.array(axes)
4455

4556
for i, max_depth in enumerate([1, 2, 9]):
46-
tree = plot_tree(X, y, max_depth=max_depth, ax=axes[i + 1, 0])
47-
axes[i + 1, 1].imshow(tree_image(tree))
48-
axes[i + 1, 1].set_axis_off()
49-
axes[0, 1].set_visible(False)
50-
for ax in axes[:, 0]:
51-
discrete_scatter(X[:, 0], X[:, 1], y, ax=ax)
52-
ax.legend(loc="best")
57+
tree = plot_tree(X, y, max_depth=max_depth, ax=axes[i, 0])
58+
axes[i, 1].imshow(tree_image(tree))
59+
axes[i, 1].set_axis_off()
5360

5461

5562
def plot_tree_partition(X, y, tree, ax=None):

mglearn/plot_metrics.py

+7-8
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66

77

88
def plot_confusion_matrix_illustration():
9+
plt.figure(figsize=(8, 8))
910
confusion = np.array([[401, 2], [8, 39]])
10-
plt.title("confusion_matrix")
11-
plt.text(0.45, .6, confusion[0, 0], size=70, horizontalalignment='right')
12-
plt.text(0.45, .1, confusion[1, 0], size=70, horizontalalignment='right')
13-
plt.text(.95, .6, confusion[0, 1], size=70, horizontalalignment='right')
14-
plt.text(.95, 0.1, confusion[1, 1], size=70, horizontalalignment='right')
15-
plt.xticks([.25, .75], ["predicted 'not 9'", "predicted '9'"], size=20)
16-
plt.yticks([.25, .75], ["true '9'", "true 'not 9'"], size=20)
11+
plt.text(0.40, .7, confusion[0, 0], size=70, horizontalalignment='right')
12+
plt.text(0.40, .2, confusion[1, 0], size=70, horizontalalignment='right')
13+
plt.text(.90, .7, confusion[0, 1], size=70, horizontalalignment='right')
14+
plt.text(.90, 0.2, confusion[1, 1], size=70, horizontalalignment='right')
15+
plt.xticks([.25, .75], ["predicted 'not nine'", "predicted 'nine'"], size=20)
16+
plt.yticks([.25, .75], ["true 'nine'", "true 'not nine'"], size=20)
1717
plt.plot([.5, .5], [0, 1], '--', c='k')
1818
plt.plot([0, 1], [.5, .5], '--', c='k')
1919

@@ -22,7 +22,6 @@ def plot_confusion_matrix_illustration():
2222

2323

2424
def plot_binary_confusion_matrix():
25-
plt.title("binary_confusion_matrix_tp_fp")
2625
plt.text(0.45, .6, "TN", size=100, horizontalalignment='right')
2726
plt.text(0.45, .1, "FN", size=100, horizontalalignment='right')
2827
plt.text(.95, .6, "FP", size=100, horizontalalignment='right')

mglearn/plot_tree_nonmonotonous.py

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ def plot_tree_not_monotone():
1212
y = y % 2
1313
plt.figure()
1414
discrete_scatter(X[:, 0], X[:, 1], y)
15+
plt.legend(["Class 0", "Class 1"], loc="best")
1516

1617
# learn a decision tree model
1718
tree = DecisionTreeClassifier(random_state=0).fit(X, y)

preamble.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,27 @@
1-
from IPython.display import set_matplotlib_formats
1+
from IPython.display import set_matplotlib_formats, display
2+
import pandas as pd
23
import numpy as np
34
import matplotlib.pyplot as plt
45
import mglearn
6+
from cycler import cycler
57

8+
#set_matplotlib_formats('png', 'svg')
69
set_matplotlib_formats('pdf', 'png')
710
plt.rcParams['savefig.dpi'] = 300
11+
plt.rcParams['image.cmap'] = "viridis"
812
plt.rcParams['image.interpolation'] = "none"
913
plt.rcParams['savefig.bbox'] = "tight"
10-
np.set_printoptions(precision=3)
14+
plt.rcParams['lines.linewidth'] = 2
15+
plt.rcParams['legend.numpoints'] = 1
16+
plt.rc('axes', prop_cycle=(cycler('color', mglearn.plot_helpers.cm_cycle.colors) +
17+
cycler('linestyle', ['-', '--', ':',
18+
'-.', '--'])
19+
)
20+
)
21+
22+
np.set_printoptions(precision=3, suppress=True)
23+
24+
pd.set_option("display.max_columns", 8)
25+
pd.set_option('precision', 2)
1126

1227
np, mglearn

0 commit comments

Comments
 (0)