Skip to content

Commit 61b3ad8

Browse files
xhluluxhlulu
xhlulu
authored and
xhlulu
committed
ML Docs: Added 3 new sections to regression notebook
1 parent be71cfe commit 61b3ad8

File tree

1 file changed

+199
-7
lines changed

1 file changed

+199
-7
lines changed

doc/python/ml-regression.md

+199-7
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,28 @@ jupyter:
3333
thumbnail: thumbnail/knn-classification.png
3434
---
3535

36-
## Basic linear regression
36+
## Basic linear regression plots
3737

38-
This example shows how to train a simple linear regression from `sklearn` to predicts the tips servers will receive based on the value of the total bill (dataset is included in `px.data`).
38+
39+
### Ordinary Least Square (OLS) with `plotly.express`
40+
41+
42+
This example shows how to use `plotly.express` to train a simply Ordinary Least Square (OLS) that can predict the tips servers will receive based on the value of the total bill.
43+
44+
```python
45+
import plotly.express as px
46+
47+
df = px.data.tips()
48+
fig = px.scatter(
49+
df, x='total_bill', y='tip', opacity=0.65,
50+
trendline='ols', trendline_color_override='red'
51+
)
52+
fig.show()
53+
```
54+
55+
### Linear Regression with scikit-learn
56+
57+
You can also perform the same prediction using scikit-learn's `LinearRegression`.
3958

4059
```python
4160
import numpy as np
@@ -123,7 +142,6 @@ mesh_size = .02
123142
margin = 0
124143

125144
df = px.data.iris()
126-
features = ["sepal_width", "sepal_length", "petal_width"]
127145

128146
X = df[['sepal_width', 'sepal_length']]
129147
y = df['petal_width']
@@ -150,10 +168,46 @@ fig.add_traces(go.Surface(x=xrange, y=yrange, z=pred, name='pred_surface'))
150168
fig.show()
151169
```
152170

153-
## Label polynomial fits with latex
171+
## Displaying `PolynomialFeatures` using $\LaTeX$
172+
173+
It's easy to diplay latex equations in legend and titles by simply adding `$` before and after your equation.
154174

155175
```python
176+
import numpy as np
177+
import plotly.express as px
178+
import plotly.graph_objects as go
179+
from sklearn.linear_model import LinearRegression
180+
from sklearn.preprocessing import PolynomialFeatures
181+
182+
def format_coefs(coefs):
183+
equation_list = [f"{coef}x^{i}" for i, coef in enumerate(coefs)]
184+
equation = "$" + " + ".join(equation_list) + "$"
185+
186+
replace_map = {"x^0": "", "x^1": "x", '+ -': '- '}
187+
for old, new in replace_map.items():
188+
equation = equation.replace(old, new)
189+
190+
return equation
156191

192+
df = px.data.tips()
193+
X = df.total_bill.values.reshape(-1, 1)
194+
x_range = np.linspace(X.min(), X.max(), 100).reshape(-1, 1)
195+
196+
fig = px.scatter(df, x='total_bill', y='tip', opacity=0.65)
197+
for n_features in [1, 2, 3, 4]:
198+
poly = PolynomialFeatures(n_features)
199+
poly.fit(X)
200+
X_poly = poly.transform(X)
201+
x_range_poly = poly.transform(x_range)
202+
203+
model = LinearRegression(fit_intercept=False)
204+
model.fit(X_poly, df.tip)
205+
y_poly = model.predict(x_range_poly)
206+
207+
equation = format_coefs(model.coef_.round(2))
208+
fig.add_traces(go.Scatter(x=x_range.squeeze(), y=y_poly, name=equation))
209+
210+
fig.show()
157211
```
158212

159213
## Prediction Error Plots
@@ -162,22 +216,160 @@ fig.show()
162216
### Simple Prediction Error
163217

164218
```python
219+
import plotly.express as px
220+
import plotly.graph_objects as go
221+
from sklearn.linear_model import LinearRegression
165222

223+
df = px.data.iris()
224+
X = df.loc[train_idx, ['sepal_width', 'sepal_length']]
225+
y = df.loc[train_idx, 'petal_width']
226+
227+
# Condition the model on sepal width and length, predict the petal width
228+
model = LinearRegression()
229+
model.fit(X, y)
230+
y_pred = model.predict(X)
231+
232+
fig = px.scatter(x=y, y=y_pred, labels={'x': 'y true', 'y': 'y pred'})
233+
fig.add_shape(
234+
type="line", line=dict(dash='dash'),
235+
x0=y.min(), y0=y.min(),
236+
x1=y.max(), y1=y.max()
237+
)
238+
fig.show()
166239
```
167240

168-
### Augmented Prediction Error plot using `px`
241+
### Augmented Prediction Error analysis using `plotly.express`
169242

170243
```python
244+
import plotly.express as px
245+
import plotly.graph_objects as go
246+
from sklearn.linear_model import LinearRegression
247+
from sklearn.model_selection import train_test_split
171248

172-
```
249+
df = px.data.iris()
173250

174-
### Grid Search Visualization using `px.scatter_matrix`
251+
# Split data into training and test splits
252+
train_idx, test_idx = train_test_split(df.index, test_size=.25, random_state=0)
253+
df['split'] = 'train'
254+
df.loc[test_idx, 'split'] = 'test'
175255

256+
X = df[['sepal_width', 'sepal_length']]
257+
X_train = df.loc[train_idx, ['sepal_width', 'sepal_length']]
258+
y_train = df.loc[train_idx, 'petal_width']
259+
260+
# Condition the model on sepal width and length, predict the petal width
261+
model = LinearRegression()
262+
model.fit(X_train, y_train)
263+
df['prediction'] = model.predict(X)
264+
265+
fig = px.scatter(
266+
df, x='petal_width', y='prediction',
267+
marginal_x='histogram', marginal_y='histogram',
268+
color='split', trendline='ols'
269+
)
270+
fig.add_shape(
271+
type="line", line=dict(dash='dash'),
272+
x0=y.min(), y0=y.min(),
273+
x1=y.max(), y1=y.max()
274+
)
275+
276+
fig.show()
277+
```
176278

177279
## Residual Plots
178280

281+
Just like prediction error plots, it's easy to visualize your prediction residuals in just a few lines of codes using `plotly.express` built-in capabilities.
282+
283+
```python
284+
import numpy as np
285+
import plotly.express as px
286+
import plotly.graph_objects as go
287+
from sklearn.linear_model import LinearRegression
288+
from sklearn.model_selection import train_test_split
289+
290+
df = px.data.iris()
291+
292+
# Split data into training and test splits
293+
train_idx, test_idx = train_test_split(df.index, test_size=.25, random_state=0)
294+
df['split'] = 'train'
295+
df.loc[test_idx, 'split'] = 'test'
296+
297+
X = df[['sepal_width', 'sepal_length']]
298+
X_train = df.loc[train_idx, ['sepal_width', 'sepal_length']]
299+
y_train = df.loc[train_idx, 'petal_width']
300+
301+
# Condition the model on sepal width and length, predict the petal width
302+
model = LinearRegression()
303+
model.fit(X_train, y_train)
304+
df['prediction'] = model.predict(X)
305+
df['residual'] = df['prediction'] - df['petal_width']
306+
307+
fig = px.scatter(
308+
df, x='prediction', y='residual',
309+
marginal_y='violin',
310+
color='split', trendline='ols'
311+
)
312+
fig.show()
313+
```
314+
315+
## Grid Search Visualization using `px` facets
316+
179317
```python
318+
import pandas as pd
319+
import plotly.express as px
320+
import plotly.graph_objects as go
321+
from sklearn.model_selection import GridSearchCV
322+
from sklearn.tree import DecisionTreeRegressor
180323

324+
N_FOLD = 5
325+
326+
df = px.data.iris()
327+
X = df.loc[train_idx, ['sepal_width', 'sepal_length']]
328+
y = df.loc[train_idx, 'petal_width']
329+
330+
model = DecisionTreeRegressor()
331+
param_grid = {
332+
'criterion': ['mse', 'friedman_mse', 'mae'],
333+
'max_depth': range(2, 5)
334+
}
335+
grid = GridSearchCV(model, param_grid, cv=N_FOLD)
336+
337+
grid.fit(X, y)
338+
grid_df = pd.DataFrame(grid.cv_results_)
339+
340+
# Convert the wide format of the grid into the long format
341+
# accepted by plotly.express
342+
melted = (
343+
grid_df
344+
.rename(columns=lambda col: col.replace('param_', ''))
345+
.melt(
346+
value_vars=[f'split{i}_test_score' for i in range(N_FOLD)],
347+
id_vars=['rank_test_score', 'mean_test_score',
348+
'mean_fit_time', 'criterion', 'max_depth']
349+
)
350+
)
351+
352+
# Convert R-Squared measure to %
353+
melted[['value', 'mean_test_score']] *= 100
354+
355+
# Format the variable names for simplicity
356+
melted['variable'] = (
357+
melted['variable']
358+
.str.replace('_test_score', '')
359+
.str.replace('split', '')
360+
)
361+
362+
px.bar(
363+
melted, x='variable', y='value',
364+
color='mean_test_score',
365+
facet_row='max_depth',
366+
facet_col='criterion',
367+
title='Test Scores of Grid Search',
368+
hover_data=['mean_fit_time', 'rank_test_score'],
369+
labels={'variable': 'cv_split',
370+
'value': 'r_squared',
371+
'mean_test_score': "mean_r_squared"}
372+
)
181373
```
182374

183375
### Reference

0 commit comments

Comments
 (0)