@@ -7,10 +7,13 @@ How to visualize the K-Nearest Neighbors (kNN) algorithm using scikit-learn.
7
7
8
8
``` python
9
9
import numpy as np
10
- from sklearn.datasets import make_moons
11
- from sklearn.neighbors import KNeighborsClassifier
12
10
import plotly.express as px
13
11
import plotly.graph_objects as go
12
+ from sklearn.datasets import make_moons
13
+ from sklearn.neighbors import KNeighborsClassifier
14
+
15
+ mesh_size = .02
16
+ margin = 1
14
17
15
18
X, y = make_moons(noise = 0.3 , random_state = 0 )
16
19
@@ -22,12 +25,12 @@ yrange = np.arange(y_min, y_max, mesh_size)
22
25
xx, yy = np.meshgrid(xrange , yrange)
23
26
24
27
# Create classifier, run predictions on grid
25
- clf = neighbors. KNeighborsClassifier(15 , weights = ' uniform' )
28
+ clf = KNeighborsClassifier(15 , weights = ' uniform' )
26
29
clf.fit(X, y)
27
30
Z = clf.predict_proba(np.c_[xx.ravel(), yy.ravel()])[:, 1 ]
28
31
Z = Z.reshape(xx.shape)
29
32
30
- fig = px.scatter(X, x = 0 , y = 1 , color = y.astype(str ))
33
+ fig = px.scatter(X, x = 0 , y = 1 , color = y.astype(str ), labels = { ' 0 ' : ' ' , ' 1 ' : ' ' } )
31
34
fig.add_trace(
32
35
go.Contour(
33
36
x = xrange ,
@@ -38,15 +41,16 @@ fig.add_trace(
38
41
opacity = 0.4
39
42
)
40
43
)
44
+ fig.show()
41
45
```
42
46
43
47
### Multi-class classification with ` px.data ` and ` go.Heatmap `
44
48
45
49
``` python
46
50
import numpy as np
47
- from sklearn.neighbors import KNeighborsClassifier
48
51
import plotly.express as px
49
52
import plotly.graph_objects as go
53
+ from sklearn.neighbors import KNeighborsClassifier
50
54
51
55
mesh_size = .02
52
56
margin = 1
@@ -67,6 +71,8 @@ clf = KNeighborsClassifier(15, weights='distance')
67
71
clf.fit(X, y)
68
72
Z = clf.predict(np.c_[ll.ravel(), ww.ravel()])
69
73
Z = Z.reshape(ll.shape)
74
+ proba = clf.predict_proba(np.c_[ll.ravel(), ww.ravel()])
75
+ proba = proba.reshape(ll.shape + (3 ,))
70
76
71
77
fig = px.scatter(df, x = ' sepal_length' , y = ' sepal_width' , color = ' species' )
72
78
fig.update_traces(marker_size = 10 , marker_line_width = 1 )
@@ -77,17 +83,27 @@ fig.add_trace(
77
83
z = Z,
78
84
showscale = False ,
79
85
colorscale = [[0.0 , ' blue' ], [0.5 , ' red' ], [1.0 , ' green' ]],
80
- opacity = 0.25
86
+ opacity = 0.25 ,
87
+ customdata = proba,
88
+ hovertemplate = (
89
+ ' sepal length: %{x} <br>'
90
+ ' sepal width: %{y} <br>'
91
+ ' p(setosa): %{customdata[0]:.3f } <br>'
92
+ ' p(versicolor): %{customdata[1]:.3f } <br>'
93
+ ' p(virginica): %{customdata[2]:.3f } <extra></extra>'
94
+ )
81
95
)
82
96
)
97
+ fig.show()
83
98
```
84
99
85
100
### Visualizing kNN Regression
86
101
87
102
``` python
88
- from sklearn.neighbors import KNeighborsRegressor
103
+ import numpy as np
89
104
import plotly.express as px
90
105
import plotly.graph_objects as go
106
+ from sklearn.neighbors import KNeighborsRegressor
91
107
92
108
df = px.data.tips()
93
109
X = df.total_bill.values.reshape(- 1 , 1 )
@@ -104,6 +120,7 @@ y_uni = knn_uni.predict(x_range.reshape(-1, 1))
104
120
fig = px.scatter(df, x = ' total_bill' , y = ' tip' , color = ' sex' , opacity = 0.65 )
105
121
fig.add_traces(go.Scatter(x = x_range, y = y_uni, name = ' Weights: Uniform' ))
106
122
fig.add_traces(go.Scatter(x = x_range, y = y_dist, name = ' Weights: Distance' ))
123
+ fig.show()
107
124
```
108
125
109
126
### Reference
0 commit comments