Skip to content

Commit a28ee1f

Browse files
xhluluxhlulu
xhlulu
authored and
xhlulu
committed
ML Docs: Added new section to regression, updated references
1 parent 1de7a14 commit a28ee1f

File tree

1 file changed

+71
-5
lines changed

1 file changed

+71
-5
lines changed

doc/python/ml-regression.md

Lines changed: 71 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,58 @@ fig = px.scatter(
323323
fig.show()
324324
```
325325

326+
## Regularization visualization
327+
328+
329+
### Plot alphas for individual folds
330+
331+
```python
332+
import pandas as pd
333+
import numpy as np
334+
import plotly.express as px
335+
import plotly.graph_objects as go
336+
from sklearn.linear_model import LassoCV
337+
338+
# Load and preprocess the data
339+
df = px.data.gapminder()
340+
X = df.drop(columns=['lifeExp', 'iso_num'])
341+
X = pd.get_dummies(X, columns=['country', 'continent', 'iso_alpha'])
342+
y = df['lifeExp']
343+
344+
# Train model to predict life expectancy
345+
model = LassoCV(cv=N_FOLD, normalize=True)
346+
model.fit(X, y)
347+
mean_alphas = model.mse_path_.mean(axis=-1)
348+
349+
fig = go.Figure([
350+
go.Scatter(
351+
x=model.alphas_, y=model.mse_path_[:, i],
352+
name=f"Fold: {i+1}", opacity=.5, line=dict(dash='dash'),
353+
hovertemplate="alpha: %{x} <br>MSE: %{y}"
354+
)
355+
for i in range(N_FOLD)
356+
])
357+
fig.add_traces(go.Scatter(
358+
x=model.alphas_, y=mean_alphas,
359+
name='Mean', line=dict(color='black', width=3),
360+
hovertemplate="alpha: %{x} <br>MSE: %{y}",
361+
))
362+
363+
fig.add_shape(
364+
type="line", line=dict(dash='dash'),
365+
x0=model.alpha_, y0=0,
366+
x1=model.alpha_, y1=1,
367+
yref='paper'
368+
)
369+
370+
fig.update_layout(
371+
xaxis_title='alpha',
372+
xaxis_type="log",
373+
yaxis_title="Mean Square Error (MSE)"
374+
)
375+
fig.show()
376+
```
377+
326378
## Grid search visualization using `px.density_heatmap` and `px.box`
327379

328380
In this example, we show how to visualize the results of a grid search on a `DecisionTreeRegressor`. The first plot shows how to visualize the score of each model parameter on individual splits (grouped using facets). The second plot aggregates the results of all splits such that each box represents a single model.
@@ -401,8 +453,22 @@ fig_box.show()
401453

402454
### Reference
403455

404-
Learn more about `px` here:
405-
* https://plot.ly/python/plotly-express/
406-
407-
This tutorial was inspired by amazing examples from the official scikit-learn docs:
408-
* https://scikit-learn.org/stable/auto_examples/neighbors/plot_regression.html
456+
Learn more about the `px` figures used in this tutorial:
457+
* Plotly Express: https://plot.ly/python/plotly-express/
458+
* Vertical Lines: https://plot.ly/python/shapes/
459+
* Heatmaps: https://plot.ly/python/heatmaps/
460+
* Box Plots: https://plot.ly/python/box-plots/
461+
* 3D Scatter: https://plot.ly/python/3d-scatter-plots/
462+
* Surface Plots: https://plot.ly/python/3d-surface-plots/
463+
464+
Learn more about the Machine Learning models used in this tutorial:
465+
* https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LinearRegression.html
466+
* https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LassoCV.html
467+
* https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KNeighborsRegressor.html
468+
* https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeRegressor.html
469+
* https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.PolynomialFeatures.html
470+
471+
Other tutorials that inspired this notebook:
472+
* https://seaborn.pydata.org/examples/residplot.html
473+
* https://scikit-learn.org/stable/auto_examples/linear_model/plot_lasso_model_selection.html
474+
* http://www.scikit-yb.org/zh/latest/api/regressor/peplot.html

0 commit comments

Comments
 (0)