@@ -34,7 +34,20 @@ jupyter:
34
34
thumbnail : thumbnail/knn-classification.png
35
35
---
36
36
37
- ## Basic Binary Classification with ` plotly.express `
37
+ ## Basic binary classification with kNN
38
+
39
+
40
+ ### Display training and test splits
41
+
42
+ ``` python
43
+
44
+ ```
45
+
46
+ ### Visualize predictions on test split
47
+
48
+ ``` python
49
+
50
+ ```
38
51
39
52
``` python
40
53
import numpy as np
@@ -113,7 +126,7 @@ fig.add_trace(
113
126
showscale = False ,
114
127
colorscale = [' Blue' , ' Red' ],
115
128
opacity = 0.4 ,
116
- name = ' Confidence '
129
+ name = ' Score '
117
130
)
118
131
)
119
132
fig.show()
@@ -150,7 +163,7 @@ Z = Z.reshape(ll.shape)
150
163
proba = clf.predict_proba(np.c_[ll.ravel(), ww.ravel()])
151
164
proba = proba.reshape(ll.shape + (3 ,))
152
165
153
- fig = px.scatter(df, x = ' sepal_length' , y = ' sepal_width' , color = ' species' , width = 1000 , height = 1000 )
166
+ fig = px.scatter(df, x = ' sepal_length' , y = ' sepal_width' , color = ' species' )
154
167
fig.update_traces(marker_size = 10 , marker_line_width = 1 )
155
168
fig.add_trace(
156
169
go.Heatmap(
@@ -173,77 +186,12 @@ fig.add_trace(
173
186
fig.show()
174
187
```
175
188
176
- ## 3D Classification with ` px.scatter_3d `
177
-
178
- ``` python
179
- import numpy as np
180
- import plotly.express as px
181
- import plotly.graph_objects as go
182
- from sklearn.neighbors import KNeighborsClassifier
183
- from sklearn.model_selection import train_test_split
184
-
185
- df = px.data.iris()
186
- features = [" sepal_width" , " sepal_length" , " petal_width" ]
187
-
188
- X = df[features]
189
- y = df.species
190
- X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.3 , random_state = 0 )
191
-
192
- # Create classifier, run predictions on grid
193
- clf = KNeighborsClassifier(15 , weights = ' distance' )
194
- clf.fit(X_train, y_train)
195
- y_pred = clf.predict(X_test)
196
- y_score = clf.predict_proba(X_test)
197
- y_score = np.around(y_score.max(axis = 1 ), 4 )
198
-
199
- fig = px.scatter_3d(
200
- X_test,
201
- x = ' sepal_length' ,
202
- y = ' sepal_width' ,
203
- z = ' petal_width' ,
204
- symbol = y_pred,
205
- color = y_score,
206
- labels = {' symbol' : ' prediction' , ' color' : ' score' }
207
- )
208
- fig.update_layout(legend = dict (x = 0 , y = 0 ))
209
- fig.show()
210
- ```
211
-
212
- ## High Dimension Visualization with ` px.scatter_matrix `
213
-
214
- If you need to visualize classifications that go beyond 3D, you can use the [ scatter plot matrix] ( https://plot.ly/python/splom/ ) .
215
-
216
- ``` python
217
- import numpy as np
218
- import plotly.express as px
219
- import plotly.graph_objects as go
220
- from sklearn.neighbors import KNeighborsClassifier
221
- from sklearn.model_selection import train_test_split
222
-
223
- df = px.data.iris()
224
- features = [" sepal_width" , " sepal_length" , " petal_width" , " petal_length" ]
225
-
226
- X = df[features]
227
- y = df.species
228
- X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.3 , random_state = 0 )
229
-
230
- # Create classifier, run predictions on grid
231
- clf = KNeighborsClassifier(15 , weights = ' distance' )
232
- clf.fit(X_train, y_train)
233
- y_pred = clf.predict(X_test)
234
-
235
- fig = px.scatter_matrix(X_test, dimensions = features, color = y_pred, labels = {' color' : ' prediction' })
236
- fig.show()
237
- ```
238
-
239
189
### Reference
240
190
241
191
Learn more about ` px ` , ` go.Contour ` , and ` go.Heatmap ` here:
242
192
* https://plot.ly/python/plotly-express/
243
193
* https://plot.ly/python/heatmaps/
244
194
* https://plot.ly/python/contour-plots/
245
- * https://plot.ly/python/3d-scatter-plots/
246
- * https://plot.ly/python/splom/
247
195
248
196
This tutorial was inspired by amazing examples from the official scikit-learn docs:
249
197
* https://scikit-learn.org/stable/auto_examples/neighbors/plot_classification.html
0 commit comments