library(DiagrammeR)
grViz("
digraph causal {
# Nodes
node [shape=plaintext]
W [label = 'w1, w2']
L1 [label = 'L11, L21']
a_1 [label = 'a_1']
L2 [label = 'L12, L22']
a_2 [label = 'a_2']
Y [label = 'Y']
# Edges
edge [color=black, arrowhead=vee]
rankdir = LR
W->L1
W->a_1
W->L2
W->a_2
W->Y
L1->a_1
L1->L2
L1->a_2
L1->Y
a_1->L2
a_1->a_2
a_1->Y
L2->a_2
L2->Y
a_2->Y
# Graph
graph [overlap=true, fontsize=14]
}")
Causal Inference for Longitudinal Data
1. Hands-on excerise performing causal analysis for time-dependent treatment/exposure with longtiudinal observaitonal data
Special acknowledgement to Yutong Lu, my undergraduate research student who is undertaking a R package review project for causal analysis with longitudinal data. Yutong developed the code to simulate the dataset we will be using in this tutorial.
1.1 Simulated observational data with a time-dependent treatment
- The simulated datset
- 1000 patients and 3 visits (2 of which patients were assigned a treatment)
- y, an end-of-study continuous outcome
- z, a binary treatment
- w1 and w2 are two baseline covariates (one continuous and one binary, mimicking age and sex)
- L1 and L2 are two time-dependent covariates (also one continuous and one binary)
- no missing data
- The simulated DAG
library(tidyverse)
library(DT)
options(scipen = 999)
<- read.csv("continuous_outcome_data.csv", header = TRUE, fileEncoding="UTF-8-BOM")
causaldata
# look at the data;
datatable(causaldata,
rownames = FALSE,
options = list(dom = 't')) %>%
formatRound(columns=c('w2', 'L2_1', 'L2_2', 'y'), digits=2)
# frequency counts by treatment combinations;
table(causaldata$a_1, causaldata$a_2)
0 1
0 520 201
1 111 168
- Suppose the causal parameter of interest is the average treatment effect between always treated and never treated, \[ ATE = E(Y|a_1 = 1, a_2 =1) - E(Y|a_1 = 0, a_2 =0) \]
1.2 Implementing marginal structural models
- Step 1, getting a glimpse of covariates balance by visit using the cobalt package
library(cobalt) #package to assess covariates balance by treatment;
#covariates balance at each visit;
bal.tab(list(a_1 ~ w1 + w2 + L1_1 + L2_1,
~ w1 + w2 + L1_1 + L2_1 + L1_2 + L2_2 + a_1),
a_2 data = causaldata,
int = FALSE,
poly = 1,
estimand = "ATE",
stats = c("m"),
thresholds = c(m = 0.1),
which.time = .all)
Balance by Time Point
- - - Time: 1 - - -
Balance Measures
Type Diff.Un M.Threshold.Un
w1 Binary 0.0470 Balanced, <0.1
w2 Contin. 0.0035 Balanced, <0.1
L1_1 Binary -0.0432 Balanced, <0.1
L2_1 Contin. 0.0292 Balanced, <0.1
Balance tally for mean differences
count
Balanced, <0.1 4
Not Balanced, >0.1 0
Variable with the greatest mean difference
Variable Diff.Un M.Threshold.Un
w1 0.047 Balanced, <0.1
Sample sizes
Control Treated
All 721 279
- - - Time: 2 - - -
Balance Measures
Type Diff.Un M.Threshold.Un
w1 Binary 0.0347 Balanced, <0.1
w2 Contin. 0.0052 Balanced, <0.1
L1_1 Binary -0.0674 Balanced, <0.1
L2_1 Contin. 0.1119 Not Balanced, >0.1
L1_2 Binary -0.0525 Balanced, <0.1
L2_2 Contin. 0.0257 Balanced, <0.1
a_1 Binary 0.2794 Not Balanced, >0.1
Balance tally for mean differences
count
Balanced, <0.1 5
Not Balanced, >0.1 2
Variable with the greatest mean difference
Variable Diff.Un M.Threshold.Un
a_1 0.2794 Not Balanced, >0.1
Sample sizes
Control Treated
All 631 369
- - - - - - - - - - -
- Step 2, using package WeightIt to calculate visit specific propensity scores, we will use stabilized weights
library(WeightIt)
<- weightitMSM(
Wmsm list(a_1 ~ w1 + w2 + L1_1 + L2_1,
~ w1 + w2 + L1_1 + L2_1 + L1_2 + L2_2 + a_1),
a_2 data = causaldata,
method = "ps",
stabilize = TRUE)
Wmsm
A weightitMSM object
- method: "glm" (propensity score weighting with GLM)
- number of obs.: 1000
- sampling weights: none
- number of time points: 2 (a_1, a_2)
- treatment:
+ time 1: 2-category
+ time 2: 2-category
- covariates:
+ baseline: w1, w2, L1_1, L2_1
+ after time 1: w1, w2, L1_1, L2_1, L1_2, L2_2, a_1
- stabilized; stabilization factors:
+ baseline: (none)
+ after time 1: a_1
summary(Wmsm) # examine if there are extreme weights
Summary of weights
Time 1
Summary of weights
- Weight ranges:
Min Max
treated 0.7346 |---------------------------| 1.5541
control 0.7644 |--------------------| 1.3881
- Units with the 5 most extreme weights by group:
427 40 550 528 927
treated 1.4448 1.4965 1.5176 1.5276 1.5541
812 598 305 713 819
control 1.3372 1.344 1.3496 1.3498 1.3881
- Weight statistics:
Coef of Var MAD Entropy # Zeros
treated 0.154 0.120 0.011 0
control 0.103 0.082 0.005 0
- Mean of Weights = 1
- Effective Sample Sizes:
Control Treated
Unweighted 721. 279.
Weighted 713.41 272.58
Time 2
Summary of weights
- Weight ranges:
Min Max
treated 0.7346 |---------------------------| 1.5541
control 0.7778 |---------------| 1.2647
- Units with the 5 most extreme weights by group:
427 40 550 528 927
treated 1.4448 1.4965 1.5176 1.5276 1.5541
804 692 956 201 830
control 1.2129 1.2149 1.2151 1.2321 1.2647
- Weight statistics:
Coef of Var MAD Entropy # Zeros
treated 0.158 0.126 0.012 0
control 0.090 0.074 0.004 0
- Mean of Weights = 1
- Effective Sample Sizes:
Control Treated
Unweighted 631. 369.
Weighted 625.95 360.05
- Step 3, assess the post-weighting covariates balance, if you observe covariates that are not balanced by treatment, you go back to step 2 and update the treatment model (considering adding interaction terms and polynomial terms).
- not an issue if the previous treatment is not balanced, the both visit 1 and visit 2’s treatment will be modelled in the marginal outcome model.
bal.tab(Wmsm,
stats = c("m"),
thresholds = c(m = .1),
which.time = .none)
Balance summary across all time points
Times Type Max.Diff.Adj M.Threshold
prop.score 1, 2 Distance 0.5946
w1 1, 2 Binary 0.0054 Balanced, <0.1
w2 1, 2 Contin. 0.0082 Balanced, <0.1
L1_1 1, 2 Binary 0.0029 Balanced, <0.1
L2_1 1, 2 Contin. 0.0060 Balanced, <0.1
L1_2 2 Binary 0.0046 Balanced, <0.1
L2_2 2 Contin. 0.0572 Balanced, <0.1
a_1 2 Binary 0.2792 Not Balanced, >0.1
Balance tally for mean differences
count
Balanced, <0.1 6
Not Balanced, >0.1 1
Variable with the greatest mean difference
Variable Max.Diff.Adj M.Threshold
a_1 0.2792 Not Balanced, >0.1
Effective sample sizes
- Time 1
Control Treated
Unadjusted 721. 279.
Adjusted 713.41 272.58
- Time 2
Control Treated
Unadjusted 631. 369.
Adjusted 625.95 360.05
- Step 4, fitting weighted linear regression using the survey package. We will treat the stabilized, visit-specific weights are survey weights in the marginal outcome model.
- The estimated ATE between always and never treated is -3.1134.
library(survey)
# first create a survey object;
<- svydesign(~1, weights = Wmsm$weights, data = causaldata)
msm_design
<- svyglm(y ~ a_1*a_2,
fitMSM design = msm_design)
summary(fitMSM)
Call:
svyglm(formula = y ~ a_1 * a_2, design = msm_design)
Survey design:
svydesign(~1, weights = Wmsm$weights, data = causaldata)
Coefficients:
Estimate Std. Error t value Pr(>|t|)
(Intercept) 2.33541 0.05156 45.299 <0.0000000000000002 ***
a_1 -1.21435 0.11887 -10.216 <0.0000000000000002 ***
a_2 -2.15788 0.09641 -22.382 <0.0000000000000002 ***
a_1:a_2 0.25885 0.16929 1.529 0.127
---
Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
(Dispersion parameter for gaussian family taken to be 1.331051)
Number of Fisher Scoring iterations: 2
<- predict(fitMSM, newdata = data.frame(a_1=1,a_2=1))
APO_11 <- predict(fitMSM, newdata = data.frame(a_1=0,a_2=0))
APO_00
- APO_00 APO_11
link SE
1 -3.1134 0.0888
# How to trim weights?;
# generally weights greater than 10 is considered large;
# weight truncation if needed can be done as following;
# trim <- quantile(Wmsm$weights, c(.99)) #obtain 99th percentile of the weights;
# sw_trim <- ifelse(Wmsm$weights > trim, trim, Wmsm$weights)
# using bootstrap to obtain SE and confidence interval of the ATE;
set.seed(123)
<- rep(NA, 1000)
boot.est for (i in 1:1000){
<- sample(1:dim(causaldata)[1], size = dim(causaldata)[1], replace = T)
boot.idx <- causaldata[boot.idx,]
boot.data
<- svydesign(~1, weights = Wmsm$weights, data = boot.data)
msm_design
<- svyglm(y ~ a_1*a_2, design = msm_design)
fitMSM
<- predict(fitMSM, newdata = data.frame(a_1=1,a_2=1))[1] - predict(fitMSM, newdata = data.frame(a_1=0,a_2=0))[1]
boot.est[i]
}
# SE of ATE;
sd(boot.est)
[1] 0.1000114
#95% CI
quantile(boot.est, probs = c(0.025, 0.975))
2.5% 97.5%
-3.314099 -2.910470
1.3 Implementing parametric g-computation
Variance of the g-method is obtained via bootstrap, thus takes some time to run
This package require a long-format data and a time variable that begin with 0 for baseline visit
Step 1, preparing the long-format data for the analysis
# preparing the data;
# first transform wide data to long data;
<- causaldata %>%
causaldata_long mutate(id = rep(1:1000)) %>%
pivot_longer(cols = -c(w1,w2,y,id),
names_to = c("variable","visit"),
names_sep = "_",
values_to = "value") %>%
pivot_wider(names_from = variable, values_from = value) %>%
mutate(time = case_when(visit == 1 ~ 0,
== 2 ~ 1))
visit
# Y is only measured at the end-of-study,
# thus, when we pivot to long format visit 1's y will have a missing value;
$y[causaldata_long$visit == 1] <- NA
causaldata_long
# look at the new data;
datatable(causaldata_long,
rownames = FALSE,
options = list(dom = 't')) %>%
formatRound(columns=c('w2', 'L2', 'y'), digits=2)
- Step 2, implement parametric g-computation using gfoRmula
library(gfoRmula)
<- 'id'
id <- 'time'
time_name <- c("L1", "L2", "a")
covnames <- 'y'
outcome_name <- c('binary', 'normal', 'binary')
covtypes <- c(lagged) #lagged feature to call for lagged value from the long format data;
histories <- list(c('a', 'L1', 'L2'))
histvars
<- list(
covparams covmodels = c(L1 ~ w1 + w2 + lag1_L1 + lag1_a,
~ lag1_L2 + w1 + w2 + lag1_a,
L2 ~ w1 + w2 + lag1_L1 + lag1_L2 + lag1_a))
a
<- y ~ lag1_a*a + w1 + w2 + lag1_L1 + lag1_L2 + L1 + L2
ymodel
<- list('a', 'a')
intvars <- list(list(c(static, rep(0, 2))),
interventions list(c(static, rep(1, 2))))
<- c('Never treat', 'Always treat')
int_descript
<- gformula_continuous_eof(
gform_cont_eof obs_data = causaldata_long,
id = id,
time_name = time_name,
covnames =covnames,
outcome_name = outcome_name,
covtypes = c("binary", "normal", "binary"),
covparams = covparams,
ymodel = ymodel,
intvars = intvars,
interventions = interventions,
int_descript = int_descript,
ref_int = 1,
histories = c(lagged),
histvars = list(c('a',"L1","L2")), #variables that are time-dependent;
basecovs = c("w1","w2"), #time-independent baseline var;
nsimul = 1000,
nsamples = 1000,
parallel = TRUE,
ncores = 6, #bootstrap features;
seed = 123)
summary(gform_cont_eof)
PREDICTED RISK UNDER MULTIPLE INTERVENTIONS
Intervention Description
0 Natural course
1 Never treat
2 Always treat
Sample size = 1000, Monte Carlo sample size = 1000
Number of bootstrap samples = 1000
Reference intervention = 1
k Interv. NP mean g-form mean Mean SE Mean lower 95% CI Mean upper 95% CI
1 0 1.244963 1.3036349 0.06182675 1.1277058 1.3639202
1 1 NA 2.3297637 0.04988542 2.2393043 2.4367416
1 2 NA -0.7816274 0.08063701 -0.9322829 -0.6090699
Mean ratio MR SE MR lower 95% CI MR upper 95% CI Mean difference
0.5595567 0.02353069 0.4871446 0.5767769 -1.026129
1.0000000 0.00000000 1.0000000 1.0000000 0.000000
-0.3354964 0.03562982 -0.4004484 -0.2592769 -3.111391
MD SE MD lower 95% CI MD upper 95% CI % Intervened On
0.05905593 -1.207822 -0.9763218 0.0
0.00000000 0.000000 0.0000000 45.1
0.09248469 -3.297456 -2.9325170 83.3
Aver % Intervened On
0.00
26.40
54.95
- Using g-computation, the estimate ATE is -3.111391 with SE = 0.09248469.
1.4 Implementing Targeted maximum likelihood estimation
- ltmle package requires the input data to only include model needed variables!
- make sure to remove variable you will not be modelling from the data, e.g., id etc
- this package uses wide data
library(ltmle)
# Step 1, if applicable remove variables we don't need;
colnames(causaldata)
[1] "w1" "w2" "L1_1" "L2_1" "a_1" "L1_2" "L2_2" "a_2" "y"
# Step 2, fitting conventional tmle without superlearner (machine learning algorithm);
<- ltmle(data = causaldata,
tmle_model Anodes = c("a_1","a_2"),
Lnodes = c("L1_1", "L2_1", "L1_2", "L2_2"),
Ynodes = c("y"),
survivalOutcome =FALSE,
gform = c("a_1 ~ w1 + w2 + L1_1 + L2_1",
"a_2 ~ w1 + w2 + L1_1 + L2_1 + L1_2 + L2_2 + a_1"),
abar = list(c(1,1), c(0,0)))
summary(tmle_model, estimator="tmle")
Estimator: tmle
Call:
ltmle(data = causaldata, Anodes = c("a_1", "a_2"), Lnodes = c("L1_1",
"L2_1", "L1_2", "L2_2"), Ynodes = c("y"), survivalOutcome = FALSE,
gform = c("a_1 ~ w1 + w2 + L1_1 + L2_1", "a_2 ~ w1 + w2 + L1_1 + L2_1 + L1_2 + L2_2 + a_1"),
abar = list(c(1, 1), c(0, 0)))
Treatment Estimate:
Parameter Estimate: -0.77939
Estimated Std Err: 0.082294
p-value: <2e-16
95% Conf Interval: (-0.94068, -0.61809)
Control Estimate:
Parameter Estimate: 2.3371
Estimated Std Err: 0.048802
p-value: <2e-16
95% Conf Interval: (2.2415, 2.4328)
Additive Treatment Effect:
Parameter Estimate: -3.1165
Estimated Std Err: 0.093326
p-value: <2e-16
95% Conf Interval: (-3.2994, -2.9336)
# Step 3, fitting tmle with superlearner on gform and Qform models;
<- ltmle(causaldata,
tmle_model_sup Anodes = c ("a_1","a_2") ,
Lnodes = c ("L1_1", "L2_1", "L1_2", "L2_2"),
Ynodes = c("y"),
survivalOutcome =FALSE,
gform = c("a_1 ~ w1 + w2 + L1_1 + L2_1",
"a_2 ~ w1 + w2 + L1_1 + L2_1 + L1_2 + L2_2 + a_1"),
SL.library = c("SL.mean"), #see SuperLearner() function for detail, try SL.glm for binary outcome, other functions: SL.poisglm, SL.randomForest, SL.gbm;
abar = list(c(1,1), c(0,0)),
estimate.time = FALSE)
summary(tmle_model_sup, estimator="tmle")
Estimator: tmle
Call:
ltmle(data = causaldata, Anodes = c("a_1", "a_2"), Lnodes = c("L1_1",
"L2_1", "L1_2", "L2_2"), Ynodes = c("y"), survivalOutcome = FALSE,
gform = c("a_1 ~ w1 + w2 + L1_1 + L2_1", "a_2 ~ w1 + w2 + L1_1 + L2_1 + L1_2 + L2_2 + a_1"),
abar = list(c(1, 1), c(0, 0)), SL.library = c("SL.mean"),
estimate.time = FALSE)
Treatment Estimate:
Parameter Estimate: -0.78091
Estimated Std Err: 0.14249
p-value: <2e-16
95% Conf Interval: (-1.0602, -0.50164)
Control Estimate:
Parameter Estimate: 2.336
Estimated Std Err: 0.058324
p-value: <2e-16
95% Conf Interval: (2.2217, 2.4504)
Additive Treatment Effect:
Parameter Estimate: -3.117
Estimated Std Err: 0.15396
p-value: <2e-16
95% Conf Interval: (-3.4187, -2.8152)
The estimated ATE under conventional TMLE is -3.1165 with SE = 0.093326 and 95% CI: (-3.2994, -2.9336).
The estimated ATE under superlearning TMLE is -3.117 with SE = 0.15396 (quite large!) and 95% CI: (-3.4187, -2.8152).