You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: doc/python/ml-regression.md
+199-7
Original file line number
Diff line number
Diff line change
@@ -33,9 +33,28 @@ jupyter:
33
33
thumbnail: thumbnail/knn-classification.png
34
34
---
35
35
36
-
## Basic linear regression
36
+
## Basic linear regression plots
37
37
38
-
This example shows how to train a simple linear regression from `sklearn` to predicts the tips servers will receive based on the value of the total bill (dataset is included in `px.data`).
38
+
39
+
### Ordinary Least Square (OLS) with `plotly.express`
40
+
41
+
42
+
This example shows how to use `plotly.express` to train a simply Ordinary Least Square (OLS) that can predict the tips servers will receive based on the value of the total bill.
43
+
44
+
```python
45
+
import plotly.express as px
46
+
47
+
df = px.data.tips()
48
+
fig = px.scatter(
49
+
df, x='total_bill', y='tip', opacity=0.65,
50
+
trendline='ols', trendline_color_override='red'
51
+
)
52
+
fig.show()
53
+
```
54
+
55
+
### Linear Regression with scikit-learn
56
+
57
+
You can also perform the same prediction using scikit-learn's `LinearRegression`.
39
58
40
59
```python
41
60
import numpy as np
@@ -123,7 +142,6 @@ mesh_size = .02
123
142
margin =0
124
143
125
144
df = px.data.iris()
126
-
features = ["sepal_width", "sepal_length", "petal_width"]
# Condition the model on sepal width and length, predict the petal width
261
+
model = LinearRegression()
262
+
model.fit(X_train, y_train)
263
+
df['prediction'] = model.predict(X)
264
+
265
+
fig = px.scatter(
266
+
df, x='petal_width', y='prediction',
267
+
marginal_x='histogram', marginal_y='histogram',
268
+
color='split', trendline='ols'
269
+
)
270
+
fig.add_shape(
271
+
type="line", line=dict(dash='dash'),
272
+
x0=y.min(), y0=y.min(),
273
+
x1=y.max(), y1=y.max()
274
+
)
275
+
276
+
fig.show()
277
+
```
176
278
177
279
## Residual Plots
178
280
281
+
Just like prediction error plots, it's easy to visualize your prediction residuals in just a few lines of codes using `plotly.express` built-in capabilities.
282
+
283
+
```python
284
+
import numpy as np
285
+
import plotly.express as px
286
+
import plotly.graph_objects as go
287
+
from sklearn.linear_model import LinearRegression
288
+
from sklearn.model_selection import train_test_split
0 commit comments