Skip to content

Commit 612c0f6

Browse files
xhluluxhlulu
xhlulu
authored and
xhlulu
committed
Create kNN docs draft
1 parent d7d9288 commit 612c0f6

File tree

1 file changed

+119
-0
lines changed

1 file changed

+119
-0
lines changed

doc/python/ml-knn.md

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
## K-Nearest Neighbors (kNN)
2+
3+
How to visualize the K-Nearest Neighbors (kNN) algorithm using scikit-learn.
4+
5+
6+
### Binary Probability Estimates with `go.Contour`
7+
8+
```python
9+
import numpy as np
10+
from sklearn.datasets import make_moons
11+
from sklearn.neighbors import KNeighborsClassifier
12+
import plotly.express as px
13+
import plotly.graph_objects as go
14+
15+
X, y = make_moons(noise=0.3, random_state=0)
16+
17+
# Create a mesh grid on which we will run our model
18+
x_min, x_max = X[:, 0].min() - margin, X[:, 0].max() + margin
19+
y_min, y_max = X[:, 1].min() - margin, X[:, 1].max() + margin
20+
xrange = np.arange(x_min, x_max, mesh_size)
21+
yrange = np.arange(y_min, y_max, mesh_size)
22+
xx, yy = np.meshgrid(xrange, yrange)
23+
24+
# Create classifier, run predictions on grid
25+
clf = neighbors.KNeighborsClassifier(15, weights='uniform')
26+
clf.fit(X, y)
27+
Z = clf.predict_proba(np.c_[xx.ravel(), yy.ravel()])[:, 1]
28+
Z = Z.reshape(xx.shape)
29+
30+
fig = px.scatter(X, x=0, y=1, color=y.astype(str))
31+
fig.add_trace(
32+
go.Contour(
33+
x=xrange,
34+
y=yrange,
35+
z=Z,
36+
showscale=False,
37+
colorscale=['Blue', 'Red'],
38+
opacity=0.4
39+
)
40+
)
41+
```
42+
43+
### Multi-class classification with `px.data` and `go.Heatmap`
44+
45+
```python
46+
import numpy as np
47+
from sklearn.neighbors import KNeighborsClassifier
48+
import plotly.express as px
49+
import plotly.graph_objects as go
50+
51+
mesh_size = .02
52+
margin = 1
53+
54+
df = px.data.iris()
55+
X = df[['sepal_length', 'sepal_width']]
56+
y = df.species_id
57+
58+
# Create a mesh grid on which we will run our model
59+
l_min, l_max = df.sepal_length.min() - margin, df.sepal_length.max() + margin
60+
w_min, w_max = df.sepal_width.min() - margin, df.sepal_width.max() + margin
61+
lrange = np.arange(l_min, l_max, mesh_size)
62+
wrange = np.arange(w_min, w_max, mesh_size)
63+
ll, ww = np.meshgrid(lrange, wrange)
64+
65+
# Create classifier, run predictions on grid
66+
clf = KNeighborsClassifier(15, weights='distance')
67+
clf.fit(X, y)
68+
Z = clf.predict(np.c_[ll.ravel(), ww.ravel()])
69+
Z = Z.reshape(ll.shape)
70+
71+
fig = px.scatter(df, x='sepal_length', y='sepal_width', color='species')
72+
fig.update_traces(marker_size=10, marker_line_width=1)
73+
fig.add_trace(
74+
go.Heatmap(
75+
x=lrange,
76+
y=wrange,
77+
z=Z,
78+
showscale=False,
79+
colorscale=[[0.0, 'blue'], [0.5, 'red'], [1.0, 'green']],
80+
opacity=0.25
81+
)
82+
)
83+
```
84+
85+
### Visualizing kNN Regression
86+
87+
```python
88+
from sklearn.neighbors import KNeighborsRegressor
89+
import plotly.express as px
90+
import plotly.graph_objects as go
91+
92+
df = px.data.tips()
93+
X = df.total_bill.values.reshape(-1, 1)
94+
95+
knn_dist = KNeighborsRegressor(10, weights='distance')
96+
knn_uni = KNeighborsRegressor(10, weights='uniform')
97+
knn_dist.fit(X, df.tip)
98+
knn_uni.fit(X, df.tip)
99+
100+
x_range = np.linspace(X.min(), X.max(), 100)
101+
y_dist = knn_dist.predict(x_range.reshape(-1, 1))
102+
y_uni = knn_uni.predict(x_range.reshape(-1, 1))
103+
104+
fig = px.scatter(df, x='total_bill', y='tip', color='sex', opacity=0.65)
105+
fig.add_traces(go.Scatter(x=x_range, y=y_uni, name='Weights: Uniform'))
106+
fig.add_traces(go.Scatter(x=x_range, y=y_dist, name='Weights: Distance'))
107+
```
108+
109+
### Reference
110+
111+
Learn more about `px`, `go.Contour`, and `go.Heatmap` here:
112+
* https://plot.ly/python/plotly-express/
113+
* https://plot.ly/python/heatmaps/
114+
* https://plot.ly/python/contour-plots/
115+
116+
This tutorial was inspired by amazing examples from the official scikit-learn docs:
117+
* https://scikit-learn.org/stable/auto_examples/neighbors/plot_regression.html
118+
* https://scikit-learn.org/stable/auto_examples/neighbors/plot_classification.html
119+
* https://scikit-learn.org/stable/auto_examples/classification/plot_classifier_comparison.html

0 commit comments

Comments
 (0)