@@ -23,39 +23,43 @@ def plot_group_kfold():
23
23
mask [i , train ] = 1
24
24
mask [i , test ] = 2
25
25
26
-
27
26
for i in range (n_folds ):
28
27
# test is grey
29
- colors = ["grey" if x == 2 else "white" for x in mask [:, i ]]
28
+ colors = ["grey" if x == 2 else "white" for x in mask [:, i ]]
30
29
# not selected has no hatch
31
-
32
- boxes = axes .barh (bottom = range (n_iter ), width = [1 - 0.1 ] * n_iter , left = i * n_samples_per_fold , height = .6 , color = colors , hatch = "//" )
30
+
31
+ boxes = axes .barh (bottom = range (n_iter ), width = [1 - 0.1 ] * n_iter ,
32
+ left = i * n_samples_per_fold , height = .6 , color = colors ,
33
+ hatch = "//" , edgecolor = "k" )
33
34
for j in np .where (mask [:, i ] == 0 )[0 ]:
34
35
boxes [j ].set_hatch ("" )
35
-
36
- axes .barh (bottom = [n_iter ] * n_folds , width = [1 - 0.1 ] * n_folds , left = np .arange (n_folds ) * n_samples_per_fold , height = .6 , color = "w" )
36
+
37
+ axes .barh (bottom = [n_iter ] * n_folds , width = [1 - 0.1 ] * n_folds ,
38
+ left = np .arange (n_folds ) * n_samples_per_fold , height = .6 ,
39
+ color = "w" , edgecolor = 'k' )
37
40
38
41
for i in range (12 ):
39
- axes .text ((i + .5 ) * n_samples_per_fold , 3.5 , "%d" % groups [ i ], horizontalalignment = "center" )
40
- #ax.set_ylim(4, -0.1 )
41
-
42
+ axes .text ((i + .5 ) * n_samples_per_fold , 3.5 , "%d" %
43
+ groups [ i ], horizontalalignment = "center" )
44
+
42
45
axes .invert_yaxis ()
43
46
axes .set_xlim (0 , n_samples + 1 )
44
47
axes .set_ylabel ("CV iterations" )
45
48
axes .set_xlabel ("Data points" )
46
49
axes .set_xticks (np .arange (n_samples ) + .5 )
47
50
axes .set_xticklabels (np .arange (1 , n_samples + 1 ))
48
51
axes .set_yticks (np .arange (n_iter + 1 ) + .3 )
49
- axes .set_yticklabels (["Split %d" % x for x in range (1 , n_iter + 1 )] + ["Group" ]);
50
- plt .legend ([boxes [0 ], boxes [1 ]], ["Training set" , "Test set" ], loc = (1 , .3 ));
52
+ axes .set_yticklabels (
53
+ ["Split %d" % x for x in range (1 , n_iter + 1 )] + ["Group" ])
54
+ plt .legend ([boxes [0 ], boxes [1 ]], ["Training set" , "Test set" ], loc = (1 , .3 ))
51
55
plt .tight_layout ()
52
56
53
57
54
-
55
58
def plot_shuffle_split ():
56
59
from sklearn .model_selection import ShuffleSplit
57
60
plt .figure (figsize = (10 , 2 ))
58
- plt .title ("ShuffleSplit with 10 points, train_size=5, test_size=2, n_splits=4" )
61
+ plt .title ("ShuffleSplit with 10 points"
62
+ ", train_size=5, test_size=2, n_splits=4" )
59
63
60
64
axes = plt .gca ()
61
65
axes .set_frame_on (False )
@@ -71,13 +75,14 @@ def plot_shuffle_split():
71
75
mask [i , train ] = 1
72
76
mask [i , test ] = 2
73
77
74
-
75
78
for i in range (n_folds ):
76
79
# test is grey
77
- colors = ["grey" if x == 2 else "white" for x in mask [:, i ]]
80
+ colors = ["grey" if x == 2 else "white" for x in mask [:, i ]]
78
81
# not selected has no hatch
79
-
80
- boxes = axes .barh (bottom = range (n_iter ), width = [1 - 0.1 ] * n_iter , left = i * n_samples_per_fold , height = .6 , color = colors , hatch = "//" )
82
+
83
+ boxes = axes .barh (bottom = range (n_iter ), width = [1 - 0.1 ] * n_iter ,
84
+ left = i * n_samples_per_fold , height = .6 , color = colors ,
85
+ hatch = "//" , edgecolor = 'k' )
81
86
for j in np .where (mask [:, i ] == 0 )[0 ]:
82
87
boxes [j ].set_hatch ("" )
83
88
@@ -88,17 +93,16 @@ def plot_shuffle_split():
88
93
axes .set_xticks (np .arange (n_samples ) + .5 )
89
94
axes .set_xticklabels (np .arange (1 , n_samples + 1 ))
90
95
axes .set_yticks (np .arange (n_iter ) + .3 )
91
- axes .set_yticklabels (["Split %d" % x for x in range (1 , n_iter + 1 )]);
96
+ axes .set_yticklabels (["Split %d" % x for x in range (1 , n_iter + 1 )])
92
97
# legend hacked for this random state
93
- plt .legend ([boxes [1 ], boxes [0 ], boxes [2 ]], ["Training set" , "Test set" , "Not selected" ], loc = (1 , .3 ));
98
+ plt .legend ([boxes [1 ], boxes [0 ], boxes [2 ]], [
99
+ "Training set" , "Test set" , "Not selected" ], loc = (1 , .3 ))
94
100
plt .tight_layout ()
95
- plt .savefig ("images/06_shuffle_split.png" )
96
- plt .close ()
97
101
98
102
99
103
def plot_stratified_cross_validation ():
100
104
fig , both_axes = plt .subplots (2 , 1 , figsize = (12 , 5 ))
101
- #plt.title("cross_validation_not_stratified")
105
+ # plt.title("cross_validation_not_stratified")
102
106
axes = both_axes [0 ]
103
107
axes .set_title ("Standard cross-validation with sorted class labels" )
104
108
@@ -109,25 +113,30 @@ def plot_stratified_cross_validation():
109
113
110
114
n_samples_per_fold = n_samples / float (n_folds )
111
115
112
-
113
116
for i in range (n_folds ):
114
117
colors = ["w" ] * n_folds
115
118
colors [i ] = "grey"
116
- axes .barh (bottom = range (n_folds ), width = [n_samples_per_fold - 1 ] * n_folds , left = i * n_samples_per_fold , height = .6 , color = colors , hatch = "//" )
117
-
118
- axes .barh (bottom = [n_folds ] * n_folds , width = [n_samples_per_fold - 1 ] * n_folds , left = np .arange (3 ) * n_samples_per_fold , height = .6 , color = "w" )
119
+ axes .barh (bottom = range (n_folds ), width = [n_samples_per_fold - 1 ] *
120
+ n_folds , left = i * n_samples_per_fold , height = .6 ,
121
+ color = colors , hatch = "//" , edgecolor = 'k' )
122
+
123
+ axes .barh (bottom = [n_folds ] * n_folds , width = [n_samples_per_fold - 1 ] *
124
+ n_folds , left = np .arange (3 ) * n_samples_per_fold , height = .6 ,
125
+ color = "w" , edgecolor = 'k' )
119
126
120
127
axes .invert_yaxis ()
121
128
axes .set_xlim (0 , n_samples + 1 )
122
129
axes .set_ylabel ("CV iterations" )
123
130
axes .set_xlabel ("Data points" )
124
- axes .set_xticks (np .arange (n_samples_per_fold / 2. , n_samples , n_samples_per_fold ))
131
+ axes .set_xticks (np .arange (n_samples_per_fold / 2. ,
132
+ n_samples , n_samples_per_fold ))
125
133
axes .set_xticklabels (["Fold %d" % x for x in range (1 , n_folds + 1 )])
126
134
axes .set_yticks (np .arange (n_folds + 1 ) + .3 )
127
- axes .set_yticklabels (["Split %d" % x for x in range (1 , n_folds + 1 )] + ["Class label" ])
135
+ axes .set_yticklabels (
136
+ ["Split %d" % x for x in range (1 , n_folds + 1 )] + ["Class label" ])
128
137
for i in range (3 ):
129
- axes .text ((i + .5 ) * n_samples_per_fold , 3.5 , "Class %d" % i , horizontalalignment = "center" )
130
-
138
+ axes .text ((i + .5 ) * n_samples_per_fold , 3.5 , "Class %d" %
139
+ i , horizontalalignment = "center" )
131
140
132
141
ax = both_axes [1 ]
133
142
ax .set_title ("Stratified Cross-validation" )
@@ -138,24 +147,38 @@ def plot_stratified_cross_validation():
138
147
ax .set_xlabel ("Data points" )
139
148
140
149
ax .set_yticks (np .arange (n_folds + 1 ) + .3 )
141
- ax .set_yticklabels (["Split %d" % x for x in range (1 , n_folds + 1 )] + ["Class label" ]);
150
+ ax .set_yticklabels (
151
+ ["Split %d" % x for x in range (1 , n_folds + 1 )] + ["Class label" ])
142
152
143
153
n_subsplit = n_samples_per_fold / 3.
144
154
for i in range (n_folds ):
145
- test_bars = ax .barh (bottom = [i ] * n_folds , width = [n_subsplit - 1 ] * n_folds , left = np .arange (n_folds ) * n_samples_per_fold + i * n_subsplit , height = .6 , color = "grey" , hatch = "//" )
155
+ test_bars = ax .barh (
156
+ bottom = [i ] * n_folds , width = [n_subsplit - 1 ] * n_folds ,
157
+ left = np .arange (n_folds ) * n_samples_per_fold + i * n_subsplit ,
158
+ height = .6 , color = "grey" , hatch = "//" , edgecolor = 'k' )
146
159
147
160
w = 2 * n_subsplit - 1
148
- ax .barh (bottom = [0 ] * n_folds , width = [w ] * n_folds , left = np .arange (n_folds ) * n_samples_per_fold + (0 + 1 ) * n_subsplit , height = .6 , color = "w" , hatch = "//" )
149
- ax .barh (bottom = [1 ] * (n_folds + 1 ), width = [w / 2. , w , w , w / 2. ], left = np .maximum (0 , np .arange (n_folds + 1 ) * n_samples_per_fold - n_subsplit ), height = .6 , color = "w" , hatch = "//" )
150
- training_bars = ax .barh (bottom = [2 ] * n_folds , width = [w ] * n_folds , left = np .arange (n_folds ) * n_samples_per_fold , height = .6 , color = "w" , hatch = "//" )
151
-
152
-
153
- ax .barh (bottom = [n_folds ] * n_folds , width = [n_samples_per_fold - 1 ] * n_folds , left = np .arange (n_folds ) * n_samples_per_fold , height = .6 , color = "w" )
161
+ ax .barh (bottom = [0 ] * n_folds , width = [w ] * n_folds , left = np .arange (n_folds )
162
+ * n_samples_per_fold + (0 + 1 ) * n_subsplit , height = .6 , color = "w" ,
163
+ hatch = "//" , edgecolor = 'k' )
164
+ ax .barh (bottom = [1 ] * (n_folds + 1 ), width = [w / 2. , w , w , w / 2. ],
165
+ left = np .maximum (0 , np .arange (n_folds + 1 ) * n_samples_per_fold -
166
+ n_subsplit ), height = .6 , color = "w" , hatch = "//" ,
167
+ edgecolor = 'k' )
168
+ training_bars = ax .barh (bottom = [2 ] * n_folds , width = [w ] * n_folds ,
169
+ left = np .arange (n_folds ) * n_samples_per_fold ,
170
+ height = .6 , color = "w" , hatch = "//" , edgecolor = 'k' )
171
+
172
+ ax .barh (bottom = [n_folds ] * n_folds , width = [n_samples_per_fold - 1 ] *
173
+ n_folds , left = np .arange (n_folds ) * n_samples_per_fold , height = .6 ,
174
+ color = "w" , edgecolor = 'k' )
154
175
155
176
for i in range (3 ):
156
- ax .text ((i + .5 ) * n_samples_per_fold , 3.5 , "Class %d" % i , horizontalalignment = "center" )
177
+ ax .text ((i + .5 ) * n_samples_per_fold , 3.5 , "Class %d" %
178
+ i , horizontalalignment = "center" )
157
179
ax .set_ylim (4 , - 0.1 )
158
- plt .legend ([training_bars [0 ], test_bars [0 ]], ['Training data' , 'Test data' ], loc = (1.05 , 1 ), frameon = False );
180
+ plt .legend ([training_bars [0 ], test_bars [0 ]], [
181
+ 'Training data' , 'Test data' ], loc = (1.05 , 1 ), frameon = False )
159
182
160
183
fig .tight_layout ()
161
184
@@ -171,33 +194,43 @@ def plot_cross_validation():
171
194
172
195
n_samples_per_fold = n_samples / float (n_folds )
173
196
174
-
175
197
for i in range (n_folds ):
176
198
colors = ["w" ] * n_folds
177
199
colors [i ] = "grey"
178
- bars = plt .barh (bottom = range (n_folds ), width = [n_samples_per_fold - 0.1 ] * n_folds ,
179
- left = i * n_samples_per_fold , height = .6 , color = colors , hatch = "//" )
200
+ bars = plt .barh (
201
+ bottom = range (n_folds ), width = [n_samples_per_fold - 0.1 ] * n_folds ,
202
+ left = i * n_samples_per_fold , height = .6 , color = colors , hatch = "//" ,
203
+ edgecolor = 'k' )
180
204
axes .invert_yaxis ()
181
205
axes .set_xlim (0 , n_samples + 1 )
182
206
plt .ylabel ("CV iterations" )
183
207
plt .xlabel ("Data points" )
184
- plt .xticks (np .arange (n_samples_per_fold / 2. , n_samples , n_samples_per_fold ), ["Fold %d" % x for x in range (1 , n_folds + 1 )])
185
- plt .yticks (np .arange (n_folds ) + .3 , ["Split %d" % x for x in range (1 , n_folds + 1 )])
186
- plt .legend ([bars [0 ], bars [4 ]], ['Training data' , 'Test data' ], loc = (1.05 , 0.4 ), frameon = False );
208
+ plt .xticks (np .arange (n_samples_per_fold / 2. , n_samples ,
209
+ n_samples_per_fold ),
210
+ ["Fold %d" % x for x in range (1 , n_folds + 1 )])
211
+ plt .yticks (np .arange (n_folds ) + .3 ,
212
+ ["Split %d" % x for x in range (1 , n_folds + 1 )])
213
+ plt .legend ([bars [0 ], bars [4 ]], ['Training data' , 'Test data' ],
214
+ loc = (1.05 , 0.4 ), frameon = False )
187
215
188
216
189
217
def plot_threefold_split ():
190
218
plt .figure (figsize = (15 , 1 ))
191
219
axis = plt .gca ()
192
- bars = axis .barh ([0 , 0 , 0 ], [11.9 , 2.9 , 4.9 ], left = [0 , 12 , 15 ], color = ['white' , 'grey' , 'grey' ], hatch = "//" )
220
+ bars = axis .barh ([0 , 0 , 0 ], [11.9 , 2.9 , 4.9 ], left = [0 , 12 , 15 ], color = [
221
+ 'white' , 'grey' , 'grey' ], hatch = "//" , edgecolor = 'k' )
193
222
bars [2 ].set_hatch (r"" )
194
223
axis .set_yticks (())
195
224
axis .set_frame_on (False )
196
225
axis .set_ylim (- .1 , .8 )
197
226
axis .set_xlim (- 0.1 , 20.1 )
198
227
axis .set_xticks ([6 , 13.3 , 17.5 ])
199
- axis .set_xticklabels (["training set" , "validation set" , "test set" ], fontdict = {'fontsize' : 20 });
228
+ axis .set_xticklabels (["training set" , "validation set" ,
229
+ "test set" ], fontdict = {'fontsize' : 20 })
200
230
axis .tick_params (length = 0 , labeltop = True , labelbottom = False )
201
- axis .text (6 , - .3 , "Model fitting" , fontdict = {'fontsize' : 13 }, horizontalalignment = "center" )
202
- axis .text (13.3 , - .3 , "Parameter selection" , fontdict = {'fontsize' : 13 }, horizontalalignment = "center" )
203
- axis .text (17.5 , - .3 , "Evaluation" , fontdict = {'fontsize' : 13 }, horizontalalignment = "center" )
231
+ axis .text (6 , - .3 , "Model fitting" ,
232
+ fontdict = {'fontsize' : 13 }, horizontalalignment = "center" )
233
+ axis .text (13.3 , - .3 , "Parameter selection" ,
234
+ fontdict = {'fontsize' : 13 }, horizontalalignment = "center" )
235
+ axis .text (17.5 , - .3 , "Evaluation" ,
236
+ fontdict = {'fontsize' : 13 }, horizontalalignment = "center" )
0 commit comments