Title: | Generalized Framework for Cross-Validation |
---|---|
Description: | A general framework for the application of cross-validation schemes to particular functions. By allowing arbitrary lists of results, origami accommodates a range of cross-validation applications. This implementation was first described by Coyle and Hejazi (2018) <doi:10.21105/joss.00512>. |
Authors: | Jeremy Coyle [aut, cre, cph] , Nima Hejazi [aut] , Ivana Malenica [aut] , Rachael Phillips [aut] |
Maintainer: | Jeremy Coyle <[email protected]> |
License: | GPL-3 |
Version: | 1.0.7 |
Built: | 2024-12-25 03:11:16 UTC |
Source: | https://github.com/tlverse/origami |
Applies combiners
: functions that collapse across a list of
similarly structured results, to a list of such lists.
combine_results(results, combiners = NULL, smart_combiners = TRUE)
combine_results(results, combiners = NULL, smart_combiners = TRUE)
results |
A |
combiners |
A |
smart_combiners |
A |
In theory you should never call this function directly, because it is called automatically by cross_validate. The defaults, combiners guessed based on data type, should work in most cases.
A list of combined results.
Combiners are functions that collapse across a list of similarly structured results. These are standard idioms for combining lists of certain data types.
combiner_rbind(x) combiner_c(x) combiner_factor(x) combiner_array(x)
combiner_rbind(x) combiner_c(x) combiner_factor(x) combiner_array(x)
x |
A |
A combined results object.
Applies cv_fun
to the folds using future_lapply
and combines
the results across folds using combine_results
.
cross_validate( cv_fun, folds, ..., use_future = TRUE, .combine = TRUE, .combine_control = list(), .old_results = NULL )
cross_validate( cv_fun, folds, ..., use_future = TRUE, .combine = TRUE, .combine_control = list(), .old_results = NULL )
cv_fun |
A function that takes a 'fold' as it's first argument and
returns a list of results from that fold. NOTE: the use of an argument
named 'X' is specifically disallowed in any input function for compliance
with the functions |
folds |
A list of folds to loop over generated using
|
... |
Other arguments passed to |
use_future |
A |
.combine |
A |
.combine_control |
A |
.old_results |
A |
A list
of results, combined across folds.
############################################################################### # This example explains how to use the cross_validate function naively. ############################################################################### data(mtcars) # resubstitution MSE r <- lm(mpg ~ ., data = mtcars) mean(resid(r)^2) # function to calculate cross-validated squared error cv_lm <- function(fold, data, reg_form) { # get name and index of outcome variable from regression formula out_var <- as.character(unlist(stringr::str_split(reg_form, " "))[1]) out_var_ind <- as.numeric(which(colnames(data) == out_var)) # split up data into training and validation sets train_data <- training(data) valid_data <- validation(data) # fit linear model on training set and predict on validation set mod <- lm(as.formula(reg_form), data = train_data) preds <- predict(mod, newdata = valid_data) # capture results to be returned as output out <- list( coef = data.frame(t(coef(mod))), SE = ((preds - valid_data[, out_var_ind])^2) ) return(out) } # replicate the resubstitution estimate resub <- make_folds(mtcars, fold_fun = folds_resubstitution)[[1]] resub_results <- cv_lm(fold = resub, data = mtcars, reg_form = "mpg ~ .") mean(resub_results$SE) # cross-validated estimate folds <- make_folds(mtcars) cv_results <- cross_validate( cv_fun = cv_lm, folds = folds, data = mtcars, reg_form = "mpg ~ ." ) mean(cv_results$SE) ############################################################################### # This example explains how to use the cross_validate function with # parallelization using the framework of the future package. ############################################################################### suppressMessages(library(data.table)) library(future) data(mtcars) set.seed(1) # make a lot of folds folds <- make_folds(mtcars, fold_fun = folds_bootstrap, V = 1000) # function to calculate cross-validated squared error for linear regression cv_lm <- function(fold, data, reg_form) { # get name and index of outcome variable from regression formula out_var <- as.character(unlist(str_split(reg_form, " "))[1]) out_var_ind <- as.numeric(which(colnames(data) == out_var)) # split up data into training and validation sets train_data <- training(data) valid_data <- validation(data) # fit linear model on training set and predict on validation set mod <- lm(as.formula(reg_form), data = train_data) preds <- predict(mod, newdata = valid_data) # capture results to be returned as output out <- list( coef = data.frame(t(coef(mod))), SE = ((preds - valid_data[, out_var_ind])^2) ) return(out) } plan(sequential) time_seq <- system.time({ results_seq <- cross_validate( cv_fun = cv_lm, folds = folds, data = mtcars, reg_form = "mpg ~ ." ) }) plan(multicore) time_mc <- system.time({ results_mc <- cross_validate( cv_fun = cv_lm, folds = folds, data = mtcars, reg_form = "mpg ~ ." ) }) if (availableCores() > 1) { time_mc["elapsed"] < 1.2 * time_seq["elapsed"] }
############################################################################### # This example explains how to use the cross_validate function naively. ############################################################################### data(mtcars) # resubstitution MSE r <- lm(mpg ~ ., data = mtcars) mean(resid(r)^2) # function to calculate cross-validated squared error cv_lm <- function(fold, data, reg_form) { # get name and index of outcome variable from regression formula out_var <- as.character(unlist(stringr::str_split(reg_form, " "))[1]) out_var_ind <- as.numeric(which(colnames(data) == out_var)) # split up data into training and validation sets train_data <- training(data) valid_data <- validation(data) # fit linear model on training set and predict on validation set mod <- lm(as.formula(reg_form), data = train_data) preds <- predict(mod, newdata = valid_data) # capture results to be returned as output out <- list( coef = data.frame(t(coef(mod))), SE = ((preds - valid_data[, out_var_ind])^2) ) return(out) } # replicate the resubstitution estimate resub <- make_folds(mtcars, fold_fun = folds_resubstitution)[[1]] resub_results <- cv_lm(fold = resub, data = mtcars, reg_form = "mpg ~ .") mean(resub_results$SE) # cross-validated estimate folds <- make_folds(mtcars) cv_results <- cross_validate( cv_fun = cv_lm, folds = folds, data = mtcars, reg_form = "mpg ~ ." ) mean(cv_results$SE) ############################################################################### # This example explains how to use the cross_validate function with # parallelization using the framework of the future package. ############################################################################### suppressMessages(library(data.table)) library(future) data(mtcars) set.seed(1) # make a lot of folds folds <- make_folds(mtcars, fold_fun = folds_bootstrap, V = 1000) # function to calculate cross-validated squared error for linear regression cv_lm <- function(fold, data, reg_form) { # get name and index of outcome variable from regression formula out_var <- as.character(unlist(str_split(reg_form, " "))[1]) out_var_ind <- as.numeric(which(colnames(data) == out_var)) # split up data into training and validation sets train_data <- training(data) valid_data <- validation(data) # fit linear model on training set and predict on validation set mod <- lm(as.formula(reg_form), data = train_data) preds <- predict(mod, newdata = valid_data) # capture results to be returned as output out <- list( coef = data.frame(t(coef(mod))), SE = ((preds - valid_data[, out_var_ind])^2) ) return(out) } plan(sequential) time_seq <- system.time({ results_seq <- cross_validate( cv_fun = cv_lm, folds = folds, data = mtcars, reg_form = "mpg ~ ." ) }) plan(multicore) time_mc <- system.time({ results_mc <- cross_validate( cv_fun = cv_lm, folds = folds, data = mtcars, reg_form = "mpg ~ ." ) }) if (availableCores() > 1) { time_mc["elapsed"] < 1.2 * time_seq["elapsed"] }
For V-fold type cross-validation. This takes a fold vector (validation set IDs) and builds a fold object for fold V.
fold_from_foldvec(v, folds)
fold_from_foldvec(v, folds)
v |
An identifier of the fold in which observations fall for cross-validation. |
folds |
A vector of the fold status for each observation for cross-validation. |
Other fold generation functions:
fold_funs
,
folds2foldvec()
,
make_folds()
,
make_repeated_folds()
These functions represent different cross-validation schemes that can be
used with origami. They should be used as options for the
fold_fun
argument to make_folds
, which will call the
requested function specify n
, based on its arguments, and pass any
remaining arguments (e.g. V
or pvalidation
) on.
folds_vfold(n, V = 10L) folds_resubstitution(n) folds_loo(n) folds_montecarlo(n, V = 1000L, pvalidation = 0.2) folds_bootstrap(n, V = 1000L) folds_rolling_origin(n, first_window, validation_size, gap = 0L, batch = 1L) folds_rolling_window(n, window_size, validation_size, gap = 0L, batch = 1L) folds_rolling_origin_pooled( n, t, id = NULL, time = NULL, first_window, validation_size, gap = 0L, batch = 1L ) folds_rolling_window_pooled( n, t, id = NULL, time = NULL, window_size, validation_size, gap = 0L, batch = 1L ) folds_vfold_rolling_origin_pooled( n, t, id = NULL, time = NULL, V = 10L, first_window, validation_size, gap = 0L, batch = 1L ) folds_vfold_rolling_window_pooled( n, t, id = NULL, time = NULL, V = 10L, window_size, validation_size, gap = 0L, batch = 1L )
folds_vfold(n, V = 10L) folds_resubstitution(n) folds_loo(n) folds_montecarlo(n, V = 1000L, pvalidation = 0.2) folds_bootstrap(n, V = 1000L) folds_rolling_origin(n, first_window, validation_size, gap = 0L, batch = 1L) folds_rolling_window(n, window_size, validation_size, gap = 0L, batch = 1L) folds_rolling_origin_pooled( n, t, id = NULL, time = NULL, first_window, validation_size, gap = 0L, batch = 1L ) folds_rolling_window_pooled( n, t, id = NULL, time = NULL, window_size, validation_size, gap = 0L, batch = 1L ) folds_vfold_rolling_origin_pooled( n, t, id = NULL, time = NULL, V = 10L, first_window, validation_size, gap = 0L, batch = 1L ) folds_vfold_rolling_window_pooled( n, t, id = NULL, time = NULL, V = 10L, window_size, validation_size, gap = 0L, batch = 1L )
n |
An integer indicating the number of observations. |
V |
An integer indicating the number of folds. |
pvalidation |
A |
first_window |
An integer indicating the number of observations in the first training sample. |
validation_size |
An integer indicating the number of points in the validation samples; should be equal to the largest forecast horizon. |
gap |
An integer indicating the number of points not included in the training or validation samples. The default is zero. |
batch |
An integer indicating increases in the number of time points added to the training set in each iteration of cross-validation. Applicable for larger time-series. The default is one. |
window_size |
An integer indicating the number of observations in each training sample. |
t |
An integer indicating the total amount of time to consider per time-series sample. |
id |
An optional vector of unique identifiers corresponding to the time vector. These can be used to subset the time vector. |
time |
An optional vector of integers of time points observed for each subject in the sample. |
A list of Fold
s.
Other fold generation functions:
fold_from_foldvec()
,
folds2foldvec()
,
make_folds()
,
make_repeated_folds()
Accessors and indexers for the different parts of a fold.
training(x = NULL, fold = NULL) validation(x = NULL, fold = NULL) fold_index(x = NULL, fold = NULL)
training(x = NULL, fold = NULL) validation(x = NULL, fold = NULL) fold_index(x = NULL, fold = NULL)
x |
an object to be indexed by a training set, validation set, or fold index. If missing, the index itself will be returned. |
fold |
Fold; the fold used to do the indexing. If missing, |
The elements of x
corresponding to the indexes, or the
indexes themselves if x
is missing.
For V-fold type cross-validation. This takes a fold object and returns a fold
vector (containing the validation set IDs) for use with other tools like
cv.glmnet
.
folds2foldvec(folds)
folds2foldvec(folds)
folds |
A |
Other fold generation functions:
fold_from_foldvec()
,
fold_funs
,
make_folds()
,
make_repeated_folds()
Maps data types into standard combiners that should be sensible.
guess_combiner(result)
guess_combiner(result)
result |
A single result; flexibly accepts several object classes. |
A function to combine a list of such results.
This function convertsf olds that subset ids to folds that subset observations
id_folds_to_folds(idfolds, cluster_ids)
id_folds_to_folds(idfolds, cluster_ids)
idfolds |
folds that subset ids |
cluster_ids |
a vector of cluster ids indicating which observations are in which clusters |
Functions to make a fold. Current representation is a simple list
.
make_fold(v, training_set, validation_set)
make_fold(v, training_set, validation_set)
v |
An integer index of folds in the larger scheme. |
training_set |
An integer vector of indexes corresponding to the training set. |
validation_set |
An integer vector of indexes corresponding to the validation set. |
A list containing these elements.
Generates a list of folds for a variety of cross-validation schemes.
make_folds( n = NULL, fold_fun = folds_vfold, cluster_ids = NULL, strata_ids = NULL, ... )
make_folds( n = NULL, fold_fun = folds_vfold, cluster_ids = NULL, strata_ids = NULL, ... )
n |
- either an integer indicating the number of observations to
cross-validate over, or an object from which to guess the number of
observations; can also be computed from |
fold_fun |
- A function indicating the cross-validation scheme to use.
See |
cluster_ids |
- a vector of cluster ids. Clusters are treated as a unit – that is, all observations within a cluster are placed in either the training or validation set. |
strata_ids |
- a vector of strata ids. Strata are balanced: insofar as possible the distribution in the sample should be the same as the distribution in the training and validation sets. |
... |
other arguments to be passed to |
A list of folds objects. Each fold consists of a list with a
training
index vector, a validation
index vector, and a
fold_index
(its order in the list of folds).
Other fold generation functions:
fold_from_foldvec()
,
fold_funs
,
folds2foldvec()
,
make_repeated_folds()
Implementation of repeated window cross-validation: generates fold objects
for repeated cross-validation by making repeated calls to
make_folds
and concatenating the results.
make_repeated_folds(repeats, ...)
make_repeated_folds(repeats, ...)
repeats |
An integer indicating the number of repeats. |
... |
Arguments passed to |
Other fold generation functions:
fold_from_foldvec()
,
fold_funs
,
folds2foldvec()
,
make_folds()
Function factory that generates versions of functions wrapped in try
.
wrap_in_try(fun, ...)
wrap_in_try(fun, ...)
fun |
A |
... |
Additional arguments passed to the previous argument |