Skip to content

Commit a22861d

Browse files
authored
Update plot_cross_validation.py
I think `LabelKFold` got renamed to `GroupKFold` in Scikit-Learn `0.18`
1 parent 3b524cf commit a22861d

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

mglearn/plot_cross_validation.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33

44

55
def plot_label_kfold():
6-
from sklearn.model_selection import LabelKFold
7-
labels = [0, 0, 0, 1, 1, 1, 1, 2, 2, 3, 3, 3]
6+
from sklearn.model_selection import GroupKFold
7+
groups = [0, 0, 0, 1, 1, 1, 1, 2, 2, 3, 3, 3]
88

99
plt.figure(figsize=(10, 2))
1010
plt.title("LabelKFold")
@@ -17,9 +17,9 @@ def plot_label_kfold():
1717
n_iter = 3
1818
n_samples_per_fold = 1
1919

20-
cv = LabelKFold(n_splits=3)
20+
cv = GroupKFold(n_splits=3)
2121
mask = np.zeros((n_iter, n_samples))
22-
for i, (train, test) in enumerate(cv.split(range(12), labels=labels)):
22+
for i, (train, test) in enumerate(cv.split(range(12), groups=groups)):
2323
mask[i, train] = 1
2424
mask[i, test] = 2
2525

0 commit comments

Comments
 (0)