ROC and PR Curves in R
Interpret the results of your classification using Receiver Operating Characteristics (ROC) and Precision-Recall (PR) Curves in R with Plotly.
New to Plotly?
Plotly is a free and open-source graphing library for R. We recommend you read our Getting Started guide for the latest installation or upgrade instructions, then move on to our Plotly Fundamentals tutorials or dive straight in to some Basic Charts tutorials.
ROC and PR Curves in R
Interpret the results of your classification using Receiver Operating Characteristics (ROC) and Precision-Recall (PR) Curves in R with Plotly.
Preliminary plots
Before diving into the receiver operating characteristic (ROC) curve, we will look at two plots that will give some context to the thresholds mechanism behind the ROC and PR curves.
In the histogram, we observe that the score spread such that most of the positive labels are binned near 1, and a lot of the negative labels are close to 0. When we set a threshold on the score, all of the bins to its left will be classified as 0's, and everything to the right will be 1's. There are obviously a few outliers, such as negative samples that our model gave a high score, and positive samples with a low score. If we set a threshold right in the middle, those outliers will respectively become false positives and false negatives.
As we adjust thresholds, the number of false positives will increase or decrease, and at the same time the number of true positives will also change; this is shown in the second plot. As you can see, the model seems to perform fairly well, because the true positive rate and the false positive rate decreases sharply as we increase the threshold. Those two lines each represent a dimension of the ROC curve.
library(plotly)
library(tidymodels)
set.seed(0)
X <- matrix(rnorm(10000),nrow=500)
y <- sample(0:1, 500, replace=TRUE)
data <- data.frame(X,y)
data$y <- as.factor(data$y)
X <- subset(data,select = -c(y))
logistic_glm <-
logistic_reg() %>%
set_engine("glm") %>%
set_mode("classification") %>%
fit(y ~ ., data = data)
y_scores <- logistic_glm %>%
predict(X, type = 'prob')
y_score <- y_scores$.pred_1
db <- data.frame(data$y, y_score)
z <- roc_curve(data = db, 'data.y', 'y_score')
z$specificity <- 1 - z$specificity
colnames(z) <- c('threshold', 'tpr', 'fpr')
fig1 <- plot_ly(x= y_score, color = data$y, colors = c('blue', 'red'), type = 'histogram', alpha = 0.5, nbinsx = 50) %>%
layout(barmode = "overlay")
fig1
fig2 <- plot_ly(data = z, x = ~threshold) %>%
add_trace(y = ~fpr, mode = 'lines', name = 'False Positive Rate', type = 'scatter')%>%
add_trace(y = ~tpr, mode = 'lines', name = 'True Positive Rate', type = 'scatter')%>%
layout(title = 'TPR and FPR at every threshold')
fig2 <- fig2 %>% layout(legend=list(title=list(text='<b> Rate </b>')))
fig2
Multiclass ROC Curve
When you have more than 2 classes, you will need to plot the ROC curve for each class separately. Make sure that you use a one-versus-rest model, or make sure that your problem has a multi-label format; otherwise, your ROC curve might not return the expected results.
library(plotly)
library(tidymodels)
library(fastDummies)
# Artificially add noise to make task harder
data(iris)
ind <- sample.int(150, 50)
samples <- sample(x = iris$Species, size = 50)
iris[ind,'Species'] = samples
# Define the inputs and outputs
X <- subset(iris, select = -c(Species))
iris$Species <- as.factor(iris$Species)
# Fit the model
logistic <-
multinom_reg() %>%
set_engine("nnet") %>%
set_mode("classification") %>%
fit(Species ~ ., data = iris)
y_scores <- logistic %>%
predict(X, type = 'prob')
# One hot encode the labels in order to plot them
y_onehot <- dummy_cols(iris$Species)
colnames(y_onehot) <- c('drop', 'setosa', 'versicolor', 'virginica')
y_onehot <- subset(y_onehot, select = -c(drop))
z = cbind(y_scores, y_onehot)
z$setosa <- as.factor(z$setosa)
roc_setosa <- roc_curve(data = z, setosa, .pred_setosa)
roc_setosa$specificity <- 1 - roc_setosa$specificity
colnames(roc_setosa) <- c('threshold', 'tpr', 'fpr')
auc_setosa <- roc_auc(data = z, setosa, .pred_setosa)
auc_setosa <- auc_setosa$.estimate
setosa <- paste('setosa (AUC=',toString(round(1-auc_setosa,2)),')',sep = '')
z$versicolor <- as.factor(z$versicolor)
roc_versicolor <- roc_curve(data = z, versicolor, .pred_versicolor)
roc_versicolor$specificity <- 1 - roc_versicolor$specificity
colnames(roc_versicolor) <- c('threshold', 'tpr', 'fpr')
auc_versicolor <- roc_auc(data = z, versicolor, .pred_versicolor)
auc_versicolor <- auc_versicolor$.estimate
versicolor <- paste('versicolor (AUC=',toString(round(1-auc_versicolor,2)),')', sep = '')
z$virginica <- as.factor(z$virginica)
roc_virginica <- roc_curve(data = z, virginica, .pred_virginica)
roc_virginica$specificity <- 1 - roc_virginica$specificity
colnames(roc_virginica) <- c('threshold', 'tpr', 'fpr')
auc_virginica <- roc_auc(data = z, virginica, .pred_virginica)
auc_virginica <- auc_virginica$.estimate
virginica <- paste('virginica (AUC=',toString(round(1-auc_virginica,2)),')',sep = '')
# Create an empty figure, and iteratively add a line for each class
fig <- plot_ly()%>%
add_segments(x = 0, xend = 1, y = 0, yend = 1, line = list(dash = "dash", color = 'black'), showlegend = FALSE) %>%
add_trace(data = roc_setosa,x = ~fpr, y = ~tpr, mode = 'lines', name = setosa, type = 'scatter')%>%
add_trace(data = roc_versicolor,x = ~fpr, y = ~tpr, mode = 'lines', name = versicolor, type = 'scatter')%>%
add_trace(data = roc_virginica,x = ~fpr, y = ~tpr, mode = 'lines', name = virginica, type = 'scatter')%>%
layout(xaxis = list(
title = "False Positive Rate"
), yaxis = list(
title = "True Positive Rate"
),legend = list(x = 100, y = 0.5))
fig
In this example, we use the average precision metric, which is an alternative scoring method to the area under the PR curve.
library(plotly)
library(tidymodels)
library(fastDummies)
# Artificially add noise to make task harder
data(iris)
ind <- sample.int(150, 50)
samples <- sample(x = iris$Species, size = 50)
iris[ind,'Species'] = samples
# Define the inputs and outputs
X <- subset(iris, select = -c(Species))
iris$Species <- as.factor(iris$Species)
# Fit the model
logistic <-
multinom_reg() %>%
set_engine("nnet") %>%
set_mode("classification") %>%
fit(Species ~ ., data = iris)
y_scores <- logistic %>%
predict(X, type = 'prob')
y_onehot <- dummy_cols(iris$Species)
colnames(y_onehot) <- c('drop', 'setosa', 'versicolor', 'virginica')
y_onehot <- subset(y_onehot, select = -c(drop))
z = cbind(y_scores, y_onehot)
z$setosa <- as.factor(z$setosa)
pr_setosa <- pr_curve(data = z, setosa, .pred_setosa)
aps_setosa <- mean(pr_setosa$precision)
setosa <- paste('setosa (AP =',toString(round(aps_setosa,2)),')',sep = '')
z$versicolor <- as.factor(z$versicolor)
pr_versicolor <- pr_curve(data = z, versicolor, .pred_versicolor)
aps_versicolor <- mean(pr_versicolor$precision)
versicolor <- paste('versicolor (AP = ',toString(round(aps_versicolor,2)),')',sep = '')
z$virginica <- as.factor(z$virginica)
pr_virginica <- pr_curve(data = z, virginica, .pred_virginica)
aps_virginica <- mean(pr_virginica$precision)
virginica <- paste('virginica (AP = ',toString(round(aps_virginica,2)),')',sep = '')
# Create an empty figure, and add a new line for each class
fig <- plot_ly()%>%
add_segments(x = 0, xend = 1, y = 1, yend = 0, line = list(dash = "dash", color = 'black'), showlegend = FALSE) %>%
add_trace(data = pr_setosa,x = ~recall, y = ~precision, mode = 'lines', name = setosa, type = 'scatter')%>%
add_trace(data = pr_versicolor,x = ~recall, y = ~precision, mode = 'lines', name = versicolor, type = 'scatter')%>%
add_trace(data = pr_virginica,x = ~recall, y = ~precision, mode = 'lines', name = virginica, type = 'scatter')%>%
layout(xaxis = list(
title = "Recall"
), yaxis = list(
title = "Precision"
),legend = list(x = 100, y = 0.5))
fig
References
Learn more about histograms, filled area plots and line charts:
https://plot.ly/r/histograms/
https://plot.ly/r/filled-area-plots/
https://plot.ly/r/line-charts/
What About Dash?
Dash for R is an open-source framework for building analytical applications, with no Javascript required, and it is tightly integrated with the Plotly graphing library.
Learn about how to install Dash for R at https://dashr.plot.ly/installation.
Everywhere in this page that you see fig
, you can display the same figure in a Dash for R application by passing it to the figure
argument of the Graph
component from the built-in dashCoreComponents
package like this:
library(plotly)
fig <- plot_ly()
# fig <- fig %>% add_trace( ... )
# fig <- fig %>% layout( ... )
library(dash)
library(dashCoreComponents)
library(dashHtmlComponents)
app <- Dash$new()
app$layout(
htmlDiv(
list(
dccGraph(figure=fig)
)
)
)
app$run_server(debug=TRUE, dev_tools_hot_reload=FALSE)