| Title: | Causal Distillation Trees |
|---|---|
| Description: | Causal Distillation Tree (CDT) is a novel machine learning method for estimating interpretable subgroups with heterogeneous treatment effects. CDT allows researchers to fit any machine learning model (or metalearner) to estimate heterogeneous treatment effects for each individual, and then "distills" these predicted heterogeneous treatment effects into interpretable subgroups by fitting an ordinary decision tree to predict the previously-estimated heterogeneous treatment effects. This package provides tools to estimate causal distillation trees (CDT), as detailed in Huang, Tang, and Kenney (2025) <doi:10.48550/arXiv.2502.07275>. |
| Authors: | Tiffany Tang [aut, cre] (ORCID: <https://orcid.org/0000-0002-8079-6867>), Melody Huang [aut], Ana Kenney [aut] |
| Maintainer: | Tiffany Tang <[email protected]> |
| License: | MIT + file LICENSE |
| Version: | 1.0.0 |
| Built: | 2026-05-30 15:07:15 UTC |
| Source: | https://github.com/tiffanymtang/causaldt |
This function implements causal distillation trees (CDT), developed in Huang et al. (2025). Briefly, CDT is a two-stage procedure that allows researchers to identify interpretable subgroups with heterogeneous treatment effects. In the first stage, researchers are free to use any machine learning model or metalearner to predict the heterogeneous treatment effects for each individual in the dataset. In the second stage, CDT “distills” these predicted heterogeneous treatment effects into interpretable subgroups by fitting an ordinary decision tree using the predicted heterogeneous treatment effects from the first stage as the response variable.
causalDT( X, Y, Z, W = NULL, holdout_prop = 0.3, holdout_idxs = NULL, teacher_model = "causal_forest", teacher_predict = NULL, student_model = "rpart", rpart_control = NULL, rpart_prune = c("none", "min", "1se"), nfolds_crossfit = NULL, nreps_crossfit = NULL, B_stability = 100, max_depth_stability = NULL, ... )causalDT( X, Y, Z, W = NULL, holdout_prop = 0.3, holdout_idxs = NULL, teacher_model = "causal_forest", teacher_predict = NULL, student_model = "rpart", rpart_control = NULL, rpart_prune = c("none", "min", "1se"), nfolds_crossfit = NULL, nreps_crossfit = NULL, B_stability = 100, max_depth_stability = NULL, ... )
X |
A tibble, data.frame, or matrix of covariates. |
Y |
A vector of outcomes. |
Z |
A vector of treatments. |
W |
A vector of weights corresponding to treatment propensities. |
holdout_prop |
Proportion of data to hold out for honest estimation of
treatment effects. Used only if |
holdout_idxs |
A vector of indices to hold out for honest estimation of
treatment effects. If NULL, a holdout set of size |
teacher_model |
Teacher model used to estimate individual-level
treatment events. Should be either "causal_forest" (default),
"bcf", or a function.
If "causal_forest", |
teacher_predict |
Function used to predict individual-level treatment
effects from the teacher model. Should take in two arguments. as input: the
first being the model object returned by |
student_model |
Student model used to estimate subgroups of individuals
and their corresponding estimated treatment effects. Should be either
"rpart" (default) or a function. If "rpart", |
rpart_control |
A list of control parameters for the |
rpart_prune |
Method for pruning the tree. Default is |
nfolds_crossfit |
Number of folds in cross-fitting procedure.
If |
nreps_crossfit |
Number of repetitions of the cross-fitting procedure.
If |
B_stability |
Number of bootstrap samples to use in evaluating stability
diagnostics (which can be used to select an appropriate teacher model).
Default is 100. Stability diagnostics are only performed if
|
max_depth_stability |
Maximum depth of the decision tree used in
evaluating stability diagnostics. If |
... |
Additional arguments passed to the |
A list with the following elements:
estimate |
Estimated subgroup average treatment effects tibble with the following columns:
|
student_fit |
Output of
|
teacher_fit |
A list of (cross-fitted) teacher model fits. |
teacher_predictions |
The predicted individual-level treatment effects, averaged across all cross-fitted teacher model. |
teacher_predictions_ls |
A list of predicted individual-level treatment effects from each (cross-fitted) teacher model fit. |
crossfit_idxs_ls |
A list of fold indices used in each cross-fit. |
stability_diagnostics |
A list of stability diagnostics with the following elements:
|
holdout_idxs |
Indices of the holdout set. |
Huang, M., Tang, T. M., and Kenney, A. M. (2025). Distilling heterogeneous treatment effects: Stable subgroup estimation in causal inference. arXiv preprint arXiv:2502.07275.
n <- 50 p <- 3 X <- matrix(rnorm(n * p), nrow = n, ncol = p) Z <- rbinom(n, 1, 0.5) Y <- 2 * Z * (X[, 1] > 0) + X[, 2] + rnorm(n, 0.1) # causal distillation trees using causal forest teacher model out <- causalDT(X, Y, Z)n <- 50 p <- 3 X <- matrix(rnorm(n * p), nrow = n, ncol = p) Z <- rbinom(n, 1, 0.5) Y <- 2 * Z * (X[, 1] > 0) + X[, 2] + rnorm(n, 0.1) # causal distillation trees using causal forest teacher model out <- causalDT(X, Y, Z)
This function estimates the conditional average treatment effect for each subgroup given by the fitted decision tree. The conditional average treatment effect is estimated as the difference in the average outcome between treated and control units that fall within each subgroup (i.e., each leaf node in the decision tree).
estimate_group_cates(fit, X, Y, Z)estimate_group_cates(fit, X, Y, Z)
fit |
Fitted subgroup model used to determine subgroup membership of
individuals. Typically, this is a |
X |
A tibble, data.frame, or matrix of covariates. |
Y |
A vector of outcomes. |
Z |
A vector of treatments. |
Estimated subgroup average treatment effects tibble with the following columns:
leaf_id |
Leaf node identifier. |
subgroup |
String representation of the subgroup. |
estimate |
Estimated conditional average treatment effect for the subgroup. |
variance |
Asymptotic variance of the estimated conditional average treatment effect. |
.var1 |
Sample variance for treated observations in the subgroup. |
.var0 |
Sample variance for control observations in the subgroup. |
.n1 |
Number of treated observations in the subgroup. |
.n0 |
Number of control observations in the subgroup. |
.sample_idxs |
Indices of (holdout) observations in the subgroup. |
n <- 50 p <- 3 X <- matrix(rnorm(n * p), nrow = n, ncol = p) Z <- rbinom(n, 1, 0.5) Y <- 2 * Z * (X[, 1] > 0) + X[, 2] + rnorm(n, 0.1) # causal distillation tree output out <- causalDT(X, Y, Z) # compute subgroup CATEs manually group_cates <- estimate_group_cates( out$student_fit$fit, X = X[out$holdout_idxs, , drop = FALSE], Y = Y[out$holdout_idxs], Z = Z[out$holdout_idxs] ) all.equal(out$estimate, group_cates)n <- 50 p <- 3 X <- matrix(rnorm(n * p), nrow = n, ncol = p) Z <- rbinom(n, 1, 0.5) Y <- 2 * Z * (X[, 1] > 0) + X[, 2] + rnorm(n, 0.1) # causal distillation tree output out <- causalDT(X, Y, Z) # compute subgroup CATEs manually group_cates <- estimate_group_cates( out$student_fit$fit, X = X[out$holdout_idxs, , drop = FALSE], Y = Y[out$holdout_idxs], Z = Z[out$holdout_idxs] ) all.equal(out$estimate, group_cates)
This function evaluates the stability of the estimated subgroups from causal distillation trees (CDT) using the Jaccard subgroup stability index (SSI), developed in Huang et al. (2025). It is generally recommended to choose teacher models in CDT that result in the most stable subgroups, as indicated by high SSI values.
evaluate_subgroup_stability( estimator, fit, X, y, Z = NULL, rpart_control = NULL, B = 100, max_depth = NULL )evaluate_subgroup_stability( estimator, fit, X, y, Z = NULL, rpart_control = NULL, B = 100, max_depth = NULL )
estimator |
Function used to estimate subgroups of individuals and their
corresponding estimated treatment effects. The function should take in
|
fit |
Fitted subgroup model (often, the output of |
X |
A tibble, data.frame, or matrix of covariates. |
y |
A vector of responses to predict. |
Z |
A vector of treatments. |
rpart_control |
A list of control parameters for the |
B |
Number of bootstrap samples to use in evaluating stability diagnostics. Default is 100. |
max_depth |
Maximum depth of the tree to consider when evaluating
stability diagnostics. If |
A list with the following elements:
jaccard_mean |
Vector of mean Jaccard similarity index for each tree depth. The tree depth is given by the vector index. |
jaccard_distribution |
List of Jaccard similarity indices across all bootstraps for each tree depth. |
bootstrap_predictions |
List of mean student model predictions (for training (non-holdout) data) across all bootstraps for each tree depth. |
bootstrap_predictions_var |
List of variance of student model predictions (for training (non-holdout) data) across all bootstraps for each tree depth. |
leaf_ids |
List of leaf node identifiers, indicating the leaf membership of each training sample in the (original) fitted student model. |
Huang, M., Tang, T. M., and Kenney, A. M. (2025). Distilling heterogeneous treatment effects: Stable subgroup estimation in causal inference. arXiv preprint arXiv:2502.07275.
n <- 200 p <- 10 X <- matrix(rnorm(n * p), nrow = n, ncol = p) Z <- rbinom(n, 1, 0.5) Y <- 2 * Z * (X[, 1] > 0) + X[, 2] + rnorm(n, 0.1) # run causal distillation trees without stability diagnostics out <- causalDT(X, Y, Z, B_stability = 0) # run stability diagnostics stability_out <- evaluate_subgroup_stability( estimator = student_rpart, fit = out$student_fit$fit, X = X[-out$holdout_idxs, , drop = FALSE], y = out$student_fit$predictions )n <- 200 p <- 10 X <- matrix(rnorm(n * p), nrow = n, ncol = p) Z <- rbinom(n, 1, 0.5) Y <- 2 * Z * (X[, 1] > 0) + X[, 2] + rnorm(n, 0.1) # run causal distillation trees without stability diagnostics out <- causalDT(X, Y, Z, B_stability = 0) # run stability diagnostics stability_out <- evaluate_subgroup_stability( estimator = student_rpart, fit = out$student_fit$fit, X = X[-out$holdout_idxs, , drop = FALSE], y = out$student_fit$predictions )
Return the decision paths for each leaf node in an rpart model as character
strings.
get_rpart_paths(rpart_fit)get_rpart_paths(rpart_fit)
rpart_fit |
An |
A list of character vectors, where each element corresponds to the decision
path for a leaf node in the rpart_fit model.
Return the split information for each node in an rpart model as a data frame.
get_rpart_tree_info(rpart_fit, X = NULL, digits = getOption("digits"))get_rpart_tree_info(rpart_fit, X = NULL, digits = getOption("digits"))
rpart_fit |
An |
X |
Optional data frame containing the features used in the |
digits |
Number of digits to round the split values to. |
A data.frame with information regarding the feature/threshold used
for each split in the rpart model.
This function computes the Jaccard similarity index between two vectors of subgroup membership labels, scaling it such that each leaf node contributes equal weight to the overall similarity.
jaccardSSI(x, y)jaccardSSI(x, y)
x |
Numeric vector of subgroup memberships. Must be encoded as integers, beginning at 0 and be contiguous (i.e., if there are k unique values, they must be 0, 1, ..., k-1). |
y |
Numeric vector of subgroup memberships Must be encoded as integers, beginning at 0 and be contiguous (i.e., if there are k unique values, they must be 0, 1, ..., k-1). |
Computed Jaccard subgroup similarity metric
Visualize the subgroups (i.e., the student tree) from a causal distillation tree object.
plot_cdt(cdt, show_digits = 2)plot_cdt(cdt, show_digits = 2)
cdt |
A causal distillation tree object, typically the output of
|
show_digits |
Number of digits to show in the plot labels. Default is 2. |
A plot of the causal distillation tree.
n <- 200 p <- 10 X <- matrix(rnorm(n * p), nrow = n, ncol = p) Z <- rbinom(n, 1, 0.5) Y <- 2 * Z * (X[, 1] > 0) + X[, 2] + rnorm(n, 0.1) cdt <- causalDT(X, Y, Z) plot_cdt(cdt)n <- 200 p <- 10 X <- matrix(rnorm(n * p), nrow = n, ncol = p) Z <- rbinom(n, 1, 0.5) Y <- 2 * Z * (X[, 1] > 0) + X[, 2] + rnorm(n, 0.1) cdt <- causalDT(X, Y, Z) plot_cdt(cdt)
The Jaccard subgroup similiarity index (SSI) is a measure of the similarity between two candidate partitions of subgroups. To select an appropriate teacher model in CDT, the Jaccard SSI can be used to select the teacher model that recovers the most stable subgroups.
plot_jaccard(...)plot_jaccard(...)
... |
Two or more causal distillation tree objects, each is typically
the output of |
A plot of the Jaccard SSI for each tree depth.
n <- 50 p <- 2 X <- matrix(rnorm(n * p), nrow = n, ncol = p) Z <- rbinom(n, 1, 0.5) Y <- 2 * Z * (X[, 1] > 0) + X[, 2] + rnorm(n, 0.1) # number of bootstraps for stability diagnostics (setting to small value for faster example) B <- 10 # run CDT with default causal forest teacher model cdt1 <- causalDT(X, Y, Z, B_stability = B) # run CDT with custom BCF teacher model sink(tempfile()) # to suppress printed output from BCF cdt2 <- causalDT( X, Y, Z, # set BCF training parameters to be small for faster example teacher_model = purrr::partial(bcf, nsim = 100, nburn = 10), teacher_predict = predict_bcf, # set number of cross-fitting replications to be small for faster example nreps_crossfit = 5, B_stability = B ) sink() # restore normal output # plot Jaccard SSI for both teacher models (note: in practice, use larger B) plot_jaccard(`Causal Forest` = cdt1, `BCF` = cdt2)n <- 50 p <- 2 X <- matrix(rnorm(n * p), nrow = n, ncol = p) Z <- rbinom(n, 1, 0.5) Y <- 2 * Z * (X[, 1] > 0) + X[, 2] + rnorm(n, 0.1) # number of bootstraps for stability diagnostics (setting to small value for faster example) B <- 10 # run CDT with default causal forest teacher model cdt1 <- causalDT(X, Y, Z, B_stability = B) # run CDT with custom BCF teacher model sink(tempfile()) # to suppress printed output from BCF cdt2 <- causalDT( X, Y, Z, # set BCF training parameters to be small for faster example teacher_model = purrr::partial(bcf, nsim = 100, nburn = 10), teacher_predict = predict_bcf, # set number of cross-fitting replications to be small for faster example nreps_crossfit = 5, B_stability = B ) sink() # restore normal output # plot Jaccard SSI for both teacher models (note: in practice, use larger B) plot_jaccard(`Causal Forest` = cdt1, `BCF` = cdt2)
This is a wrapper function to convert any of the rlearner
model functions into a format that can be used as teacher model in the
causal distillation tree framework.
rlearner_teacher(rlearner_fun, ...)rlearner_teacher(rlearner_fun, ...)
rlearner_fun |
One of |
... |
Additional arguments to pass to the base model functions. |
Outputs a function that can be used as teacher model in the
causal distillation tree framework. The returned function has the
signature function(X, Y, Z, W = NULL, ...).
This function is a wrapper around rpart::rpart() that can be easily
used as a student model in the causal distillation tree framework.
student_rpart( X, y, method = "anova", rpart_control = NULL, prune = c("none", "min", "1se"), fit_only = FALSE )student_rpart( X, y, method = "anova", rpart_control = NULL, prune = c("none", "min", "1se"), fit_only = FALSE )
X |
A tibble, data.frame, or matrix of covariates. |
y |
A vector of responses to predict. |
method |
Same as |
rpart_control |
A list of control parameters for the |
prune |
Method for pruning the tree. Default is |
fit_only |
Logical. If |
If fit_only = TRUE, the fitted model is returned. Otherwise, a list
with the following components is returned:
fit |
Fitted model. An |
tree_info |
Data frame with tree structure/split information. |
subgroups |
List of subgroups given by their string representation. |
predictions |
Student model predictions for the given |