The sl3
package provides a modern framework for machine
learning. This includes the Super Learner algorithm (van der Laan, Polley, and Hubbard 2007), a
method for performing stacked regressions (Breiman 1996), combined with covariate
screening and cross-validation. sl3
uses an Object Oriented
Programming (OOP) approach and leverages R6
classes
to define both Tasks (machine learning problems) and
Learners (machine learning algorithms that attempt to solve
those problems) in a way that is both flexible and extensible. The
design of sl3
owes a lot to the SuperLearner
and mlr
packages, which also provide unified frameworks for
Super Learning and machine learning, respectively.
Throughout this vignette, we use data from the Collaborative
Perinatal Project (CPP) to illustrate the features of sl3
as well as its proper usage. For convenience, we’ve included an imputed
version of this dataset in the sl3
package. Below, we load
some useful packages, load the cpp_imputed
dataset, and
define the variables (columns) from the data set we’re interested
in:
set.seed(49753)
# packages we'll be using
library(data.table)
library(SuperLearner)
library(origami)
library(sl3)
# load example data set
data(cpp_imputed)
# here are the covariates we are interested in and, of course, the outcome
covars <- c(
"apgar1", "apgar5", "parity", "gagebrth", "mage", "meducyrs",
"sexn"
)
outcome <- "haz"
As mentioned above, sl3
is designed using basic OOP
principles and the R6
OOP
framework. While we’ve tried to make it easy to use sl3
without worrying much about OOP, it is helpful to have some intuition
about how sl3
is structured. In this section, we briefly
outline some key concepts from OOP. Readers familiar with OOP basics are
invited to skip this section. The key concept of OOP is that of an
object, a collection of data and functions that corresponds to
some conceptual unit. Objects have two main types of elements,
fields, which can be thought of as nouns, are information about
an object, and methods, which can be thought of as verbs, are
actions an object can perform. Objects are members of classes,
which define what those specific fields and methods are. Classes can
inherit elements from other classes (sometimes called base
classes) – accordingly, classes that are similar, but not exactly
the same, can share some parts of their definitions.
Many different implementations of OOP exist, with variations in how
these concepts are implemented and used. R has several different
implementations, including S3, S4, reference classes, and R6.
sl3
uses the R6
implementation. In R6, methods and fields of a class object are accessed
using the $
operator. The next section explains how these
concepts are used in sl3
to model machine learning problems
and algorithms.
sl3
objectsThe sl3_Task
class defines machine learning problems. An
sl3_Task
object keeps track of the task data, as well as
what variables play what roles in the machine learning problem. We can
see an example of that here, using the cpp
dataset
described above:
task <- make_sl3_Task(
data = cpp_imputed, covariates = covars,
outcome = outcome, outcome_type = "continuous"
)
We use the make_sl3_Task
method to create a new
sl3_Task
, called task
. Here, we specified the
underlying data, cpp_imputed
, and vectors indicating which
varaibles to use as covariates and outcomes.
Let’s take a look at this object:
## An sl3 Task with 1441 obs and these nodes:
## $covariates
## [1] "apgar1" "apgar5" "parity" "gagebrth" "mage" "meducyrs" "sexn"
##
## $outcome
## [1] "haz"
##
## $id
## NULL
##
## $weights
## NULL
##
## $offset
## NULL
##
## $time
## NULL
In addition to the simple usage demonstrated above,
make_sl3_Task
supports a range of options in order to
facilitate the proper articulation of more advanced specifics
potentially informative of the machine learning problem of interest. For
example, we can specify the id
, weights
, and
offset
nodes listed above. These additional features are
documented in the help for sl3_Task
.
Lrnr_base
is the base class for defining machine
learning algorithms, as well as fits for those algorithms to
particular sl3_Task
s. Different machine learning algorithms
are defined in classes that inherit from Lrnr_base
. For
instance, the Lrnr_glm
class inherits from
Lrnr_base
, and defines a learner that fits generalized
linear models. We will use the term learners to refer to the
family of classes that inherit from Lrnr_base
. Learner
objects can be constructed from their class definitions using the
make_learner
function:
Because all learners inherit from Lrnr_base
, they have
many features in common, and can be used interchangeably. All learners
define three main methods: train
, predict
, and
chain
. The first, train
, takes a
sl3_task
object, and returns a learner_fit, which has the
same class as the learner that was trained:
# fit learner to task data
lrnr_glm_fit <- lrnr_glm$train(task)
# verify that the learner is fit
lrnr_glm_fit$is_trained
## [1] TRUE
Here, we fit the learner to the CPP task we defined above. Both
lrnr_glm
and lrnr_glm_fit
are objects of class
Lrnr_glm
, although the former defines a learner and the
latter defines a fit of that learner. We can distiguish between the
learners and learner fits using the is_trained
field, which
is true for fits but not for learners.
Now that we’ve fit a learner, we can generate predictions using the
predict
method:
## [1] 0.36298498 0.36298498 0.25993072 0.25993072 0.25993072 0.05680264
Here, we specified task
as the task for which we wanted
to generate predictions. If we had omitted this, we would have gotten
the same predictions because predict
defaults to using the
task provided to train
(called the training task).
Alternatively, we could have provided a different task for which we want
to generate predictions.
The final important learner method, chain
, will be
discussed below, in the section on learner composition.
As with sl3_Task
, learners have a variety of fields and
methods we haven’t discussed here. More information on these is
available in the help for Lrnr_base
.
Learners have properties that indicate what features they
support. You can use sl3_list_properties
to get a list of
all properties supported by at least one learner. You can then use
sl3_list_learners
to find learners supporting any set of
properties. For example:
## [1] "binomial" "categorical" "continuous" "cv"
## [5] "density" "h2o" "ids" "importance"
## [9] "offset" "preprocessing" "sampling" "screener"
## [13] "timeseries" "weights" "wrapper"
## [1] "Lrnr_bayesglm" "Lrnr_ga" "Lrnr_glm" "Lrnr_glm_fast"
## [5] "Lrnr_glmtree" "Lrnr_h2o_glm" "Lrnr_h2o_grid" "Lrnr_lightgbm"
## [9] "Lrnr_mean" "Lrnr_optim" "Lrnr_solnp" "Lrnr_xgboost"
The list of supported learners is currently somewhat limited. Despite
current limitations, some learners not yet supported natively in
sl3
can be used via their corresponding wrappers in the
SuperLearner
package. SuperLearner
wrappers,
screeners, and methods can all be used as sl3
learners via
Lrnr_pkg_SuperLearner
,
Lrnr_pkg_SuperLearner_screener
, and
Lrnr_pkg_SuperLearner_method
respectively. To learn more
about SuperLearner
wrappers, screeners, and methods,
consult the documentation provided with that R package. Here’s an
example of defining a sl3
learner that uses the
SL.glmnet
wrapper from SuperLearner
.
In most cases, using these wrappers will not be as efficient as their
native sl3
counterparts. If your favorite learner is
missing from sl3
, please consider adding it by following
the “Defining New Learners”
vignette.
In general, learners can be instantiated without providing any additional parameters. We’ve tried to provide sensible defaults for each learner; however, if you would like to modify the learners’ behavior, you may do so by instantiating learners with different parameters.
sl3
Learners support some common parameters that work
with all learners for which they are applicable:
covariates
: subsets covariates before fitting. This
allows different learners to be fit to the same task with different
covariate subsets.
outcome_type
: overrides the
task$outcome_type
. This allows different learners to be fit
to the same task with different outcome_types.
...
: abitrary parameters typically passed directly
to the internal learner method. The documentation for each learner will
direct to the appropriate function documentation for the learner
method.
sl3
defines two special learners, Pipeline
and Stack
, that allow learners to be composed in a flexible
manner.
A pipeline is a set of learners to be fit sequentially,
where the fit from one learner is used to define the task for the next
learner. There are many ways in which a learner can define the task for
the downstream learner. The chain
method defined by
learners defines how this will work. Let’s look at the example of
pre-screening variables. For now, we’ll rely on a screener from the
SuperLearner
package, although native sl3
screening algorithms will be implemented soon.
Below, we generate a screener object based on the
SuperLearner
function screen.corP
and fit it
to our task. Inspecting the fit, we see that it selected a subset of
covariates:
screen_cor <- Lrnr_pkg_SuperLearner_screener$new("screen.corP")
screen_fit <- screen_cor$train(task)
print(screen_fit)
## [1] "Lrnr_pkg_SuperLearner_screener_screen.corP"
## $selected
## [1] "parity" "gagebrth"
Now, chain
may be called on this learner fit to define a
downstream task:
## An sl3 Task with 1441 obs and these nodes:
## $covariates
## [1] "parity" "gagebrth"
##
## $outcome
## [1] "haz"
##
## $id
## NULL
##
## $weights
## NULL
##
## $offset
## NULL
##
## $time
## NULL
As with predict
, we can omit a task from the call to
chain
, in which case the call defaults to using the same
task that was used for training. We can see that the chained task
reduces the covariates to the subset selected by the screener. We can
fit this new task using the lrnr_glm
we defined above:
screened_glm_fit <- lrnr_glm$train(screened_task)
screened_preds <- screened_glm_fit$predict()
head(screened_preds)
## [1] 0.38084472 0.38084472 0.29887623 0.29887623 0.29887623 -0.00987784
The Pipeline
class automates this process. It takes an
arbitrary number of learners and fits them sequentially, training and
chaining each one in turn. Since Pipeline
is a learner like
any other, it shares the same interface. We can define a pipeline using
make_learner
, and use train
and
predict
just as we did before:
sg_pipeline <- make_learner(Pipeline, screen_cor, lrnr_glm)
sg_pipeline_fit <- sg_pipeline$train(task)
sg_pipeline_preds <- sg_pipeline_fit$predict()
head(sg_pipeline_preds)
## [1] 0.38084472 0.38084472 0.29887623 0.29887623 0.29887623 -0.00987784
We see that the pipeline returns the same predictions as manually
training glm
on the chained task from the screening
learner.
We can visualize the pipeline we defined above:
Like Pipeline
s, Stack
s combine multiple
learners. Stack
s train learners simultaneously, so
that their predictions can be either combined or compared. Again,
Stack
is just a special learner and so has the same
interface as all other learners:
stack <- make_learner(Stack, lrnr_glm, sg_pipeline)
stack_fit <- stack$train(task)
stack_preds <- stack_fit$predict()
head(stack_preds)
## Lrnr_glm_TRUE
## <num>
## 1: 0.36298498
## 2: 0.36298498
## 3: 0.25993072
## 4: 0.25993072
## 5: 0.25993072
## 6: 0.05680264
## Pipeline(Lrnr_pkg_SuperLearner_screener_screen.corP->Lrnr_glm_TRUE)
## <num>
## 1: 0.38084472
## 2: 0.38084472
## 3: 0.29887623
## 4: 0.29887623
## 5: 0.29887623
## 6: -0.00987784
Above, we’ve defined and fit a stack
comprised of a
simple glm
learner as well as a pipeline that combines a
screening algorithm with that same learner. We could have included any
abitrary set of learners and pipelines, the latter of which are
themselves just learners. We can see that the predict
method now returns a matrix, with a column for each learner included in
the stack.
We can visualize the stack:
We see one “branch” for each learner in the stack.
Having defined a stack, we might want to compare the performance of
learners in the stack, which we may do using cross-validation.
The Lrnr_cv
learner wraps another learner and performs
training and prediction in a cross-validated fashion, using separate
training and validation splits as defined by
task$folds
.
Below, we define a new Lrnr_cv
object based on the
previously defined stack
and train it and generate
predictions on the validation set:
We can also use the special Lrnr_cv
function
cv_risk
to estimate cross-validated risk values:
## Key: <learner>
## learner
## <fctr>
## 1: Lrnr_glm_TRUE
## 2: Pipeline(Lrnr_pkg_SuperLearner_screener_screen.corP->Lrnr_glm_TRUE)
## coefficients MSE se fold_sd fold_min_MSE fold_max_MSE
## <num> <num> <num> <num> <num> <num>
## 1: NA 1.600368 0.1042245 0.3521137 1.062662 2.139263
## 2: NA 1.601804 0.1032100 0.3419678 1.066897 2.108442
In this example, we don’t see much difference between the two learners, suggesting the addition of the screening step in the pipeline learner didn’t improve performance much.
We can combine all of the above elements, Pipeline
s,
Stack
s, and cross-validation using Lrnr_cv
, to
easily define a Super Learner. The Super Learner algorithm works by
fitting a “meta-learner”, which combines predictions from multiple
stacked learners. It does this while avoiding overfitting by training
the meta-learner on validation-set predictions in a manner that is
cross-validated. Using some of the objects we defined in the above
examples, this becomes a very simple operation:
metalearner <- make_learner(Lrnr_nnls)
cv_task <- cv_fit$chain()
ml_fit <- metalearner$train(cv_task)
Here, we used a special learner, Lrnr_nnls
, for the
meta-learning step. This fits a non-negative least squares meta-learner.
It is important to note that any learner can be used as a
meta-learner.
The Super Learner finally produced is defined as a pipeline with the
learner stack trained on the full data and the meta-learner trained on
the validation-set predictions. Below, we use a special behavior of
pipelines: if all objects passed to a pipeline are learner fits (i.e.,
learner$is_trained
is TRUE
), the result will
also be a fit:
sl_pipeline <- make_learner(Pipeline, stack_fit, ml_fit)
sl_preds <- sl_pipeline$predict()
head(sl_preds)
## [1] 0.36977825 0.36977825 0.27474433 0.27474433 0.27474433 0.03143954
A Super Learner may be fit in a more streamlined manner using the
Lrnr_sl
learner. For simplicity, we will use the same set
of learners and meta-learning algorithm as we did before:
sl <- Lrnr_sl$new(
learners = stack,
metalearner = metalearner
)
sl_fit <- sl$train(task)
lrnr_sl_preds <- sl_fit$predict()
head(lrnr_sl_preds)
## [1] 0.36977825 0.36977825 0.27474433 0.27474433 0.27474433 0.03143954
We can see that this generates the same predictions as the more hands-on definition above.
delayed
Fitting a Super Learner is composed of many different training and
prediction steps, as the procedure requires that the learners in the
stack and the meta-learner be fit on cross-validation folds and on the
full data. For large datasets, this can be extremely time-consuming. To
alleviate this complication, we’ve developed a specialized
parallelization framework delayed
that parallelizes across
these tasks in a way that takes into account their inter-dependent
nature. Consider a Super Learner with three learners:
lrnr_rf <- make_learner(Lrnr_randomForest)
lrnr_glmnet <- make_learner(Lrnr_glmnet)
sl <- Lrnr_sl$new(
learners = list(lrnr_glm, lrnr_rf, lrnr_glmnet),
metalearner = metalearner
)
We can plot the network of tasks required to train this Super Learner:
delayed
then allows us to parallelize the procedure
across these tasks using the future
package. For more information on specifying future
plan
s for parallelization, see the documentation of the future
package.
## R version 4.4.2 (2024-10-31)
## Platform: x86_64-pc-linux-gnu
## Running under: Ubuntu 24.04.1 LTS
##
## Matrix products: default
## BLAS: /usr/lib/x86_64-linux-gnu/openblas-pthread/libblas.so.3
## LAPACK: /usr/lib/x86_64-linux-gnu/openblas-pthread/libopenblasp-r0.3.26.so; LAPACK version 3.12.0
##
## locale:
## [1] LC_CTYPE=en_US.UTF-8 LC_NUMERIC=C
## [3] LC_TIME=en_US.UTF-8 LC_COLLATE=C
## [5] LC_MONETARY=en_US.UTF-8 LC_MESSAGES=en_US.UTF-8
## [7] LC_PAPER=en_US.UTF-8 LC_NAME=C
## [9] LC_ADDRESS=C LC_TELEPHONE=C
## [11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C
##
## time zone: Etc/UTC
## tzcode source: system (glibc)
##
## attached base packages:
## [1] splines stats graphics grDevices utils datasets methods
## [8] base
##
## other attached packages:
## [1] origami_1.0.7 SuperLearner_2.0-29 gam_1.22-5
## [4] foreach_1.5.2 nnls_1.6 data.table_1.16.2
## [7] sl3_1.4.5 rmarkdown_2.29
##
## loaded via a namespace (and not attached):
## [1] Rdpack_2.6.1 delayed_0.5.0 pROC_1.18.5
## [4] rlang_1.1.4 magrittr_2.0.3 compiler_4.4.2
## [7] vctrs_0.6.5 reshape2_1.4.4 stringr_1.5.1
## [10] pkgconfig_2.0.3 shape_1.4.6.1 crayon_1.5.3
## [13] fastmap_1.2.0 backports_1.5.0 utf8_1.2.4
## [16] prodlim_2024.06.25 purrr_1.0.2 xfun_0.49
## [19] glmnet_4.1-8 randomForest_4.7-1.2 cachem_1.1.0
## [22] jsonlite_1.8.9 progress_1.2.3 recipes_1.1.0
## [25] uuid_1.2-1 parallel_4.4.2 prettyunits_1.2.0
## [28] R6_2.5.1 bslib_0.8.0 stringi_1.8.4
## [31] parallelly_1.39.0 rpart_4.1.23 lubridate_1.9.3
## [34] jquerylib_0.1.4 Rcpp_1.0.13-1 assertthat_0.2.1
## [37] iterators_1.0.14 knitr_1.49 future.apply_1.11.3
## [40] R.utils_2.12.3 Matrix_1.7-1 nnet_7.3-19
## [43] R.cache_0.16.0 igraph_2.1.1 timechange_0.3.0
## [46] tidyselect_1.2.1 abind_1.4-8 yaml_2.3.10
## [49] timeDate_4041.110 codetools_0.2-20 listenv_0.9.1
## [52] lattice_0.22-6 tibble_3.2.1 plyr_1.8.9
## [55] withr_3.0.2 ROCR_1.0-11 evaluate_1.0.1
## [58] future_1.34.0 survival_3.7-0 pillar_1.9.0
## [61] checkmate_2.3.2 stats4_4.4.2 generics_0.1.3
## [64] hms_1.1.3 ggplot2_3.5.1 munsell_0.5.1
## [67] scales_1.3.0 globals_0.16.3 class_7.3-22
## [70] glue_1.8.0 maketools_1.3.1 tools_4.4.2
## [73] sys_3.4.3 ModelMetrics_1.2.2.2 rstackdeque_1.1.1
## [76] gower_1.0.1 buildtools_1.0.0 visNetwork_2.1.2
## [79] grid_4.4.2 rbibutils_2.3 ipred_0.9-15
## [82] colorspace_2.1-1 nlme_3.1-166 BBmisc_1.13
## [85] cli_3.6.3 fansi_1.0.6 lava_1.8.0
## [88] dplyr_1.1.4 gtable_0.3.6 R.methodsS3_1.8.2
## [91] R.rsp_0.46.0 sass_0.4.9 digest_0.6.37
## [94] caret_6.0-94 htmlwidgets_1.6.4 htmltools_0.5.8.1
## [97] R.oo_1.27.0 lifecycle_1.0.4 hardhat_1.4.0
## [100] MASS_7.3-61