8000 Merge pull request #63 from plotly/rpy-parity · plotly/plotly.r-docs@3742c1c · GitHub
[go: up one dir, main page]

Skip to content

Commit 3742c1c

Browse files
authored
Merge pull request #63 from plotly/rpy-parity
R-Py parity for ml-regression page
2 parents eada8db + 036b474 commit 3742c1c

File tree

2 files changed

+360
-2
lines changed

2 files changed

+360
-2
lines changed

.circleci/config.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@ jobs:
2424
- run:
2525
name: install application-level dependencies
2626
command: |
27-
sudo apt-get install -y pandoc libudunits2-dev libgdal-dev libxt-dev libglu1-mesa-dev libfftw3-dev libglpk40
28-
sudo R -e 'install.packages(c("curl", "devtools", "mvtnorm", "hexbin")); devtools::install_github("hypertidy/anglr"); devtools::install_github("ropensci/plotly"); devtools::install_github("johannesbjork/LaCroixColoR"); install.packages("BiocManager"); BiocManager::install("EBImage"); devtools::install_deps(dependencies = TRUE) '
27+
sudo apt-get install -y pandoc libudunits2-dev libgdal-dev libxt-dev libglu1-mesa-dev libfftw3-dev libglpk40 libxml2-dev libcurl4-openssl-dev apt-transport-https software-properties-common
28+
sudo R -e 'install.packages(c("curl", "devtools", "mvtnorm", "hexbin", "tidyverse", "tidymodels", "kknn", "kernlab", "pracma", "reshape2", "ggplot2", "datasets")); devtools::install_github("ropensci/plotly"); devtools::install_github("johannesbjork/LaCroixColoR"); install.packages("BiocManager"); BiocManager::install("EBImage"); devtools::install_deps(dependencies = TRUE) '
29+
sudo R -e 'install.packages("https://github.com/hypertidy/anglr/archive/refs/tags/v0.7.0.tar.gz", repos=NULL, type="source"); devtools::install_deps(dependencies = TRUE) '
2930
- save_cache:
3031
key: cache4
3132
paths:

r/2021-07-08-ml-regression.Rmd

Lines changed: 357 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,357 @@
1+
<!-- #region -->
2+
This page shows how to use Plotly charts for displaying various types of regression models, starting from simple models like [Linear Regression](https://parsnip.tidymodels.org/reference/linear_reg.html) and progressively move towards models like Decision Tree and Polynomial Features. We highlight various capabilities of plotly, such as comparative analysis of the same model with different parameters, displaying Latex, and [surface plots](https://plotly.com/r/3d-surface-plots/) for 3D data.
3+
4+
We will use [tidymodels](https://tidymodels.tidymodels.org/) to split and preprocess our data and train various regression models. Tidymodels is a popular Machine Learning (ML) library in R that is compatible with the "tidyverse" concepts, and offers various tools for creating and training ML algorithms, feature engineering, data cleaning, and evaluating and testing models. It is the next-gen version of the popular [caret](http://topepo.github.io/caret/index.html) library for R.
5+
6+
<!-- #endregion -->
7+
8+
## Basic linear regression plots
9+
10+
In this section, we show you how to apply a simple regression model for predicting tips a server will receive based on various client attributes (such as sex, time of the week, and whether they are a smoker).
11+
12+
We will be using the [Linear Regression][lr], which is a simple model that fits an intercept (the mean tip received by a server), and adds a slope for each feature we use, such as the value of the total bill.
13+
14+
[lr]: https://parsnip.tidymodels.org/reference/linear_reg.html
15+
16+
### Linear Regression with R
17+
18+
```{r}
19+
library(reshape2) # to load tips data
20+
library(tidyverse)
21+
library(tidymodels) # for the fit() function
22+
library(plotly)
23+
data(tips)
24+
25+
y <- tips$tip
26+
X <- tips$total_bill
27+
28+
lm_model <- linear_reg() %>%
29+
set_engine('lm') %>%
30+
set_mode('regression') %>%
31+
fit(tip ~ total_bill, data = tips)
32+
33+
x_range <- seq(min(X), max(X), length.out = 100)
34+
x_range <- matrix(x_range, nrow=100, ncol=1)
35+
xdf <- data.frame(x_range)
36+
colnames(xdf) <- c('total_bill')
37+
38+
ydf <- lm_model %>% predict(xdf)
39+
40+
colnames(ydf) <- c('tip')
41+
xy <- data.frame(xdf, ydf)
42+
43+
fig <- plot_ly(tips, x = ~total_bill, y = ~tip, type = 'scatter', alpha = 0.65, mode = 'markers', name = 'Tips')
44+
fig <- fig %>% add_trace(data = xy, x = ~total_bill, y = ~tip, name = 'Regression Fit', mode = 'lines', alpha = 1)
45+
fig
46+
```
47+
## Model generalization on unseen data
48+
49+
With `add_trace()`, you can easily color your plot based on a predefined data split. By coloring the training and the testing data points with different colors, you can easily see if the model generalizes well to the test data or not.
50+
51+
```{r}
52+
library(reshape2)
53+
library(tidyverse)
54+
library(tidymodels)
55+
library(plotly)
56+
data(tips)
57+
58+
y <- tips$tip
59+
X <- tips$total_bill
60+
61+
set.seed(123)
62+
tips_split <- initial_split(tips)
63+
tips_training <- tips_split %>%
64+
training()
65+
tips_test <- tips_split %>%
66+
testing()
67+
68+
lm_model <- linear_reg() %>%
69+
set_engine('lm') %>%
70+
set_mode('regression') %>%
71+
fit(tip ~ total_bill, data = tips_training)
72+
73+
x_range <- seq(min(X), max(X), length.out = 100)
74+
x_range <- matrix(x_range, nrow=100, ncol=1)
75+
xdf <- data.frame(x_range)
76+
colnames(xdf) <- c('total_bill')
77+
78+
ydf <- lm_model %>%
79+
predict(xdf)
80+
81+
colnames(ydf) <- c('tip')
82+
xy <- data.frame(xdf, ydf)
83+
84+
fig <- plot_ly(data = tips_training, x = ~total_bill, y = ~tip, type = 'scatter', name = 'train', mode = 'markers', alpha = 0.65) %>%
85+
add_trace(data = tips_test, x = ~total_bill, y = ~tip, type = 'scatter', name = 'test', mode = 'markers', alpha = 0.65 ) %>%
86+
add_trace(data = xy, x = ~total_bill, y = ~tip, name = 'prediction', mode = 'lines', alpha = 1)
87+
fig
88+
```
89+
90+
## Comparing different kNN models parameters
91+
92+
In addition to linear regression, it's possible to fit the same data using [k-Nearest Neighbors][knn]. When you perform a prediction on a new sample, this model either takes the weighted or un-weighted average of the neighbors. In order to see the difference between those two averaging options, we train a kNN model with both of those parameters, and we plot them in the same way as the previous graph.
93+
94+
Notice how we can combine scatter points with lines using Plotly. You can learn more about [multiple chart types](https://plotly.com/r/graphing-multiple-chart-types/).
95+
96+
[knn]: http://klausvigo.github.io/kknn/
97+
98+
```{r}
99+
library(reshape2)
100+
library(tidyverse)
101+
library(tidymodels)
102+
library(plotly)
103+
library(kknn)
104+
data(tips)
105+
106+
y <- tips$tip
107+
X <- tips$total_bill
108+
109+
# Model #1
110+
knn_dist <- nearest_neighbor(neighbors = 10, weight_func = 'inv') %>%
111+
set_engine('kknn') %>%
112+
set_mode('regression') %>%
113+
fit(tip ~ total_bill, data = tips)
114+
115+
# Model #2
116+
knn_uni <- nearest_neighbor(neighbors = 10, weight_func = 'rectangular') %>%
117+
set_engine('kknn') %>%
118+
set_mode('regression') %>%
119+
fit(tip ~ total_bill, data = tips)
120+
121+
x_range <- seq(min(X), max(X), length.out = 100)
122+
x_range <- matrix(x_range, nrow=100, ncol=1)
123+
xdf <- data.frame(x_range)
124+
colnames(xdf) <- c('total_bill')
125+
126+
y_dist <- knn_dist %>%
127+
predict(xdf)
128+
y_uni <- knn_uni %>%
129+
predict(xdf)
130+
131+
colnames(y_dist) <- c('dist')
132+
colnames(y_uni) <- c('uni')
133+
xy <- data.frame(xdf, y_dist, y_uni)
134+
135+
fig <- plot_ly(tips, type = 'scatter', mode = 'markers', colors = c("#FF7F50", "#6495ED")) %>%
136+
add_trace(data = tips, x = ~total_bill, y = ~tip, type = 'scatter', mode = 'markers', color = ~sex, alpha = 0.65) %>%
137+
add_trace(data = xy, x = ~total_bill, y = ~dist, name = 'Weights: Distance', mode = 'lines', alpha = 1) %>%
138+
add_trace(data = xy, x = ~total_bill, y = ~uni, name = 'Weights: Uniform', mode = 'lines', alpha = 1)
139+
fig
140+
```
141+
142+
## 3D regression surface with `mesh3d` and `add_surface`
143+
144+
Visualize the decision plane of your model whenever you have more than one variable in your input data. Here, we will use [`svm_rbf`](https://parsnip.tidymodels.org/reference/svm_rbf.html) with [`kernlab`](https://cran.r-project.org/web/packages/kernlab/index.html) engine in `regression` mode. For generating the 2D mesh on the surface, we use the [`pracma`](https://cran.r-project.org/web/packages/pracma/index.html) package.
145+
146+
```{r}
147+
library(reshape2)
148+
library(tidyverse)
149+
library(tidymodels)
150+
library(plotly)
151+
library(kernlab)
152+
library(pracma) #For meshgrid()
153+
data(iris)
154+
155+
mesh_size <- .02
156+
margin <- 0
157+
X <- iris %>% select(Sepal.Width, Sepal.Length)
158+
y <- iris %>% select(Petal.Width)
159+
160+
model <- svm_rbf(cost = 1.0) %>%
161+
set_engine("kernlab") %>%
162+
set_mode("regression") %>%
163+
fit(Petal.Width ~ Sepal.Width + Sepal.Length, data = iris)
164+
165+
x_min <- min(X$Sepal.Width) - margin
166+
x_max <- max(X$Sepal.Width) - margin
167+
y_min <- min(X$Sepal.Length) - margin
168+
y_max <- max(X$Sepal.Length) - margin
169+
xrange <- seq(x_min, x_max, mesh_size)
170+
yrange <- seq(y_min, y_max, mesh_size)
171+
xy <- meshgrid(x = xrange, y = yrange)
172+
xx <- xy$X
173+
yy <- xy$Y
174+
dim_val <- dim(xx)
175+
xx1 <- matrix(xx, length(xx), 1)
176+
yy1 <- matrix(yy, length(yy), 1)
177+
final <- cbind(xx1, yy1)
178+
pred <- model %>%
179+
predict(final)
180+
181+
pred <- pred$.pred
182+
pred <- matrix(pred, dim_val[1], dim_val[2])
183+
184+
fig <- plot_ly(iris, x = ~Sepal.Width, y = ~Sepal.Length, z = ~Petal.Width ) %>%
185+
add_markers(size = 5) %>%
186+
add_surface(x=xrange, y=yrange, z=pred, alpha = 0.65, type = 'mesh3d', name = 'pred_surface')
187+
fig
188+
189+
```
190+
## Prediction Error Plots
191+
192+
When you are working with very high-dimensional data, it is inconvenient to plot every dimension with your output `y`. Instead, you can use methods such as prediction error plots, which let you visualize how well your model does compared to the ground truth.
193+
194+
### Simple actual vs predicted plot
195+
196+
This example shows you the simplest way to compare the predicted output vs. the actual output. A good model will have most of the scatter dots near the diagonal black line.
197+
198+
```{r}
199+
library(tidyverse)
200+
library(tidymodels)
201+
library(plotly)
202+
library(ggplot2)
203+
204+
data("iris")
205+
206+
X <- data.frame(Sepal.Width = c(iris$Sepal.Width), Sepal.Length = c(iris$Sepal.Length))
207+
y <- iris$Petal.Width
208+
209+
lm_model <- linear_reg() %>%
210+
set_engine('lm') %>%
211+
set_mode('regression') %>%
212+
fit(Petal.Width ~ Sepal.Width + Sepal.Length, data = iris)
213+
214+
y_pred <- lm_model %>%
215+
predict(X)
216+
217+
db = cbind(iris, y_pred)
218+
219+
colnames(db)[4] <- "Ground_truth"
220+
colnames(db)[6] <- "prediction"
221+
222+
x0 = min(y)
223+
y0 = max(y)
224+
x1 = max(y)
225+
y1 = max(y)
226+
p1 <- ggplot(db, aes(x= Ground_truth, y= prediction )) +
227+
geom_point(aes(color = "Blue"), show.legend = FALSE) + geom_segment(aes(x = x0, y = x0, xend = y1, yend = y1 ),linetype = 2)
228+
229+
230+
p1 <- ggplotly(p1)
231+
p1
232+
233+
```
234+
235+
### Enhanced prediction error analysis using `ggplotly`
236+
237+
Add marginal histograms to quickly diagnoses any prediction bias your model might have.
238+
239+
```{r}
240+
library(plotly)
241+
library(ggplot2)
242+
library(tidyverse)
243+
library(tidymodels)
244+
data(iris)
245+
246+
X <- iris %>% select(Sepal.Width, Sepal.Length)
247+
y <- iris %>% select(Petal.Width)
248+
249+
set.seed(0)
250+
iris_split <- initial_split(iris, prop = 3/4)
251+
iris_training <- iris_split %>%
252+
training()
253+
iris_test <- iris_split %>%
254+
testing()
255+
256+
train_index <- as.integer(rownames(iris_training))
257+
test_index <- as.integer(rownames(iris_test))
258+
259+
iris[train_index,'split'] = 'train'
260+
iris[test_index,'split'] = 'test'
261+
262+
lm_model <- linear_reg() %>%
263+
set_engine('lm') %>%
264+
set_mode('regression') %>%
265+
fit(Petal.Width ~ Sepal.Width + Sepal.Length, data = iris_training)
266+
267+
prediction <- lm_model %>%
268+
predict(X)
269+
colnames(prediction) <- c('prediction')
270+
iris = cbind(iris, prediction)
271+
272+
hist_top <- ggplot(iris,aes(x=Petal.Width)) +
273+
geom_histogram(data=subset(iris,split == 'train'),fill = "red", alpha = 0.2, bins = 6) +
274+
geom_histogram(data=subset(iris,split == 'test'),fill = "blue", alpha = 0.2, bins = 6) +
275+
theme(axis.title.y=element_blank(),axis.text.y=element_blank(),axis.ticks.y=element_blank())
276+
hist_top <- ggplotly(p = hist_top)
277+
278+
scatter <- ggplot(iris, aes(x = Petal.Width, y = prediction, color = split)) +
279+
geom_point() +
280+
geom_smooth(formula=y ~ x, method=lm, se=FALSE)
281+
scatter <- ggplotly(p = scatter, type = 'scatter')
282+
283+
hist_right <- ggplot(iris,aes(x=prediction)) +
284+
geom_histogram(data=subset(iris,split == 'train'),fill = "red", alpha = 0.2, bins = 13) +
285+
geom_histogram(data=subset(iris,split == 'test'),fill = "blue", alpha = 0.2, bins = 13) +
286+
theme(axis.title.x=element_blank(),axis.text.x=element_blank(),axis.ticks.x=element_blank())+
287+
coord_flip()
288+
hist_right <- ggplotly(p = hist_right)
289+
290+
s <- subplot(
291+
hist_top,
292+
plotly_empty(),
293+
scatter,
294+
hist_right,
295+
nrows = 2, heights = c(0.2, 0.8), widths = c(0.8, 0.2), margin = 0,
296+
shareX = TRUE, shareY = TRUE, titleX = TRUE, titleY = TRUE
297+
)
298+
layout(s, showlegend = FALSE)
299+
300+
```
301+
## Residual plots
302+
Just like prediction error plots, it's easy to visualize your prediction residuals in just a few lines of codes using `ggplotly` and `tidymodels` capabilities.
303+
```{r}
304+
library(plotly)
305+
library(ggplot2)
306+
library(tidyverse)
307+
library(tidymodels)
308+
309+
data(iris)
310+
311+
X <- iris %>% select(Sepal.Width, Sepal.Length)
312+
y <- iris %>% select(Petal.Width)
313+
314+
set.seed(0)
315+
iris_split <- initial_split(iris, prop = 3/4)
316+
iris_training <- iris_split %>%
317+
training()
318+
iris_test <- iris_split %>%
319+
testing()
320+
321+
train_index <- as.integer(rownames(iris_training))
322+
test_index <- as.integer(rownames(iris_test))
323+
324+
iris[train_index,'split'] = 'train'
325+
iris[test_index,'split'] = 'test'
326+
327+
lm_model <- linear_reg() %>%
328+
set_engine('lm') %>%
329+
set_mode('regression') %>%
330+
fit(Petal.Width ~ Sepal.Width + Sepal.Length, data = iris_training)
331+
332+
prediction <- lm_model %>%
333+
predict(X)
334+
colnames(prediction) <- c('prediction')
335+
iris = cbind(iris, prediction)
336+
residual <- prediction - iris$Petal.Width
337+
colnames(residual) <- c('residual')
338+
iris = cbind(iris, residual)
339+
340+
scatter <- ggplot(iris, aes(x = prediction, y = residual, color = split)) +
341+
geom_point() +
342+
geom_smooth(formula=y ~ x, method=lm, se=FALSE)
343+
344+
scatter <- ggplotly(p = scatter, type = 'scatter')
345+
346+
violin <- iris %>%
347+
plot_ly(x = ~split, y = ~residual, split = ~split, type = 'violin' )
348+
349+
s <- subplot(
350+
scatter,
351+
violin,
352+
nrows = 1, heights = c(1), widths = c(0.65, 0.35), margin = 0.01,
353+
shareX = TRUE, shareY = TRUE, titleX = TRUE, titleY = TRUE
354+
)
355+
356+
layout(s, showlegend = FALSE)
357+
```

0 commit comments

Comments
 (0)
0