Stochastic treatment regimes present a relatively simple manner in which to assess the effects of continuous treatments by way of parameters that examine the effects induced by the counterfactual shifting of the observed values of a treatment of interest. Here, we present an implementation of a new algorithm for computing targeted minimum loss-based estimates of treatment shift parameters defined based on a shifting function \(d(A,W)\). For a technical presentation of the algorithm, the interested reader is invited to consult Dı́az and van der Laan (2018). For additional background on Targeted Learning and previous work on stochastic treatment regimes, please consider consulting van der Laan and Rose (2011), van der Laan and Rose (2018), and Dı́az and van der Laan (2012).
To start, let’s load the packages we’ll use and set a seed for simulation:
Consider \(n\) observed units \(O_1, \ldots, O_n\), where each random variable \(O = (W, A, Y)\) corresponds to a single observational unit. Let \(W\) denote baseline covariates (e.g., age, sex, education level), \(A\) an intervention variable of interest (e.g., nutritional supplements), and \(Y\) an outcome of interest (e.g., disease status). Though it need not be the case, let \(A\) be continuous-valued, i.e. \(A \in \mathbb{R}\). Let \(O_i \sim \mathcal{P} \in \mathcal{M}\), where \(\mathcal{M}\) is the nonparametric statistical model defined as the set of continuous densities on \(O\) with respect to some dominating measure. To formalize the definition of stochastic interventions and their corresponding causal effects, we introduce a nonparametric structural equation model (NPSEM), based on Pearl (2000), to define how the system changes under posited interventions: \[\begin{align*}\label{eqn:npsem} W &= f_W(U_W) \\ A &= f_A(W, U_A) \\ Y &= f_Y(A, W, U_Y), \end{align*}\] We denote the observed data structure \(O = (W, A, Y)\)
Letting \(A\) denote a
continuous-valued treatment, we assume that the distribution of \(A\) conditional on \(W = w\) has support in the interval \((l(w), u(w))\) – for convenience, let this
support be a.e. That is, the minimum natural value of treatment
\(A\) for an individual with covariates
\(W = w\) is \(l(w)\); similarly, the maximum is \(u(w)\). Then, a simple stochastic
intervention, based on a shift \(\delta\), may be defined \[\begin{equation}\label{eqn:shift}
d(a, w) =
\begin{cases}
a - \delta & \text{if } a > l(w) + \delta \\
a & \text{if } a \leq l(w) + \delta,
\end{cases}
\end{equation}\] where \(0 \leq \delta
\leq u(w)\) is an arbitrary pre-specified value that defines the
degree to which the observed value \(A\) is to be shifted, where possible. For
the purpose of using such a shift in practice, the present software
provides the functions shift_additive and
shift_additive_inv, which define a variation of this shift,
assuming that the density of treatment \(A\), conditional on the covariates \(W\), has support a.e.
# simulate simple data for tmle-shift sketch
n_obs <- 1000 # number of observations
n_w <- 1 # number of baseline covariates
tx_mult <- 2 # multiplier for the effect of W = 1 on the treatment
## baseline covariates -- simple, binary
W <- as.numeric(replicate(n_w, rbinom(n_obs, 1, 0.5)))
## create treatment based on baseline W
A <- as.numeric(rnorm(n_obs, mean = tx_mult * W, sd = 1))
## create outcome as a linear function of A, W + white noise
Y <- A + W + rnorm(n_obs, mean = 0, sd = 0.5)The above composes our observed data structure \(O = (W, A, Y)\). To formally express this
fact using the tlverse grammar introduced by the tmle3 package,
we create a single data object and specify the functional relationships
between the nodes in the directed acyclic graph (DAG) via
nonparametric structural equation models (NPSEMs), reflected in
the node list that we set up:
# organize data and nodes for tmle3
data <- data.table(W, A, Y)
node_list <- list(W = "W", A = "A", Y = "Y")
head(data)## W A Y
## <num> <num> <num>
## 1: 1 2.4031607 3.7157578
## 2: 1 4.4973744 5.9651611
## 3: 1 2.0330871 2.2531970
## 4: 0 -0.8089023 -0.8849531
## 5: 1 1.8432067 2.7193091
## 6: 1 1.3555863 2.5705832
We now have an observed data structure (data) and a
specification of the role that each variable in the data set plays as
the nodes in a DAG.
To start, we will initialize a specification for the TMLE of our
parameter of interest (called a tmle3_Spec in the
tlverse nomenclature) simply by calling
tmle_shift. We specify the argument
shift_val = 0.5 when initializing the
tmle3_Spec object to communicate that we’re interested in a
shift of \(0.5\) on the scale of the
treatment \(A\) – that is, we specify
\(\delta = 0.5\) (note that this is an
arbitrarily chosen value for this example).
# initialize a tmle specification
tmle_spec <- tmle_shift(shift_val = 0.5,
shift_fxn = shift_additive,
shift_fxn_inv = shift_additive_inv)As seen above, the tmle_shift specification object (like
all tmle3_Spec objects) does not store the data
for our specific analysis of interest. Later, we’ll see that passing a
data object directly to the tmle3 wrapper function,
alongside the instantiated tmle_spec, will serve to
construct a tmle3_Task object internally (see the tmle3 documentation
for details). Note that, by default, the tmle_spec object
is set up to facilitate cross-validated estimation of likelihood
components, ensuring certain empirical process conditions may be
circumvented by reducing the contribution of an empirical process term
to the estimated influence function (Zheng and
Laan 2011). In practice, this automatic incorporation of
cross-validation (CV-TMLE) means that the user need not be concerned
with these theoretical conditions being satisfied; moreover,
cross-validated estimation of the efficient influence function is
expected to control the estimated variance.
sl3To easily incorporate ensemble machine learning into the estimation
procedure, we rely on the facilities provided in the sl3 R package. For a
complete guide on using the sl3 R package, consider
consulting https://tlverse.org/sl3, or https://tlverse.org for the
tlverse ecosystem,
of which sl3 is a core component.
Using the framework provided by the sl3 package, the
nuisance parameters of the TML estimator may be fit with ensemble
learning, using the cross-validation framework of the Super Learner
algorithm of van der Laan et
al. (2007). To estimate the treatment mechanism (denoted
g, we must make use of learning algorithms specifically suited to
conditional density estimation; a list of such learners may be extracted
from sl3 using sl3_list_learners():
## [1] "Lrnr_density_discretize" "Lrnr_density_hse"
## [3] "Lrnr_density_semiparametric" "Lrnr_haldensify"
## [5] "Lrnr_solnp_density"
To proceed, we’ll select two of the above learners,
Lrnr_haldensify for using the highly adaptive lasso for
conditional density estimation, based on an algorithm given by Dı́az and van der Laan
(2011), and Lrnr_density_semiparametric, an approach
for semiparametric conditional density estimation:
# learners used for conditional density regression (i.e., propensity score)
haldensify_lrnr <- Lrnr_haldensify$new(
n_bins = 3, grid_type = "equal_mass",
lambda_seq = exp(seq(-1, -9, length = 100))
)
hse_lrnr <- Lrnr_density_semiparametric$new(mean_learner = Lrnr_glm$new())
mvd_lrnr <- Lrnr_density_semiparametric$new(mean_learner = Lrnr_glm$new(),
var_learner = Lrnr_mean$new())
sl_lrn_dens <- Lrnr_sl$new(
learners = list(haldensify_lrnr, hse_lrnr, mvd_lrnr),
metalearner = Lrnr_solnp_density$new()
)We also require an approach for estimating the outcome regression (denoted Q). For this, we build a Super Learner composed of an intercept model, a main terms GLM, and the xgboost algorithm for gradient boosting:
# learners used for conditional expectation regression (e.g., outcome)
mean_lrnr <- Lrnr_mean$new()
glm_lrnr <- Lrnr_glm$new()
xgb_lrnr <- Lrnr_xgboost$new()
sl_lrn <- Lrnr_sl$new(
learners = list(mean_lrnr, glm_lrnr, xgb_lrnr),
metalearner = Lrnr_nnls$new()
)We can make the above explicit with respect to standard notation by
bundling the ensemble learners into a list object
below.
# specify outcome and treatment regressions and create learner list
Q_learner <- sl_lrn
g_learner <- sl_lrn_dens
learner_list <- list(Y = Q_learner, A = g_learner)The learner_list object above specifies the role that
each of the ensemble learners we’ve generated is to play in computing
initial estimators to be used in building a TMLE for the parameter of
interest here. In particular, it makes explicit the fact that our
Q_learner is used in fitting the outcome regression while
our g_learner is used in fitting our treatment mechanism
regression.
Note that, by default, the
## Warning in check.deprecation(deprecated_train_params, match.call(), ...):
## Passed invalid function arguments: nthread. These should be passed as a list to
## argument 'params'. Conversion from argument to 'params' entry will be done
## automatically, but this behavior will become an error in a future version.
## Warning in throw_err_or_depr_msg("Parameter '", match_old, "' has been renamed
## to '", : Parameter 'watchlist' has been renamed to 'evals'. This warning will
## become an error in a future version.
## Error in fit_object$training_offset <- task$has_node("offset") :
## ALTLIST classes must provide a Set_elt method [class: XGBAltrepPointerClass, pkg: xgboost]
## Warning in private$.train(processed_task, trained_sublearners): Lrnr_xgboost_20_1 failed with message: Error in fit_object$training_offset <- task$has_node("offset") :
## ALTLIST classes must provide a Set_elt method [class: XGBAltrepPointerClass, pkg: xgboost]
## . It will be removed from the stack
## Warning in check.deprecation(deprecated_train_params, match.call(), ...):
## Passed invalid function arguments: nthread. These should be passed as a list to
## argument 'params'. Conversion from argument to 'params' entry will be done
## automatically, but this behavior will become an error in a future version.
## Warning in throw_err_or_depr_msg("Parameter '", match_old, "' has been renamed
## to '", : Parameter 'watchlist' has been renamed to 'evals'. This warning will
## become an error in a future version.
## Error in fit_object$training_offset <- task$has_node("offset") :
## ALTLIST classes must provide a Set_elt method [class: XGBAltrepPointerClass, pkg: xgboost]
## Warning in check.deprecation(deprecated_train_params, match.call(), ...):
## Passed invalid function arguments: nthread. These should be passed as a list to
## argument 'params'. Conversion from argument to 'params' entry will be done
## automatically, but this behavior will become an error in a future version.
## Warning in check.deprecation(deprecated_train_params, match.call(), ...):
## Parameter 'watchlist' has been renamed to 'evals'. This warning will become an
## error in a future version.
## Error in fit_object$training_offset <- task$has_node("offset") :
## ALTLIST classes must provide a Set_elt method [class: XGBAltrepPointerClass, pkg: xgboost]
## Warning in check.deprecation(deprecated_train_params, match.call(), ...):
## Passed invalid function arguments: nthread. These should be passed as a list to
## argument 'params'. Conversion from argument to 'params' entry will be done
## automatically, but this behavior will become an error in a future version.
## Warning in check.deprecation(deprecated_train_params, match.call(), ...):
## Parameter 'watchlist' has been renamed to 'evals'. This warning will become an
## error in a future version.
## Error in fit_object$training_offset <- task$has_node("offset") :
## ALTLIST classes must provide a Set_elt method [class: XGBAltrepPointerClass, pkg: xgboost]
## Warning in check.deprecation(deprecated_train_params, match.call(), ...):
## Passed invalid function arguments: nthread. These should be passed as a list to
## argument 'params'. Conversion from argument to 'params' entry will be done
## automatically, but this behavior will become an error in a future version.
## Warning in check.deprecation(deprecated_train_params, match.call(), ...):
## Parameter 'watchlist' has been renamed to 'evals'. This warning will become an
## error in a future version.
## Error in fit_object$training_offset <- task$has_node("offset") :
## ALTLIST classes must provide a Set_elt method [class: XGBAltrepPointerClass, pkg: xgboost]
## Warning in check.deprecation(deprecated_train_params, match.call(), ...):
## Passed invalid function arguments: nthread. These should be passed as a list to
## argument 'params'. Conversion from argument to 'params' entry will be done
## automatically, but this behavior will become an error in a future version.
## Warning in check.deprecation(deprecated_train_params, match.call(), ...):
## Parameter 'watchlist' has been renamed to 'evals'. This warning will become an
## error in a future version.
## Error in fit_object$training_offset <- task$has_node("offset") :
## ALTLIST classes must provide a Set_elt method [class: XGBAltrepPointerClass, pkg: xgboost]
## Warning in check.deprecation(deprecated_train_params, match.call(), ...):
## Passed invalid function arguments: nthread. These should be passed as a list to
## argument 'params'. Conversion from argument to 'params' entry will be done
## automatically, but this behavior will become an error in a future version.
## Warning in check.deprecation(deprecated_train_params, match.call(), ...):
## Parameter 'watchlist' has been renamed to 'evals'. This warning will become an
## error in a future version.
## Error in fit_object$training_offset <- task$has_node("offset") :
## ALTLIST classes must provide a Set_elt method [class: XGBAltrepPointerClass, pkg: xgboost]
## Warning in check.deprecation(deprecated_train_params, match.call(), ...):
## Passed invalid function arguments: nthread. These should be passed as a list to
## argument 'params'. Conversion from argument to 'params' entry will be done
## automatically, but this behavior will become an error in a future version.
## Warning in check.deprecation(deprecated_train_params, match.call(), ...):
## Parameter 'watchlist' has been renamed to 'evals'. This warning will become an
## error in a future version.
## Error in fit_object$training_offset <- task$has_node("offset") :
## ALTLIST classes must provide a Set_elt method [class: XGBAltrepPointerClass, pkg: xgboost]
## Warning in check.deprecation(deprecated_train_params, match.call(), ...):
## Passed invalid function arguments: nthread. These should be passed as a list to
## argument 'params'. Conversion from argument to 'params' entry will be done
## automatically, but this behavior will become an error in a future version.
## Warning in check.deprecation(deprecated_train_params, match.call(), ...):
## Parameter 'watchlist' has been renamed to 'evals'. This warning will become an
## error in a future version.
## Error in fit_object$training_offset <- task$has_node("offset") :
## ALTLIST classes must provide a Set_elt method [class: XGBAltrepPointerClass, pkg: xgboost]
## Warning in check.deprecation(deprecated_train_params, match.call(), ...):
## Passed invalid function arguments: nthread. These should be passed as a list to
## argument 'params'. Conversion from argument to 'params' entry will be done
## automatically, but this behavior will become an error in a future version.
## Warning in check.deprecation(deprecated_train_params, match.call(), ...):
## Parameter 'watchlist' has been renamed to 'evals'. This warning will become an
## error in a future version.
## Error in fit_object$training_offset <- task$has_node("offset") :
## ALTLIST classes must provide a Set_elt method [class: XGBAltrepPointerClass, pkg: xgboost]
## Warning in check.deprecation(deprecated_train_params, match.call(), ...):
## Passed invalid function arguments: nthread. These should be passed as a list to
## argument 'params'. Conversion from argument to 'params' entry will be done
## automatically, but this behavior will become an error in a future version.
## Warning in check.deprecation(deprecated_train_params, match.call(), ...):
## Parameter 'watchlist' has been renamed to 'evals'. This warning will become an
## error in a future version.
## Error in fit_object$training_offset <- task$has_node("offset") :
## ALTLIST classes must provide a Set_elt method [class: XGBAltrepPointerClass, pkg: xgboost]
## 1% of observations outside training support...predictions trimmed.
## 1% of observations outside training support...predictions trimmed.
##
## Iter: 1 fn: 1383.2954 Pars: 0.38063 0.31823 0.30114
## Iter: 2 fn: 1383.2954 Pars: 0.38063 0.31823 0.30114
## solnp--> Completed in 2 iterations
## Warning in private$.train(processed_task, trained_sublearners): Lrnr_xgboost_20_1 failed with message: Error in fit_object$training_offset <- task$has_node("offset") :
## ALTLIST classes must provide a Set_elt method [class: XGBAltrepPointerClass, pkg: xgboost]
## . It will be removed from the stack
## Warning in private$.train(processed_task, trained_sublearners): Lrnr_xgboost_20_1 failed with message: Error in fit_object$training_offset <- task$has_node("offset") :
## ALTLIST classes must provide a Set_elt method [class: XGBAltrepPointerClass, pkg: xgboost]
## . It will be removed from the stack
## Warning in private$.train(processed_task, trained_sublearners): Lrnr_xgboost_20_1 failed with message: Error in fit_object$training_offset <- task$has_node("offset") :
## ALTLIST classes must provide a Set_elt method [class: XGBAltrepPointerClass, pkg: xgboost]
## . It will be removed from the stack
## Warning in private$.train(processed_task, trained_sublearners): Lrnr_xgboost_20_1 failed with message: Error in fit_object$training_offset <- task$has_node("offset") :
## ALTLIST classes must provide a Set_elt method [class: XGBAltrepPointerClass, pkg: xgboost]
## . It will be removed from the stack
## Warning in private$.train(processed_task, trained_sublearners): Lrnr_xgboost_20_1 failed with message: Error in fit_object$training_offset <- task$has_node("offset") :
## ALTLIST classes must provide a Set_elt method [class: XGBAltrepPointerClass, pkg: xgboost]
## . It will be removed from the stack
## Warning in private$.train(processed_task, trained_sublearners): Lrnr_xgboost_20_1 failed with message: Error in fit_object$training_offset <- task$has_node("offset") :
## ALTLIST classes must provide a Set_elt method [class: XGBAltrepPointerClass, pkg: xgboost]
## . It will be removed from the stack
## Warning in private$.train(processed_task, trained_sublearners): Lrnr_xgboost_20_1 failed with message: Error in fit_object$training_offset <- task$has_node("offset") :
## ALTLIST classes must provide a Set_elt method [class: XGBAltrepPointerClass, pkg: xgboost]
## . It will be removed from the stack
## Warning in private$.train(processed_task, trained_sublearners): Lrnr_xgboost_20_1 failed with message: Error in fit_object$training_offset <- task$has_node("offset") :
## ALTLIST classes must provide a Set_elt method [class: XGBAltrepPointerClass, pkg: xgboost]
## . It will be removed from the stack
## Warning in private$.train(processed_task, trained_sublearners): Lrnr_xgboost_20_1 failed with message: Error in fit_object$training_offset <- task$has_node("offset") :
## ALTLIST classes must provide a Set_elt method [class: XGBAltrepPointerClass, pkg: xgboost]
## . It will be removed from the stack
## Warning in private$.train(processed_task, trained_sublearners): Lrnr_xgboost_20_1 failed with message: Error in fit_object$training_offset <- task$has_node("offset") :
## ALTLIST classes must provide a Set_elt method [class: XGBAltrepPointerClass, pkg: xgboost]
## . It will be removed from the stack
## Warning in private$.train(processed_task, trained_sublearners): The following
## learners failed for one or more folds and will be dropped from all folds:
## Lrnr_xgboost_20_1
## 0.9% of observations outside training support...predictions trimmed.
## 1% of observations outside training support...predictions trimmed.
## 1% of observations outside training support...predictions trimmed.
## 1% of observations outside training support...predictions trimmed.
## 1% of observations outside training support...predictions trimmed.
## 1% of observations outside training support...predictions trimmed.
## 1% of observations outside training support...predictions trimmed.
## 1% of observations outside training support...predictions trimmed.
## 4% of observations outside training support...predictions trimmed.
## 1% of observations outside training support...predictions trimmed.
## 1% of observations outside training support...predictions trimmed.
## 1% of observations outside training support...predictions trimmed.
## 1% of observations outside training support...predictions trimmed.
## 0.2% of observations outside training support...predictions trimmed.
## A tmle3_Fit that took 1 step(s)
## type param init_est tmle_est se lower upper
## <char> <char> <num> <num> <num> <num> <num>
## 1: TSM E[Y_{A=NULL}] 2.050792 2.050106 0.05990814 1.932689 2.167524
## psi_transformed lower_transformed upper_transformed
## <num> <num> <num>
## 1: 2.050106 1.932689 2.167524
The print method of the resultant tmle_fit
object conveniently displays the results from computing our TML
estimator.
Recall that the asymptotic distribution of TML estimators has been studied thoroughly: \[\psi_n - \psi_0 = (P_n - P_0) \cdot D(\bar{Q}_n^*, g_n) + R(\hat{P}^*, P_0),\] which, provided the following two conditions:
Under the additional condition that the remainder term \(R(\hat{P}^*, P_0)\) decays as \(o_P \left( \frac{1}{\sqrt{n}} \right),\) we have that \[\psi_n - \psi_0 = (P_n - P_0) \cdot D(P_0) + o_P \left( \frac{1}{\sqrt{n}} \right),\] which, by a central limit theorem, establishes a Gaussian limiting distribution for the estimator:
\[\sqrt{n}(\psi_n - \psi) \to N(0, V(D(P_0))),\] where \(V(D(P_0))\) is the variance of the efficient influence curve (canonical gradient) when \(\psi\) admits an asymptotically linear representation.
The above implies that \(\psi_n\) is a \(\sqrt{n}\)-consistent estimator of \(\psi\), that it is asymptotically normal (as given above), and that it is locally efficient. This allows us to build Wald-type confidence intervals in a straightforward manner:
\[\psi_n \pm z_{\alpha} \cdot \frac{\sigma_n}{\sqrt{n}},\] where \(\sigma_n^2\) is an estimator of \(V(D(P_0))\). The estimator \(\sigma_n^2\) may be obtained using the bootstrap or computed directly via the following
\[\sigma_n^2 = \frac{1}{n} \sum_{i = 1}^{n} D^2(\bar{Q}_n^{\star}, g_n)(O_i)\]
Having now re-examined these facts, let’s simply examine the results of computing our TML estimator:
## A tmle3_Fit that took 1 step(s)
## type param init_est tmle_est se lower upper
## <char> <char> <num> <num> <num> <num> <num>
## 1: TSM E[Y_{A=NULL}] 2.050792 2.050106 0.05990814 1.932689 2.167524
## psi_transformed lower_transformed upper_transformed
## <num> <num> <num>
## 1: 2.050106 1.932689 2.167524