8000 Adding ML ROC & PR page draft · plotly/plotly.r-docs@04746c2 · GitHub
[go: up one dir, main page]

Skip to content

Commit 04746c2

Browse files
author
Kalpit Desai
committed
Adding ML ROC & PR page draft
1 parent 598068a commit 04746c2

File tree

1 file changed

+287
-0
lines changed

1 file changed

+287
-0
lines changed

r/2021-07-26-ml-roc-pr.Rmd

Lines changed: 287 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
2+
## ROC and PR Curves in R
3+
4+
Interpret the results of your classification using Receiver Operating Characteristics (ROC) and Precision-Recall (PR) Curves in R with Plotly.
5+
6+
## Preliminary plots
7+
8+
Before diving into the receiver operating characteristic (ROC) curve, we will look at two plots that will give some context to the thresholds mechanism behind the ROC and PR curves.
9+
10+
In the histogram, we observe that the score spread such that most of the positive labels are binned near 1, and a lot of the negative labels are close to 0. When we set a threshold on the score, all of the bins to its left will be classified as 0's, and everything to the right will be 1's. There are obviously a few outliers, such as **negative** samples that our model gave a high score, and *positive* samples with a low score. If we set a threshold right in the middle, those outliers will respectively become **false positives** and *false negatives*.
11+
12+
As we adjust thresholds, the number of positive positives will increase or decrease, and at the same time the number of true positives will also change; this is shown in the second plot. As you can see, the model seems to perform fairly well, because the true positive rate and the false positive rate decreases sharply as we increase the threshold. Those two lines each represent a dimension of the ROC curve.
13+
14+
```{r}
15+
library(plotly)
16+
library(tidymodels)
17+
set.seed(0)
18+
X <- matrix(rnorm(10000),nrow=500)
19+
y <- sample(0:1, 500, replace=TRUE)
20+
data <- data.frame(X,y)
21+
data$y <- as.factor(data$y)
22+
X <- subset(data,select = -c(y))
23+
logistic_glm <-
24+
logistic_reg() %>%
25+
set_engine("glm") %>%
26+
set_mode("classification") %>%
27+
fit(y ~ ., data = data)
28+
29+
y_scores <- logistic_glm %>%
30+
predict(X, type = 'prob')
31+
32+
y_score <- y_scores$.pred_1
33+
db <- data.frame(data$y, y_score)
34+
35+
z <- roc_curve(data = db, 'data.y', 'y_score')
36+
z$specificity <- 1 - z$specificity
37+
colnames(z) <- c('threshold', 'tpr', 'fpr')
38+
39+
fig1 <- plot_ly(x= y_score, color = data$y, colors = c('blue', 'red'), type = 'histogram', alpha = 0.5, nbinsx = 50) %>%
40+
layout(barmode = "overlay")
41+
fig1
42+
43+
fig2 <- plot_ly(data = z, x = ~threshold) %>%
44+
add_trace(y = ~fpr, mode = 'lines', name = 'false positive rate', type = 'scatter')%>%
45+
add_trace(y = ~tpr, mode = 'lines', name = 'true positive rate', type = 'scatter')%>%
46+
layout(title = 'TPR and FPR at every threshold')
47+
fig2
48+
```
49+
50+
## Basic binary ROC curve
51+
52+
We display the area under the ROC curve (ROC AUC), which is fairly high, thus consistent with our interpretation of the previous plots.
53+
54+
```{r}
55+
library(dplyr)
56+
library(ggplot2)
57+
library(plotly)
58+
library(pROC)
59+
60+
set.seed(0)
61+
X <- matrix(rnorm(10000),nrow=500)
62+
y <- sample(0:1, 500, re 8000 place=TRUE)
63+
db <- data.frame(X,y)
64+
db$y <- as.factor(db$y)
65+
test_data = db[1:20]
66+
67+
model<- logistic_reg() %>%
68+
set_engine("glm") %>%
69+
set_mode("classification") %>%
70+
# Fit the model
71+
fit(y ~., data = db)
72+
73+
74+
ypred <- predict(model,
75+
new_data = test_data,
76+
type = "prob")
77+
78+
yscore <- data.frame(ypred$.pred_0)
79+
rdb <- cbind(db$y,yscore)
80+
colnames(rdb) = c('y','yscore')
81+
82+
83+
pdb <- roc_curve(rdb, y, yscore)
84+
pdb$specificity <- 1 - pdb$specificity
85+
auc = roc_auc(rdb, y, yscore)
86+
auc = auc$.estimate
87+
88+
tit = paste('ROC Curve (Auc = ',toString(round(auc,2)),')',sep = '')
89+
90+
fig <- plot_ly(data = pdb ,x = ~specificity, y = ~sensitivity, type = 'scatter', mode = 'lines', fill = 'tozeroy') %>%
91+
layout(title = tit,xaxis = list(title = "False Positive Rate"), yaxis = list(title = "True Positive Rate")) %>%
92+
add_segments(x = 0, xend = 1, y = 0, yend = 1, line = list(dash = "dash", color = 'black'),inherit = FALSE, showlegend = FALSE)
93+
fig
94+
```
95+
96+
97+
98+
## Multiclass ROC Curve
99+
100+
When you have more than 2 classes, you will need to plot the ROC curve for each class separately. Make sure that you use a [one-versus-rest](https://cran.r-project.org/web/packages/multiclassPairs/vignettes/Tutorial.html) model, or make sure that your problem has a multi-label format; otherwise, your ROC curve might not return the expected results.
101+
102+
```{r}
103+
library(plotly)
104+
library(tidymodels)
105+
library(fastDummies)
106+
107+
data(iris)
108+
ind <- sample.int(150, 50)
109+
samples <- sample(x = iris$Species, size = 50)
110+
iris[ind,'Species'] = samples
111+
112+
X <- subset(iris, select = -c(Species))
113+
iris$Species <- as.factor(iris$Species)
114+
115+
logistic <-
116+
multinom_reg() %>%
117+
set_engine("nnet") %>%
118+
set_mode("classification") %>%
119+
fit(Species ~ ., data = iris)
120+
121+
y_scores <- logistic %>%
122+
predict(X, type = 'prob')
123+
124+
y_onehot <- dummy_cols(iris$Species)
125+
colnames(y_onehot) <- c('drop', 'setosa', 'versicolor', 'virginica')
126+
y_onehot <- subset(y_onehot, select = -c(drop))
127+
128+
z = cbind(y_scores, y_onehot)
129+
130+
z$setosa <- as.factor(z$setosa)
131+
roc_setosa <- roc_curve(data = z, setosa, .pred_setosa)
132+
roc_setosa$specificity <- 1 - roc_setosa$specificity
133+
colnames(roc_setosa) <- c('threshold', 'tpr', 'fpr')
134+
auc_setosa <- roc_auc(data = z, setosa, .pred_setosa)
135+
auc_setosa <- auc_setosa$.estimate
136+
setosa <- paste('setosa (AUC=',toString(round(1-auc_setosa,2)),')',sep = '')
137+
138+
z$versicolor <- as.factor(z$versicolor)
139+
roc_versicolor <- roc_curve(data = z, versicolor, .pred_versicolor)
140+
roc_versicolor$specificity <- 1 - roc_versicolor$specificity
141+
colnames(roc_versicolor) <- c('threshold', 'tpr', 'fpr')
142+
auc_versicolor <- roc_auc(data = z, versicolor, .pred_versicolor)
143+
auc_versicolor <- auc_versicolor$.estimate
144+
versicolor <- paste('versicolor (AUC=',toString(round(1-auc_versicolor,2)),')', sep = '')
145+
146+
z$virginica <- as.factor(z$virginica)
147+
roc_virginica <- roc_curve(data = z, virginica, .pred_virginica)
148+
roc_virginica$specificity <- 1 - roc_virginica$specificity
149+
colnames(roc_virginica) <- c('threshold', 'tpr', 'fpr')
150+
auc_virginica <- roc_auc(data = z, virginica, .pred_virginica)
151+
auc_virginica <- auc_virginica$.estimate
152+
virginica <- paste('virginica (AUC=',toString(round(1-auc_virginica,2)),')',sep = '')
153+
154+
fig <- plot_ly()%>%
155+
add_segments(x = 0, xend = 1, y = 0, yend = 1, line = list(dash = "dash", color = 'black'), showlegend = FALSE) %>%
156+
add_trace(data = roc_setosa,x = ~fpr, y = ~tpr, mode = 'lines', name = setosa, type = 'scatter')%>%
157+
add_trace(data = roc_versicolor,x = ~fpr, y = ~tpr, mode = 'lines', name = versicolor, type = 'scatter')%>%
158+
add_trace(data = roc_virginica,x = ~fpr, y = ~tpr, mode = 'lines', name = virginica, type = 'scatter')%>%
159+
layout(xaxis = list(
160+
title = "False Positive Rate"
161+
), yaxis = list(
162+
title = "True Positive Rate"
163+
),legend = list(x = 100, y = 0.5))
164+
fig
165+
166+
```
167+
168+
169+
## Precision-Recall Curves
170+
171+
Plotting the PR curve is very similar to plotting the ROC curve. The following examples are slightly modified from the previous examples:
172+
173+
```{r}
174+
library(dplyr)
175+
library(ggplot2)
176+
library(plotly)
177+
library(pROC)
178+
179+
set.seed(0)
180+
X <- matrix(rnorm(10000),nrow=500)
181+
y <- sample(0:1, 500, replace=TRUE)
182+
db <- data.frame(X,y)
183+
db$y <- as.factor(db$y)
184+
test_data = db[1:20]
185+
186+
model<- logistic_reg() %>%
187+
set_engine("glm") %>%
188+
set_mode("classification") %>%
189+
# Fit the model
190+
fit(y ~., data = db)
191+
192+
ypred <- predict(model,
193+
new_data = test_data,
194+
type = "prob")
195+
196+
yscore <- data.frame(ypred$.pred_0)
197+
rdb <- cbind(db$y,yscore)
198+
colnames(rdb) = c('y','yscore')
199+
200+
pdb <- pr_curve(rdb, y, yscore)
201+
auc = roc_auc(rdb, y, yscore)
202+
auc = auc$.estimate
203+
204+
tit = paste('ROC Curve (Auc = ',toString(round(auc,2)),')',sep = '')
205+
206+
fig <- plot_ly(data = pdb ,x = ~recall, y = ~precision, type = 'scatter', mode = 'lines', fill = 'tozeroy') %>%
207+
add_segments(x = 0, xend = 1, y = 1, yend = 0, line = list(dash = "dash", color = 'black'),inherit = FALSE, showlegend = FALSE) %>%
208+
layout(title = tit, xaxis = list(title = "Recall"), yaxis = list(title = "Precision") )
209+
210+
fig
211+
```
212+
213+
In this example, we use the average precision metric, which is an alternative scoring method to the area under the PR curve.
214+
215+
```{r}
216+
library(plotly)
217+
library(tidymodels)
218+
library(fastDummies)
219+
220+
data(iris)
221+
ind <- sample.int(150, 50)
222+
samples <- sample(x = iris$Species, size = 50)
223+
iris[ind,'Species'] = samples
224+
225+
X <- subset(iris, select = -c(Species))
226+
iris$Species <- as.factor(iris$Species)
227+
228+
logistic <-
229+
multinom_reg() %>%
230+
set_engine("nnet") %>%
231+
set_mode("classification") %>%
232+
fit(Species ~ ., data = iris)
233+
234+
y_scores <- logistic %>%
235+
predict(X, type = 'prob')
236+
237+
y_onehot <- dummy_cols(iris$Species)
238+
colnames(y_onehot) <- c('drop', 'setosa', 'versicolor', 'virginica')
239+
y_onehot <- subset(y_onehot, select = -c(drop))
240+
241+
z = cbind(y_scores, y_onehot)
242+
243+
z$setosa <- as.factor(z$setosa)
244+
pr_setosa <- pr_curve(data = z, setosa, .pred_setosa)
245+
aps_setosa <- mean(pr_setosa$precision)
246+
setosa <- paste('setosa (AP =',toString(round(aps_setosa,2)),')',sep = '')
247+
248+
249+
z$versicolor <- as.factor(z$versicolor)
250+
pr_versicolor <- pr_curve(data = z, versicolor, .pred_versicolor)
251+
aps_versicolor <- mean(pr_versicolor$precision)
252+
versicolor <- paste('versicolor (AP = ',toString(round(aps_versicolor,2)),')',sep = '')
253+
254+
z$virginica <- as.factor(z$virginica)
255+
pr_virginica <- pr_curve(data = z, virginica, .pred_virginica)
256+
aps_virginica <- mean(pr_virginica$precision)
257+
virginica <- paste('virginica (AP = ',toString(round(aps_virginica,2)),')',sep = '')
258+
259+
260+
fig <- plot_ly()%>%
261+
add_segments(x = 0, xend = 1, y = 1, yend = 0, line = list(dash = "dash", color = 'black'), showlegend = FALSE) %>%
262+
add_trace(data = pr_setosa,x = ~recall, y = ~precision, mode = 'lines', name = setosa, type = 'scatter')%>%
263+
add_trace(data = pr_versicolor,x = ~recall, y = ~precision, mode = 'lines', name = versicolor, type = 'scatter')%>%
264+
add_trace(data = pr_virginica,x = ~recall, y = ~precision, mode = 'lines', name = virginica, type = 'scatter')%>%
265+
layout(xaxis = list(
266+
title = "Recall"
267+
), yaxis = list(
268+
title = "Precision"
269+
),legend = list(x = 100, y = 0.5))
270+
fig
271+
```
272+
273+
274+
## References
275+
276+
277+
Learn more about histograms, filled area plots and line charts:
278+
279+
* https://plot.ly/r/histograms/
280+
281+
* https://plot.ly/r/filled-area-plots/
282+
283+
* https://plot.ly/r/line-charts/
284+
285+
286+
287+

0 commit comments

Comments
 (0)
0