10000 Completed AI/ML section and started on Fundamentals Section by kvdesai · Pull Request #70 · plotly/plotly.r-docs · GitHub
[go: up one dir, main page]

Skip to content

Completed AI/ML section and started on Fundamentals Section #70

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 38 commits into from
Aug 17, 2021
Merged
Changes from 1 commit
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
da8f391
ml-knn page with make_moons.csv data
Jul 24, 2021
858b82a
Fixing text
Jul 26, 2021
c81963c
Adding ML ROC & PR page draft
Jul 27, 2021
5f3f90b
Merge branch 'rpy-parity-dev' of https://github.com/plotly/plotly.r-d…
Jul 28, 2021
739b1c1
Adding R page for PCA
Jul 28, 2021
eb138bc
ML page for t-SNE and UMAP
Jul 29, 2021
ea5d15b
Adding fundamentals/multiple-chart-types page draft
Aug 1, 2021
e41bacd
Adding page for Styling in Plotly with R
Aug 3, 2021
39bc7d8
committing horiz-vert shapes and figure label pages, without Dash
Aug 4, 2021
4248f8e
Adding dependencies=TRUE for Anglr
Aug 5, 2021
dfd7011
Merge branch 'rpy-parity' of https://github.com/plotly/plotly.r-docs …
Aug 5, 2021
8c4a8ed
Added Dash code
Aug 5, 2021
d9c481b
Merge pull request #69 from plotly/rpy-parity-dev
kvdesai Aug 5, 2021
5d68cc0
Build fix with Anglr issue
Aug 5, 2021
477cced
Build fix2 with Anglr issue
Aug 5, 2021
3dd73c3
Build fix3 with Anglr issue
Aug 5, 2021
5d65e56
Build fix4 for Anglr
kvdesai Aug 5, 2021
7bd525c
Added the Dash code for Figure-Labels page
Aug 6, 2021
cfca991
Merge branch 'rpy-parity' of https://github.com/plotly/plotly.r-docs …
Aug 6, 2021
c5d254c
Fixing the broken data URL
kvdesai Aug 6, 2021
90f139a
Update r/2021-07-27-ml-pca.rmd
kvdesai Aug 16, 2021
d9d083c
Update r/2021-07-27-ml-pca.rmd
kvdesai Aug 16, 2021
2026b0d
Update r/2021-07-28-ml-tsne-umap.rmd
kvdesai Aug 16, 2021
c4f782b
Update r/2021-07-28-ml-tsne-umap.rmd
kvdesai Aug 16, 2021
9477afc
Update r/2021-08-02-styling-plotly-in-r.rmd
kvdesai Aug 16, 2021
697b63d
Update r/2021-08-04-figure-labels.rmd
kvdesai Aug 16, 2021
24d5d26
Merge branch 'master' of https://github.com/plotly/plotly.r-docs into…
Aug 17, 2021
567ea3b
Merge branch 'rpy-parity' of https://github.com/plotly/plotly.r-docs …
Aug 17, 2021
75082eb
Improving aesthetics for plot background and grid lines
Aug 17, 2021
ff0008b
Adding front-matter tags and cleanup
HammadTheOne Aug 17, 2021
dc9ebd6
Renaming Rmd files
HammadTheOne Aug 17, 2021
1d91867
Adding explicit tsne install
HammadTheOne Aug 17, 2021
de2521a
Added explicit umap install
HammadTheOne Aug 17, 2021
0169502
Explicitly install rsvd
HammadTheOne Aug 17, 2021
e824a2d
Explicitly install dash
HammadTheOne Aug 17, 2021
5132343
Fixing dependencies
HammadTheOne Aug 17, 2021
8cf5d02
Fixing order and skipping Dash chunk eval
HammadTheOne Aug 17, 2021
f3175c6
Deleting old figure-labels page
HammadTheOne Aug 17, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Adding ML ROC & PR page draft
  • Loading branch information
Kalpit Desai committed Jul 27, 2021
commit c81963c87d060b2a6a307a4b6f9d9e573a57817b
287 changes: 287 additions & 0 deletions r/2021-07-26-ml-roc-pr.Rmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,287 @@

## ROC and PR Curves in R

Interpret the results of your classification using Receiver Operating Characteristics (ROC) and Precision-Recall (PR) Curves in R with Plotly.

## Preliminary plots

Before diving into the receiver operating characteristic (ROC) curve, we will look at two plots that will give some context to the thresholds mechanism behind the ROC and PR curves.

In the histogram, we observe that the score spread such that most of the positive labels are binned near 1, and a lot of the negative labels are close to 0. When we set a threshold on the score, all of the bins to its left will be classified as 0's, and everything to the right will be 1's. There are obviously a few outliers, such as **negative** samples that our model gave a high score, and *positive* samples with a low score. If we set a threshold right in the middle, those outliers will respectively become **false positives** and *false negatives*.

As we adjust thresholds, the number of positive positives will increase or decrease, and at the same time the number of true positives will also change; this is shown in the second plot. As you can see, the model seems to perform fairly well, because the true positive rate and the false positive rate decreases sharply as we increase the threshold. Those two lines each represent a dimension of the ROC curve.

```{r}
library(plotly)
library(tidymodels)
set.seed(0)
X <- matrix(rnorm(10000),nrow=500)
y <- sample(0:1, 500, replace=TRUE)
data <- data.frame(X,y)
data$y <- as.factor(data$y)
X <- subset(data,select = -c(y))
logistic_glm <-
logistic_reg() %>%
set_engine("glm") %>%
set_mode("classification") %>%
fit(y ~ ., data = data)

y_scores <- logistic_glm %>%
predict(X, type = 'prob')

y_score <- y_scores$.pred_1
db <- data.frame(data$y, y_score)

z <- roc_curve(data = db, 'data.y', 'y_score')
z$specificity <- 1 - z$specificity
colnames(z) <- c('threshold', 'tpr', 'fpr')

fig1 <- plot_ly(x= y_score, color = data$y, colors = c('blue', 'red'), type = 'histogram', alpha = 0.5, nbinsx = 50) %>%
layout(barmode = "overlay")
fig1

fig2 <- plot_ly(data = z, x = ~threshold) %>%
add_trace(y = ~fpr, mode = 'lines', name = 'false positive rate', type = 'scatter')%>%
add_trace(y = ~tpr, mode = 'lines', name = 'true positive rate', type = 'scatter')%>%
layout(title = 'TPR and FPR at every threshold')
fig2
```

## Basic binary ROC curve

We display the area under the ROC curve (ROC AUC), which is fairly high, thus consistent with our interpretation of the previous plots.

```{r}
library(dplyr)
library(ggplot2)
library(plotly)
library(pROC)

set.seed(0)
X <- matrix(rnorm(10000),nrow=500)
y <- sample(0:1, 500, replace=TRUE)
db <- data.frame(X,y)
db$y <- as.factor(db$y)
test_data = db[1:20]

model<- logistic_reg() %>%
set_engine("glm") %>%
set_mode("classification") %>%
# Fit the model
fit(y ~., data = db)


ypred <- predict(model,
new_data = test_data,
type = "prob")

yscore <- data.frame(ypred$.pred_0)
rdb <- cbind(db$y,yscore)
colnames(rdb) = c('y','yscore')


pdb <- roc_curve(rdb, y, yscore)
pdb$specificity <- 1 - pdb$specificity
auc = roc_auc(rdb, y, yscore)
auc = auc$.estimate

tit = paste('ROC Curve (Auc = ',toString(round(auc,2)),')',sep = '')

fig <- plot_ly(data = pdb ,x = ~specificity, y = ~sensitivity, type = 'scatter', mode = 'lines', fill = 'tozeroy') %>%
layout(title = tit,xaxis = list(title = "False Positive Rate"), yaxis = list(title = "True Positive Rate")) %>%
add_segments(x = 0, xend = 1, y = 0, yend = 1, line = list(dash = "dash", color = 'black'),inherit = FALSE, showlegend = FALSE)
fig
```



## Multiclass ROC Curve

When you have more than 2 classes, you will need to plot the ROC curve for each class separately. Make sure that you use a [one-versus-rest](https://cran.r-project.org/web/packages/multiclassPairs/vignettes/Tutorial.html) model, or make sure that your problem has a multi-label format; otherwise, your ROC curve might not return the expected results.

```{r}
library(plotly)
library(tidymodels)
library(fastDummies)

data(iris)
ind <- sample.int(150, 50)
samples <- sample(x = iris$Species, size = 50)
iris[ind,'Species'] = samples

X <- subset(iris, select = -c(Species))
iris$Species <- as.factor(iris$Species)

logistic <-
multinom_reg() %>%
set_engine("nnet") %>%
set_mode("classification") %>%
fit(Species ~ ., data = iris)

y_scores <- logistic %>%
predict(X, type = 'prob')

y_onehot <- dummy_cols(iris$Species)
colnames(y_onehot) <- c('drop', 'setosa', 'versicolor', 'virginica')
y_onehot <- subset(y_onehot, select = -c(drop))

z = cbind(y_scores, y_onehot)

z$setosa <- as.factor(z$setosa)
roc_setosa <- roc_curve(data = z, setosa, .pred_setosa)
roc_setosa$specificity <- 1 - roc_setosa$specificity
colnames(roc_setosa) <- c('threshold', 'tpr', 'fpr')
auc_setosa <- roc_auc(data = z, setosa, .pred_setosa)
auc_setosa <- auc_setosa$.estimate
setosa <- paste('setosa (AUC=',toString(round(1-auc_setosa,2)),')',sep = '')

z$versicolor <- as.factor(z$versicolor)
roc_versicolor <- roc_curve(data = z, versicolor, .pred_versicolor)
roc_versicolor$specificity <- 1 - roc_versicolor$specificity
colnames(roc_versicolor) <- c('threshold', 'tpr', 'fpr')
auc_versicolor <- roc_auc(data = z, versicolor, .pred_versicolor)
auc_versicolor <- auc_versicolor$.estimate
versicolor <- paste('versicolor (AUC=',toString(round(1-auc_versicolor,2)),')', sep = '')

z$virginica <- as.factor(z$virginica)
roc_virginica <- roc_curve(data = z, virginica, .pred_virginica)
roc_virginica$specificity <- 1 - roc_virginica$specificity
colnames(roc_virginica) <- c('threshold', 'tpr', 'fpr')
auc_virginica <- roc_auc(data = z, virginica, .pred_virginica)
auc_virginica <- auc_virginica$.estimate
virginica <- paste('virginica (AUC=',toString(round(1-auc_virginica,2)),')',sep = '')

fig <- plot_ly()%>%
add_segments(x = 0, xend = 1, y = 0, yend = 1, line = list(dash = "dash", color = 'black'), showlegend = FALSE) %>%
add_trace(data = roc_setosa,x = ~fpr, y = ~tpr, mode = 'lines', name = setosa, type = 'scatter')%>%
add_trace(data = roc_versicolor,x = ~fpr, y = ~tpr, mode = 'lines', name = versicolor, type = 'scatter')%>%
add_trace(data = roc_virginica,x = ~fpr, y = ~tpr, mode = 'lines', name = virginica, type = 'scatter')%>%
layout(xaxis = list(
title = "False Positive Rate"
), yaxis = list(
title = "True Positive Rate"
),legend = list(x = 100, y = 0.5))
fig

```


## Precision-Recall Curves

Plotting the PR curve is very similar to plotting the ROC curve. The following examples are slightly modified from the previous examples:

```{r}
library(dplyr)
library(ggplot2)
library(plotly)
library(pROC)

set.seed(0)
X <- matrix(rnorm(10000),nrow=500)
y <- sample(0:1, 500, replace=TRUE)
db <- data.frame(X,y)
db$y <- as.factor(db$y)
test_data = db[1:20]

model<- logistic_reg() %>%
set_engine("glm") %>%
set_mode("classification") %>%
# Fit the model
fit(y ~., data = db)

ypred <- predict(model,
new_data = test_data,
type = "prob")

yscore <- data.frame(ypred$.pred_0)
rdb <- cbind(db$y,yscore)
colnames(rdb) = c('y','yscore')

pdb <- pr_curve(rdb, y, yscore)
auc = roc_auc(rdb, y, yscore)
auc = auc$.estimate

tit = paste('ROC Curve (Auc = ',toString(round(auc,2)),')',sep = '')

fig <- plot_ly(data = pdb ,x = ~recall, y = ~precision, type = 'scatter', mode = 'lines', fill = 'tozeroy') %>%
add_segments(x = 0, xend = 1, y = 1, yend = 0, line = list(dash = "dash", color = 'black'),inherit = FALSE, showlegend = FALSE) %>%
layout(title = tit, xaxis = list(title = "Recall"), yaxis = list(title = "Precision") )

fig
```

In this example, we use the average precision metric, which is an alternative scoring method to the area under the PR curve.

```{r}
library(plotly)
library(tidymodels)
library(fastDummies)

data(iris)
ind <- sample.int(150, 50)
samples <- sample(x = iris$Species, size = 50)
iris[ind,'Species'] = samples

X <- subset(iris, select = -c(Species))
iris$Species <- as.factor(iris$Species)

logistic <-
multinom_reg() %>%
set_engine("nnet") %>%
set_mode("classification") %>%
fit(Species ~ ., data = iris)

y_scores <- logistic %>%
predict(X, type = 'prob')

y_onehot <- dummy_cols(iris$Species)
colnames(y_onehot) <- c('drop', 'setosa', 'versicolor', 'virginica')
y_onehot <- subset(y_onehot, select = -c(drop))

z = cbind(y_scores, y_onehot)

z$setosa <- as.factor(z$setosa)
pr_setosa <- pr_curve(data = z, setosa, .pred_setosa)
aps_setosa <- mean(pr_setosa$precision)
setosa <- paste('setosa (AP =',toString(round(aps_setosa,2)),')',sep = '')


z$versicolor <- as.factor(z$versicolor)
pr_versicolor <- pr_curve(data = z, versicolor, .pred_versicolor)
aps_versicolor <- mean(pr_versicolor$precision)
versicolor <- paste('versicolor (AP = ',toString(round(aps_versicolor,2)),')',sep = '')

z$virginica <- as.factor(z$virginica)
pr_virginica <- pr_curve(data = z, virginica, .pred_virginica)
aps_virginica <- mean(pr_virginica$precision)
virginica <- paste('virginica (AP = ',toString(round(aps_virginica,2)),')',sep = '')


fig <- plot_ly()%>%
add_segments(x = 0, xend = 1, y = 1, yend = 0, line = list(dash = "dash", color = 'black'), showlegend = FALSE) %>%
add_trace(data = pr_setosa,x = ~recall, y = ~precision, mode = 'lines', name = setosa, type = 'scatter')%>%
add_trace(data = pr_versicolor,x = ~recall, y = ~precision, mode = 'lines', name = versicolor, type = 'scatter')%>%
add_trace(data = pr_virginica,x = ~recall, y = ~precision, mode = 'lines', name = virginica, type = 'scatter')%>%
layout(xaxis = list(
title = "Recall"
), yaxis = list(
title = "Precision"
),legend = list(x = 100, y = 0.5))
fig
```


## References


Learn more about histograms, filled area plots and line charts:

* https://plot.ly/r/histograms/

* https://plot.ly/r/filled-area-plots/

* https://plot.ly/r/line-charts/




0