Skip to content

Commit 6b3bbb1

Browse files
Xing Hanxhlulu
Xing Han
authored and
xhlulu
committed
Update based on Emma's suggestions
1 parent 612c0f6 commit 6b3bbb1

File tree

1 file changed

+24
-7
lines changed

1 file changed

+24
-7
lines changed

doc/python/ml-knn.md

+24-7
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,13 @@ How to visualize the K-Nearest Neighbors (kNN) algorithm using scikit-learn.
77

88
```python
99
import numpy as np
10-
from sklearn.datasets import make_moons
11-
from sklearn.neighbors import KNeighborsClassifier
1210
import plotly.express as px
1311
import plotly.graph_objects as go
12+
from sklearn.datasets import make_moons
13+
from sklearn.neighbors import KNeighborsClassifier
14+
15+
mesh_size = .02
16+
margin = 1
1417

1518
X, y = make_moons(noise=0.3, random_state=0)
1619

@@ -22,12 +25,12 @@ yrange = np.arange(y_min, y_max, mesh_size)
2225
xx, yy = np.meshgrid(xrange, yrange)
2326

2427
# Create classifier, run predictions on grid
25-
clf = neighbors.KNeighborsClassifier(15, weights='uniform')
28+
clf = KNeighborsClassifier(15, weights='uniform')
2629
clf.fit(X, y)
2730
Z = clf.predict_proba(np.c_[xx.ravel(), yy.ravel()])[:, 1]
2831
Z = Z.reshape(xx.shape)
2932

30-
fig = px.scatter(X, x=0, y=1, color=y.astype(str))
33+
fig = px.scatter(X, x=0, y=1, color=y.astype(str), labels={'0':'', '1':''})
3134
fig.add_trace(
3235
go.Contour(
3336
x=xrange,
@@ -38,15 +41,16 @@ fig.add_trace(
3841
opacity=0.4
3942
)
4043
)
44+
fig.show()
4145
```
4246

4347
### Multi-class classification with `px.data` and `go.Heatmap`
4448

4549
```python
4650
import numpy as np
47-
from sklearn.neighbors import KNeighborsClassifier
4851
import plotly.express as px
4952
import plotly.graph_objects as go
53+
from sklearn.neighbors import KNeighborsClassifier
5054

5155
mesh_size = .02
5256
margin = 1
@@ -67,6 +71,8 @@ clf = KNeighborsClassifier(15, weights='distance')
6771
clf.fit(X, y)
6872
Z = clf.predict(np.c_[ll.ravel(), ww.ravel()])
6973
Z = Z.reshape(ll.shape)
74+
proba = clf.predict_proba(np.c_[ll.ravel(), ww.ravel()])
75+
proba = proba.reshape(ll.shape + (3,))
7076

7177
fig = px.scatter(df, x='sepal_length', y='sepal_width', color='species')
7278
fig.update_traces(marker_size=10, marker_line_width=1)
@@ -77,17 +83,27 @@ fig.add_trace(
7783
z=Z,
7884
showscale=False,
7985
colorscale=[[0.0, 'blue'], [0.5, 'red'], [1.0, 'green']],
80-
opacity=0.25
86+
opacity=0.25,
87+
customdata=proba,
88+
hovertemplate=(
89+
'sepal length: %{x} <br>'
90+
'sepal width: %{y} <br>'
91+
'p(setosa): %{customdata[0]:.3f}<br>'
92+
'p(versicolor): %{customdata[1]:.3f}<br>'
93+
'p(virginica): %{customdata[2]:.3f}<extra></extra>'
94+
)
8195
)
8296
)
97+
fig.show()
8398
```
8499

85100
### Visualizing kNN Regression
86101

87102
```python
88-
from sklearn.neighbors import KNeighborsRegressor
103+
import numpy as np
89104
import plotly.express as px
90105
import plotly.graph_objects as go
106+
from sklearn.neighbors import KNeighborsRegressor
91107

92108
df = px.data.tips()
93109
X = df.total_bill.values.reshape(-1, 1)
@@ -104,6 +120,7 @@ y_uni = knn_uni.predict(x_range.reshape(-1, 1))
104120
fig = px.scatter(df, x='total_bill', y='tip', color='sex', opacity=0.65)
105121
fig.add_traces(go.Scatter(x=x_range, y=y_uni, name='Weights: Uniform'))
106122
fig.add_traces(go.Scatter(x=x_range, y=y_dist, name='Weights: Distance'))
123+
fig.show()
107124
```
108125

109126
### Reference

0 commit comments

Comments
 (0)