3. Heterogeneous Treatment effect - A Causal Forest Example

Author

Kuan Liu

Published

2024-10-12

Outlines
  • Additional notes added to the AI4PH causal workshop
    • provide code to perform causal random forest for the estimation of conditional treatment effect (CATE)
    • In this data example, we estimate CATE by race.
library(tidyverse)
library(grf)
options(scipen = 999)

# Data prep;
data <- read.csv("data/rhc.csv", header=T)
# define exposure variable
data$A <- ifelse(data$swang1 =="No RHC", 0, 1)
# outcome is dth30, a binary outcome measuring survival status at day 30;
data$Y <- ifelse(data$dth30 =="No", 0, 1)
data2 <- select(data, -c(cat2, adld3p, urin1, swang1,
                         sadmdte, dschdte, dthdte, lstctdte, death, dth30,
                         surv2md1, das2d3pc, t3d30, ptid))
data2 <- rename(data2, id = X)
covariates <- select(data2, -c(id, A, Y))
# Potential subgroups of interest;
table(data2$race)

black other white 
  920   355  4460 
table(data2$income)

  $11-$25k   $25-$50k     > $50k Under $11k 
      1165        893        451       3226 
# getting matrix for causal forest;
X <- covariates %>%  model.matrix(~., .)
X <- X[ , -1]
Y <- data2$Y
W <- data2$A

set.seed(123)
cforest <- causal_forest(
  X,
  Y,
  W,
  num.trees = 2000,
  sample.fraction = 0.5,
  mtry = min(ceiling(sqrt(ncol(X)) + 20), ncol(X)),
  min.node.size = 10,
  honesty = TRUE,
  honesty.fraction = 0.5,
  honesty.prune.leaves = TRUE,
  alpha = 0.05,
  imbalance.penalty = 0,
  stabilize.splits = TRUE,
  ci.group.size = 2,
  tune.parameters = "none",
  compute.oob.predictions = TRUE
)


cforest
GRF forest object of type causal_forest 
Number of trees: 2000 
Number of training samples: 5735 
Variable importance: 
    1     2     3     4     5     6     7     8     9    10    11    12    13 
0.003 0.000 0.000 0.000 0.000 0.000 0.003 0.014 0.003 0.002 0.004 0.001 0.003 
   14    15    16    17    18    19    20    21    22    23    24    25    26 
0.001 0.001 0.001 0.001 0.000 0.002 0.004 0.007 0.000 0.078 0.005 0.039 0.057 
   27    28    29    30    31    32    33    34    35    36    37    38    39 
0.022 0.055 0.038 0.035 0.051 0.051 0.052 0.018 0.047 0.038 0.062 0.028 0.062 
   40    41    42    43    44    45    46    47    48    49    50    51    52 
0.044 0.038 0.040 0.004 0.016 0.001 0.001 0.015 0.002 0.004 0.003 0.000 0.004 
   53    54    55    56    57    58    59    60    61    62    63 
0.004 0.000 0.003 0.005 0.000 0.000 0.005 0.003 0.002 0.005 0.013 
# which variables appear to make a difference for treatment effects? HTE?
# we can inspect variable_importance, which measures how often a variable Xj was split on.;
varimp <- variable_importance(cforest)
ranked.vars <- order(varimp, decreasing = TRUE)

# Top 5 variables according to this measure
colnames(X)[ranked.vars[1:5]]
[1] "age"     "pot1"    "crea1"   "aps1"    "meanbp1"
#the first tree;
plot(get_tree(cforest, index=1))
#the 30th tree;
plot(get_tree(cforest, index=30))
# estimate ATE;
average_treatment_effect(cforest, target.sample = "all")
  estimate    std.err 
0.05720678 0.01128300 
# estimate ATT;
average_treatment_effect(cforest, target.sample = "treated")
  estimate    std.err 
0.05880718 0.01292041 
# test whether there is HTE;
test_calibration(cforest)

Best linear fit using forest predictions (on held-out data)
as well as the mean forest prediction as regressors, along
with one-sided heteroskedasticity-robust (HC3) SEs:

                               Estimate Std. Error t value     Pr(>t)    
mean.forest.prediction          1.01764    0.23667  4.2998 0.00000869 ***
differential.forest.prediction  1.06305    0.66444  1.5999    0.05484 .  
---
Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
#We can reject the null of no heterogeneity;

# Conditional ATE for race
#white;
average_treatment_effect(cforest, target.sample = "all", subset = X[ , "racewhite"] == 1)
  estimate    std.err 
0.06151071 0.01272640 
#other;
average_treatment_effect(cforest, target.sample = "all", subset = X[ , "raceother"] == 1)
  estimate    std.err 
0.00977772 0.04376047 
#black;
average_treatment_effect(cforest, target.sample = "all", subset = X[ , "raceother"] == 0 & X[ , "racewhite"] == 0)
  estimate    std.err 
0.05464348 0.02926098