-
Notifications
You must be signed in to change notification settings - Fork 28
/
Copy pathcompare-smooths.R
188 lines (172 loc) · 5.23 KB
/
compare-smooths.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
#' Compare smooths across models
#'
#' @param model Primary model for comparison.
#' @param ... Additional models to compare smooths against those of `model`.
#' @param smooths `r lifecycle::badge("deprecated")` Use `select` instead.
#' @param select character; select which smooths to compare. The default
#' (`NULL`) means all smooths in `model` will be compared. Numeric `select`
#' indexes the smooths in the order they are specified in the formula and
#' stored in `model`. Character `select` matches the labels for smooths
#' as shown for example in the output from `summary(object)`. Logical
#' `select` operates as per numeric `select` in the order that smooths are
#' stored.
#' @param partial_match logical; should smooths be selected by partial matches
#' with `select`? If `TRUE`, `select` can only be a single string to match
#' against.
#'
#' @inheritParams smooth_estimates
#'
#' @export
#'
#' @importFrom rlang dots_list
#' @importFrom dplyr group_by
#' @importFrom lifecycle deprecated is_present
#'
#' @examples
#' \dontshow{
#' op <- options(cli.unicode = FALSE, pillar.sigfig = 5)
#' }
#' load_mgcv()
#' dat <- data_sim("eg1", seed = 2)
#'
#' ## models to compare smooths across - artificially create differences
#' m1 <- gam(y ~ s(x0, k = 5) + s(x1, k = 5) + s(x2, k = 5) + s(x3, k = 5),
#' data = dat, method = "REML"
#' )
#' m2 <- gam(y ~ s(x0, bs = "ts") + s(x1, bs = "ts") + s(x2, bs = "ts") +
#' s(x3, bs = "ts"), data = dat, method = "REML")
#'
#' ## build comparisons
#' comp <- compare_smooths(m1, m2)
#' comp
#' ## notice that the result is a nested tibble
#'
#' draw(comp)
#' \dontshow{
#' options(op)
#' }
`compare_smooths` <- function(model, ...,
select = NULL,
smooths = deprecated(),
n = 100,
data = NULL,
unconditional = FALSE,
overall_uncertainty = TRUE,
partial_match = FALSE) {
if (lifecycle::is_present(smooths)) {
lifecycle::deprecate_warn("0.8.9.9", "compare_smooths(smooths)",
"compare_smooths(select)")
select <- smooths
}
## grab ...
dots <- rlang::dots_list(..., .named = TRUE)
model_names <- c(deparse(substitute(model)), names(dots))
if (length(dots) < 1L) {
stop("Need at least two models to compare smooths",
call. = FALSE
)
}
## combine model and others into a list
models <- append(list(model), dots)
if (is.null(smooths)) {
smooths <- Reduce(union, lapply(models, smooths))
} else {
# user supplied smooth vector, check that those smooths exist in models
}
## loop over the smooths, applying smooth_estimates to each model
sm_est <- lapply(models, smooth_estimates,
select = select,
n = n,
data = data,
uncondtional = unconditional,
overall_uncertainty = overall_uncertainty,
unnest = FALSE, partial_match = partial_match
)
## loop over list of smooth estimates and add model column
for (i in seq_along(sm_est)) {
sm_est[[i]] <- add_column(sm_est[[i]],
.model = model_names[i],
.before = 1L
)
}
`unnest_nest` <- function(x) {
x |>
group_by(.data$.smooth) |>
group_split() |>
purrr::map(unnest, cols = all_of("data")) |>
purrr::map(nest, data = !all_of(c(
".model", ".smooth",
".type", ".by"
))) |>
bind_rows()
}
sm_est <- purrr::map(sm_est, unnest_nest)
sm_est <- bind_rows(sm_est) |>
arrange(.data$.smooth)
class(sm_est) <- c("compare_smooths", class(sm_est))
sm_est
}
#' Plot comparisons of smooths
#'
#' @param object of class `"compare_smooths"`, the result of a call to
#' [gratia::compare_smooths()].
#' @inheritParams draw.gam
#'
#' @export
#' @importFrom dplyr group_split
#' @importFrom purrr map
#' @importFrom patchwork wrap_plots
`draw.compare_smooths` <- function(object,
ncol = NULL, nrow = NULL,
guides = "collect",
...) {
l <- group_split(object, .data$.smooth)
plts <- map(l, plot_comparison_of_smooths)
## return
n_plots <- length(plts)
if (is.null(ncol) && is.null(nrow)) {
ncol <- ceiling(sqrt(n_plots))
nrow <- ceiling(n_plots / ncol)
}
wrap_plots(plts,
byrow = TRUE, ncol = ncol, nrow = nrow, guides = guides,
...
)
}
#' @importFrom dplyr mutate
#' @importFrom tidyr unnest
#' @importFrom ggplot2 ggplot geom_ribbon geom_line
`plot_comparison_of_smooths` <- function(object, coverage = 0.95, ...) {
## get the covariate labels
sm_vars <- vars_from_label(unique(object[[".smooth"]]))
## unnest data cols
object <- unnest(object, cols = all_of("data"))
## compute the critical value
crit <- coverage_normal(coverage)
## add the frequentist confidence interval
object <- object |> add_confint(coverage = coverage)
## basic plot
plt <- ggplot(object, aes(
x = .data[[sm_vars[1L]]],
y = .data[[".estimate"]],
group = .data[[".model"]]
))
## add uncertainty bands
plt <- plt + geom_ribbon(
aes(
ymin = .data[[".lower_ci"]],
ymax = .data[[".upper_ci"]],
fill = .data[[".model"]]
),
alpha = 0.2
)
## add smooth lines
plt <- plt + geom_line(aes(colour = .data[[".model"]]))
## Add labels
plt <- plt + labs(
colour = "Model", fill = "Model",
title = unique(object[[".smooth"]]),
y = "Estimate"
)
plt
}