Skip to content

Commit 9fc81be

Browse files
author
xhlu
committed
ML Docs: Update Regression notebook
Added a preliminary section that introduces roc curves
1 parent 0b861e9 commit 9fc81be

File tree

1 file changed

+74
-3
lines changed

1 file changed

+74
-3
lines changed

doc/python/ml-roc-pr.md

Lines changed: 74 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,58 @@ jupyter:
3434
thumbnail: thumbnail/ml-roc-pr.png
3535
---
3636

37-
## Basic Binary ROC Curve
37+
## Preliminary plots
38+
39+
Before diving into the receiver operating characteristic (ROC) curve, we will look at two plots that will give some context to the thresholds mechanism behind the ROC and PR curves.
40+
41+
In the histogram, we observe that the score spread such that most of the positive labels are binned near 1, and a lot of the negative labels are close to 0. When we set a threshold on the score, all of the bins to its left will be classified as 0's, and everything to the right will be 1's. There are obviously a few outliers, such as **negative** samples that our model gave a high score, and *positive* samples with a low score. If we set a threshold right in the middle, those outliers will respectively become **false positives** and *false negatives*.
42+
43+
As we adjust thresholds, the number of positive positives will increase or decrease, and at the same time the number of true positives will also change; this is shown in the second plot. As you can see, the model seems to perform fairly well, because the true positive rate decreases slowly, whereas the false positive rate decreases sharply as we increase the threshold. Those two lines each represent a dimension of the ROC curve.
44+
45+
```python
46+
import plotly.express as px
47+
import plotly.graph_objects as go
48+
from sklearn.linear_model import LogisticRegression
49+
from sklearn.metrics import roc_curve, auc
50+
from sklearn.datasets import make_classification
51+
52+
X, y = make_classification(n_samples=500, random_state=0)
53+
54+
model = LogisticRegression()
55+
model.fit(X, y)
56+
y_score = model.predict_proba(X)[:, 1]
57+
fpr, tpr, thresholds = roc_curve(y, y_score)
58+
59+
# The histogram of scores compared to true labels
60+
fig_hist = px.histogram(
61+
x=y_score, color=y, nbins=50,
62+
labels=dict(color='True Labels', x='Score')
63+
)
64+
65+
# Evaluating model performance at various thresholds
66+
fig_thresh = go.Figure([
67+
go.Scatter(x=thresholds, y=fpr, name='False Positive Rate'),
68+
go.Scatter(x=thresholds, y=tpr, name='True Positive Rate')
69+
])
70+
fig_thresh.update_layout(
71+
title='TPR and FPR at every threshold',
72+
xaxis_title='Threshold',
73+
yaxis_title='Rate',
74+
yaxis=dict(scaleanchor="x", scaleratio=1),
75+
xaxis=dict(constrain='domain')
76+
)
77+
fig_thresh.update_xaxes(range=[0, 1])
78+
79+
# Display plots
80+
fig_hist.show()
81+
fig_thresh.show()
82+
```
83+
84+
## Basic binary ROC curve
85+
86+
Notice how this ROC curve looks similar to the True Positive Rate curve from the previous plot. This is because they are the same curve, except the x-axis consists of increasing values of FPR instead of threshold, which is why the line is flipped and distorted.
87+
88+
We also display the area under the ROC curve (ROC AUC), which is fairly high, thus consistent with our intepretation of the previous plots.
3889

3990
```python
4091
import plotly.express as px
@@ -59,6 +110,10 @@ fig.add_shape(
59110
type='line', line=dict(dash='dash'),
60111
x0=0, x1=1, y0=0, y1=1
61112
)
113+
fig.update_layout(
114+
yaxis=dict(scaleanchor="x", scaleratio=1),
115+
xaxis=dict(constrain='domain')
116+
)
62117
fig.show()
63118
```
64119

@@ -112,7 +167,9 @@ for i in range(y_scores.shape[1]):
112167

113168
fig.update_layout(
114169
xaxis_title='False Positive Rate',
115-
yaxis_title='True Positive Rate'
170+
yaxis_title='True Positive Rate',
171+
yaxis=dict(scaleanchor="x", scaleratio=1),
172+
xaxis=dict(constrain='domain')
116173
)
117174
fig.show()
118175
```
@@ -144,6 +201,11 @@ fig.add_shape(
144201
type='line', line=dict(dash='dash'),
145202
x0=0, x1=1, y0=1, y1=0
146203
)
204+
fig.update_layout(
205+
yaxis=dict(scaleanchor="x", scaleratio=1),
206+
xaxis=dict(constrain='domain')
207+
)
208+
147209
fig.show()
148210
```
149211

@@ -195,7 +257,16 @@ for i in range(y_scores.shape[1]):
195257

196258
fig.update_layout(
197259
xaxis_title='Recall',
198-
yaxis_title='Precision'
260+
yaxis_title='Precision',
261+
yaxis=dict(scaleanchor="x", scaleratio=1),
262+
xaxis=dict(constrain='domain')
199263
)
200264
fig.show()
201265
```
266+
267+
## References
268+
269+
Learn more about `px`, `px.area`, `px.hist`:
270+
* https://plot.ly/python/histograms/
271+
* https://plot.ly/python/filled-area-plots/
272+
* https://plot.ly/python/line-charts/

0 commit comments

Comments
 (0)