Causal Inference for Longitudinal Data

1. Hands-on excerise performing causal analysis for time-dependent treatment/exposure with longtiudinal observaitonal data

Special acknowledgement to Yutong Lu, my undergraduate research student who is undertaking a R package review project for causal analysis with longitudinal data. Yutong developed the code to simulate the dataset we will be using in this tutorial.

1.1 Simulated observational data with a time-dependent treatment

  • The simulated datset
    • 1000 patients and 3 visits (2 of which patients were assigned a treatment)
    • y, an end-of-study continuous outcome
    • z, a binary treatment
    • w1 and w2 are two baseline covariates (one continuous and one binary, mimicking age and sex)
    • L1 and L2 are two time-dependent covariates (also one continuous and one binary)
    • no missing data
  • The simulated DAG

    digraph causal {
    # Nodes
    node [shape=plaintext]
    W [label = 'w1, w2']
    L1 [label = 'L11, L21']
    Z1 [label = 'Z1']
    L2 [label = 'L12, L22']
    Z2 [label = 'Z2']
    Y [label = 'Y']
    # Edges
    edge [color=black, arrowhead=vee]
    rankdir = LR
    # Graph
    graph [overlap=true, fontsize=14]
options(scipen = 999)

causaldata <- read.csv("continuous_outcome_data.csv", header = TRUE, fileEncoding="UTF-8-BOM")

# look at the data;
          rownames = FALSE,
          options = list(dom = 't')) %>%
  formatRound(columns=c('w2', 'L2_1', 'L2_2', 'y'), digits=2)
# frequency counts by treatment combinations;
table(causaldata$Z_1, causaldata$Z_2)
      0   1
  0 520 201
  1 111 168
  • Suppose the causal parameter of interest is the average treatment effect between always treated and never treated, \[ ATE = E(Y|Z_1 = 1, Z_2 =1) - E(Y|Z_1 = 0, Z_2 =0) \]

1.2 Implementing marginal structural models

  • Step 1, getting a glimpse of covariates balance by visit using the cobalt package
library(cobalt) #package to assess covariates balance by treatment;

#covariates balance at each visit; ~ w1 + w2 + L1_1 + L2_1,
        Z_2 ~ w1 + w2 + L1_1 + L2_1 + L1_2 + L2_2 + Z_1),
        data = causaldata, 
        int = FALSE,
        poly = 1, 
        estimand = "ATE", 
        stats = c("m"),
        thresholds = c(m = 0.1),
        which.time = .all)
Balance by Time Point

 - - - Time: 1 - - - 
Balance Measures
        Type Diff.Un M.Threshold.Un
w1    Binary  0.0470 Balanced, <0.1
w2   Contin.  0.0035 Balanced, <0.1
L1_1  Binary -0.0432 Balanced, <0.1
L2_1 Contin.  0.0292 Balanced, <0.1

Balance tally for mean differences
Balanced, <0.1         4
Not Balanced, >0.1     0

Variable with the greatest mean difference
 Variable Diff.Un M.Threshold.Un
       w1   0.047 Balanced, <0.1

Sample sizes
    Control Treated
All     721     279

 - - - Time: 2 - - - 
Balance Measures
        Type Diff.Un     M.Threshold.Un
w1    Binary  0.0347     Balanced, <0.1
w2   Contin.  0.0052     Balanced, <0.1
L1_1  Binary -0.0674     Balanced, <0.1
L2_1 Contin.  0.1119 Not Balanced, >0.1
L1_2  Binary -0.0525     Balanced, <0.1
L2_2 Contin.  0.0257     Balanced, <0.1
Z_1   Binary  0.2794 Not Balanced, >0.1

Balance tally for mean differences
Balanced, <0.1         5
Not Balanced, >0.1     2

Variable with the greatest mean difference
 Variable Diff.Un     M.Threshold.Un
      Z_1  0.2794 Not Balanced, >0.1

Sample sizes
    Control Treated
All     631     369
 - - - - - - - - - - - 
  • Step 2, using package WeightIt to calculate visit specific propensity scores, we will use stabilized weights

Wmsm <- weightitMSM(
  list(Z_1 ~ w1 + w2 + L1_1 + L2_1,
       Z_2 ~ w1 + w2 + L1_1 + L2_1 + L1_2 + L2_2 + Z_1),
  data = causaldata, 
  method = "ps",
  stabilize = TRUE)

A weightitMSM object
 - method: "ps" (propensity score weighting)
 - number of obs.: 1000
 - sampling weights: none
 - number of time points: 2 (Z_1, Z_2)
 - treatment: 
    + time 1: 2-category
    + time 2: 2-category
 - covariates: 
    + baseline: w1, w2, L1_1, L2_1
    + after time 1: w1, w2, L1_1, L2_1, L1_2, L2_2, Z_1
 - stabilized; stabilization factors:
    + baseline: (none)
    + after time 1: Z_1
summary(Wmsm) # examine if there are extreme weights
                 Summary of weights

                       Time 1                       
                 Summary of weights

- Weight ranges:

           Min                                  Max
treated 0.7346 |---------------------------| 1.5541
control 0.7644  |--------------------|       1.3881

- Units with 5 most extreme weights by group:
            427     40    550    528    927
 treated 1.4448 1.4965 1.5176 1.5276 1.5541
            812    598    305    713    819
 control 1.3372  1.344 1.3496 1.3498 1.3881

- Weight statistics:

        Coef of Var   MAD Entropy # Zeros
treated       0.154 0.120   0.011       0
control       0.103 0.082   0.005       0

- Mean of Weights = 1

- Effective Sample Sizes:

           Control Treated
Unweighted  721.    279.  
Weighted    713.41  272.58

                       Time 2                       
                 Summary of weights

- Weight ranges:

           Min                                  Max
treated 0.7346 |---------------------------| 1.5541
control 0.7778   |---------------|           1.2647

- Units with 5 most extreme weights by group:
            427     40    550    528    927
 treated 1.4448 1.4965 1.5176 1.5276 1.5541
            804    692    956    201    830
 control 1.2129 1.2149 1.2151 1.2321 1.2647

- Weight statistics:

        Coef of Var   MAD Entropy # Zeros
treated       0.158 0.126   0.012       0
control       0.090 0.074   0.004       0

- Mean of Weights = 1

- Effective Sample Sizes:

           Control Treated
Unweighted  631.    369.  
Weighted    625.95  360.05
  • Step 3, assess the post-weighting covariates balance, if you observe covariates that are not balanced by treatment, you go back to step 2 and update the treatment model (considering adding interaction terms and polynomial terms).
    • not an issue if the previous treatment is not balanced, the both visit 1 and visit 2’s treatment will be modelled in the marginal outcome model., 
        stats = c("m"),
        thresholds = c(m = .1),
        which.time = .none)
 weightitMSM(formula.list = list(Z_1 ~ w1 + w2 + L1_1 + L2_1, 
    Z_2 ~ w1 + w2 + L1_1 + L2_1 + L1_2 + L2_2 + Z_1), data = causaldata, 
    method = "ps", stabilize = TRUE)

Balance summary across all time points
           Times     Type Max.Diff.Adj        M.Threshold
prop.score  1, 2 Distance       0.5946                   
w1          1, 2   Binary       0.0054     Balanced, <0.1
w2          1, 2  Contin.       0.0082     Balanced, <0.1
L1_1        1, 2   Binary       0.0029     Balanced, <0.1
L2_1        1, 2  Contin.       0.0060     Balanced, <0.1
L1_2           2   Binary       0.0046     Balanced, <0.1
L2_2           2  Contin.       0.0572     Balanced, <0.1
Z_1            2   Binary       0.2792 Not Balanced, >0.1

Balance tally for mean differences
Balanced, <0.1         6
Not Balanced, >0.1     1

Variable with the greatest mean difference
 Variable Max.Diff.Adj        M.Threshold
      Z_1       0.2792 Not Balanced, >0.1

Effective sample sizes
 - Time 1
           Control Treated
Unadjusted  721.    279.  
Adjusted    713.41  272.58
 - Time 2
           Control Treated
Unadjusted  631.    369.  
Adjusted    625.95  360.05
  • Step 4, fitting weighted linear regression using the survey package. We will treat the stabilized, visit-specific weights are survey weights in the marginal outcome model.
    • The estimated ATE between always and never treated is -3.1134.

# first create a survey object;
msm_design <- svydesign(~1, weights = Wmsm$weights, data = causaldata)

fitMSM <- svyglm(y ~ Z_1*Z_2, 
                 design = msm_design)


svyglm(formula = y ~ Z_1 * Z_2, design = msm_design)

Survey design:
svydesign(~1, weights = Wmsm$weights, data = causaldata)

            Estimate Std. Error t value            Pr(>|t|)    
(Intercept)  2.33541    0.05156  45.299 <0.0000000000000002 ***
Z_1         -1.21435    0.11887 -10.216 <0.0000000000000002 ***
Z_2         -2.15788    0.09641 -22.382 <0.0000000000000002 ***
Z_1:Z_2      0.25885    0.16929   1.529               0.127    
Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

(Dispersion parameter for gaussian family taken to be 1.331051)

Number of Fisher Scoring iterations: 2
APO_11 <- predict(fitMSM, newdata = data.frame(Z_1=1,Z_2=1))
APO_00 <- predict(fitMSM, newdata = data.frame(Z_1=0,Z_2=0))

APO_11 - APO_00
     link     SE
1 -3.1134 0.0888
# How to trim weights?;
# generally weights greater than 10 is considered large;
# weight truncation if needed can be done as following;
# trim <- quantile(Wmsm$weights, c(.99)) #obtain 99th percentile of the weights;
# sw_trim <- ifelse(Wmsm$weights > trim, trim, Wmsm$weights)

# using bootstrap to obtain SE and confidence interval of the ATE;
boot.est <- rep(NA, 1000)
for (i in 1:1000){

  boot.idx <- sample(1:dim(causaldata)[1], size = dim(causaldata)[1], replace = T) <- causaldata[boot.idx,]
  msm_design <- svydesign(~1, weights = Wmsm$weights, data =
  fitMSM <- svyglm(y ~ Z_1*Z_2, design = msm_design)
  boot.est[i] <- predict(fitMSM, newdata = data.frame(Z_1=1,Z_2=1))[1] - predict(fitMSM, newdata = data.frame(Z_1=0,Z_2=0))[1]

# SE of ATE;
[1] 0.1000114
#95% CI
quantile(boot.est, probs = c(0.025, 0.975))
     2.5%     97.5% 
-3.314099 -2.910470 

1.3 Implementing parametric g-computation

  • Variance of the g-method is obtained via bootstrap, thus takes some time to run

  • This package require a long-format data and a time variable that begin with 0 for baseline visit

  • Step 1, preparing the long-format data for the analysis

# preparing the data;
# first transform wide data to long data;
causaldata_long <- causaldata %>%
  mutate(id = rep(1:1000)) %>% 
  pivot_longer(cols = -c(w1,w2,y,id), 
               names_to = c("variable","visit"), 
               names_sep = "_", 
               values_to = "value") %>% 
  pivot_wider(names_from = variable, values_from = value) %>% 
  mutate(time = case_when(visit == 1 ~ 0,
                          visit == 2 ~ 1)) 

# Y is only measured at the end-of-study,
# thus, when we pivot to long format visit 1's y will have a missing value; 
causaldata_long$y[causaldata_long$visit == 1] <- NA

# look at the new data;
          rownames = FALSE,
          options = list(dom = 't')) %>%
  formatRound(columns=c('w2', 'L2', 'y'), digits=2)
  • Step 2, implement parametric g-computation using gfoRmula

id <- 'id'
time_name <- 'time'
covnames <- c("L1", "L2", "Z")
outcome_name <- 'y'
covtypes <- c('binary', 'normal', 'binary')
histories <- c(lagged) #lagged feature to call for lagged value from the long format data;
histvars <- list(c('Z', 'L1', 'L2'))

covparams <- list(
  covmodels = c(L1 ~ w1 + w2 + lag1_L1 + lag1_Z,
                L2 ~ lag1_L2 + w1 + w2 + lag1_Z,
                Z ~ w1 + w2 + lag1_L1 + lag1_L2 + lag1_Z))

ymodel <- y ~ lag1_Z*Z + w1 + w2 + lag1_L1 + lag1_L2 + L1 + L2

intvars <- list('Z', 'Z')
interventions <- list(list(c(static, rep(0, 2))),
                      list(c(static, rep(1, 2))))
int_descript <- c('Never treat', 'Always treat')

gform_cont_eof <- gformula_continuous_eof(
  obs_data = causaldata_long,
  id = id,
  time_name = time_name,
  covnames =covnames,
  outcome_name = outcome_name, 
  covtypes = c("binary", "normal", "binary"),
  covparams = covparams,  
  ymodel = ymodel,
  intvars = intvars, 
  interventions = interventions,
  int_descript = int_descript, 
  ref_int = 1,
  histories = c(lagged), 
  histvars = list(c('Z',"L1","L2")), #variables that are time-dependent;
  basecovs = c("w1","w2"), #time-independent baseline var;
  nsimul = 1000, 
  nsamples = 1000, 
  parallel = TRUE, 
  ncores = 6, #bootstrap features;
  seed = 123)


Intervention     Description
0        Natural course
1        Never treat
2        Always treat

Sample size = 1000, Monte Carlo sample size = 1000
Number of bootstrap samples = 1000
Reference intervention = 1


 k Interv.  NP mean g-form mean    Mean SE Mean lower 95% CI Mean upper 95% CI
 1       0 1.244963   1.3036349 0.06182675         1.1277058         1.3639202
 1       1       NA   2.3297637 0.04988542         2.2393043         2.4367416
 1       2       NA  -0.7816274 0.08063701        -0.9322829        -0.6090699
 Mean ratio      MR SE MR lower 95% CI MR upper 95% CI Mean difference
  0.5595567 0.02353069       0.4871446       0.5767769       -1.026129
  1.0000000 0.00000000       1.0000000       1.0000000        0.000000
 -0.3354964 0.03562982      -0.4004484      -0.2592769       -3.111391
      MD SE MD lower 95% CI MD upper 95% CI
 0.05905593       -1.207822      -0.9763218
 0.00000000        0.000000       0.0000000
 0.09248469       -3.297456      -2.9325170
  • Using g-computation, the estimate ATE is -3.111391 with SE = 0.09248469.

1.4 Implementing Targeted maximum likelihood estimation

  • ltmle package requires the input data to only include model needed variables!
    • make sure to remove variable you will not be modelling from the data, e.g., id etc
    • this package uses wide data
# Step 1, if applicable remove variables we don't need;
[1] "w1"   "w2"   "L1_1" "L2_1" "Z_1"  "L1_2" "L2_2" "Z_2"  "y"   
# Step 2, fitting conventional tmle without superlearner (machine learning algorithm);

tmle_model <- ltmle(data = causaldata,
                    Anodes = c("Z_1","Z_2"),
                    Lnodes = c("L1_1", "L2_1", "L1_2", "L2_2"), 
                    Ynodes = c("y"), 
                    survivalOutcome =FALSE,
                    gform = c("Z_1 ~ w1 + w2 + L1_1 + L2_1",
                              "Z_2 ~ w1 + w2 + L1_1 + L2_1 + L1_2 + L2_2 + Z_1"),
                    abar = list(c(1,1), c(0,0)))

summary(tmle_model, estimator="tmle")
Estimator:  tmle 
ltmle(data = causaldata, Anodes = c("Z_1", "Z_2"), Lnodes = c("L1_1", 
    "L2_1", "L1_2", "L2_2"), Ynodes = c("y"), survivalOutcome = FALSE, 
    gform = c("Z_1 ~ w1 + w2 + L1_1 + L2_1", "Z_2 ~ w1 + w2 + L1_1 + L2_1 + L1_2 + L2_2 + Z_1"), 
    abar = list(c(1, 1), c(0, 0)))

Treatment Estimate:
   Parameter Estimate:  -0.77939 
    Estimated Std Err:  0.082294 
              p-value:  <2e-16 
    95% Conf Interval: (-0.94068, -0.61809) 

Control Estimate:
   Parameter Estimate:  2.3371 
    Estimated Std Err:  0.048802 
              p-value:  <2e-16 
    95% Conf Interval: (2.2415, 2.4328) 

Additive Treatment Effect:
   Parameter Estimate:  -3.1165 
    Estimated Std Err:  0.093326 
              p-value:  <2e-16 
    95% Conf Interval: (-3.2994, -2.9336) 
# Step 3, fitting tmle with superlearner on gform and Qform models;
tmle_model_sup <- ltmle(causaldata,
                    Anodes = c ("Z_1","Z_2") ,
                    Lnodes = c ("L1_1", "L2_1", "L1_2", "L2_2"), 
                    Ynodes = c("y"), 
                    survivalOutcome =FALSE,
                    gform = c("Z_1 ~ w1 + w2 + L1_1 + L2_1",
                              "Z_2 ~ w1 + w2 + L1_1 + L2_1 + L1_2 + L2_2 + Z_1"),
                    SL.library =  c("SL.mean"), #see SuperLearner() function for detail, try SL.glm for binary outcome, other functions: SL.poisglm, SL.randomForest, SL.gbm;
                    abar = list(c(1,1), c(0,0)),
                    estimate.time = FALSE)

summary(tmle_model_sup, estimator="tmle")
Estimator:  tmle 
ltmle(data = causaldata, Anodes = c("Z_1", "Z_2"), Lnodes = c("L1_1", 
    "L2_1", "L1_2", "L2_2"), Ynodes = c("y"), survivalOutcome = FALSE, 
    gform = c("Z_1 ~ w1 + w2 + L1_1 + L2_1", "Z_2 ~ w1 + w2 + L1_1 + L2_1 + L1_2 + L2_2 + Z_1"), 
    abar = list(c(1, 1), c(0, 0)), SL.library = c("SL.mean"), 
    estimate.time = FALSE)

Treatment Estimate:
   Parameter Estimate:  -0.78091 
    Estimated Std Err:  0.14249 
              p-value:  <2e-16 
    95% Conf Interval: (-1.0602, -0.50164) 

Control Estimate:
   Parameter Estimate:  2.336 
    Estimated Std Err:  0.058324 
              p-value:  <2e-16 
    95% Conf Interval: (2.2217, 2.4504) 

Additive Treatment Effect:
   Parameter Estimate:  -3.117 
    Estimated Std Err:  0.15396 
              p-value:  <2e-16 
    95% Conf Interval: (-3.4187, -2.8152) 
  • The estimated ATE under conventional TMLE is -3.1165 with SE = 0.093326 and 95% CI: (-3.2994, -2.9336).

  • The estimated ATE under superlearning TMLE is -3.117 with SE = 0.15396 (quite large!) and 95% CI: (-3.4187, -2.8152).