Bootstrapping and plotting 95% confidence bands: 'Causal Inference: What If' Causal Survival Analysis. Parametric g-formula

In this post, I explore parametric g-formula fitting in the causal survival analysis context. I use the machinery of the tidyverse throughout the post and finish with plotting the 95% confidence band around the g-formula fitted survival curve for smokers vs non-smokers (see Chapter 17, Hernán MA, Robins JM (2020). Causal Inference: What If. Boca Raton: Chapman & Hall/CRC).

Getting the data

library(tidyverse)
library(magrittr)
library(conflicted)  
library(readxl)
library(wesanderson)
library(patchwork)

conflict_prefer("filter", "dplyr")
temp <- tempfile()

temp2 <- tempfile()

download.file("https://cdn1.sph.harvard.edu/wp-content/uploads/sites/1268/2017/01/nhefs_excel.zip", temp)

unzip(zipfile = temp, exdir = temp2)

data <- read_xls(file.path(temp2, "NHEFS.xls"))

unlink(c(temp, temp2))
data
## # A tibble: 1,629 x 64
##     seqn  qsmk death yrdth modth dadth   sbp   dbp   sex   age  race income
##    <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>  <dbl>
##  1   233     0     0    NA    NA    NA   175    96     0    42     1     19
##  2   235     0     0    NA    NA    NA   123    80     0    36     0     18
##  3   244     0     0    NA    NA    NA   115    75     1    56     1     15
##  4   245     0     1    85     2    14   148    78     0    68     1     15
##  5   252     0     0    NA    NA    NA   118    77     0    40     0     18
##  6   257     0     0    NA    NA    NA   141    83     1    43     1     11
##  7   262     0     0    NA    NA    NA   132    69     1    56     0     19
##  8   266     0     0    NA    NA    NA   100    53     1    29     0     22
##  9   419     0     1    84    10    13   163    79     0    51     0     18
## 10   420     0     1    86    10    17   184   106     0    43     0     16
## # … with 1,619 more rows, and 52 more variables: marital <dbl>, school <dbl>,
## #   education <dbl>, ht <dbl>, wt71 <dbl>, wt82 <dbl>, wt82_71 <dbl>,
## #   birthplace <dbl>, smokeintensity <dbl>, smkintensity82_71 <dbl>,
## #   smokeyrs <dbl>, asthma <dbl>, bronch <dbl>, tb <dbl>, hf <dbl>, hbp <dbl>,
## #   pepticulcer <dbl>, colitis <dbl>, hepatitis <dbl>, chroniccough <dbl>,
## #   hayfever <dbl>, diabetes <dbl>, polio <dbl>, tumor <dbl>,
## #   nervousbreak <dbl>, alcoholpy <dbl>, alcoholfreq <dbl>, alcoholtype <dbl>,
## #   alcoholhowmuch <dbl>, pica <dbl>, headache <dbl>, otherpain <dbl>,
## #   weakheart <dbl>, allergies <dbl>, nerves <dbl>, lackpep <dbl>,
## #   hbpmed <dbl>, boweltrouble <dbl>, wtloss <dbl>, infection <dbl>,
## #   active <dbl>, exercise <dbl>, birthcontrol <dbl>, pregnancies <dbl>,
## #   cholesterol <dbl>, hightax82 <dbl>, price71 <dbl>, price82 <dbl>,
## #   tax71 <dbl>, tax82 <dbl>, price71_82 <dbl>, tax71_82 <dbl>

As before, for this section I partly re-used the code by Joy Shi and Sean McGrath available here.

# define/recode variables 
data %<>% 
  mutate(
    # define censoring
    cens = if_else(is.na(wt82_71),1, 0),
    # categorize the school variable (as in R code by Joy Shi and Sean McGrath)
    education = cut(school, breaks = c(0, 8, 11, 12, 15, 20),
                    include.lowest = TRUE, 
                    labels = c('1. 8th Grage or Less',
                               '2. HS Dropout',
                               '3. HS',
                               '4. College Dropout',
                               '5. College or More')),
    # establish active as a factor variable
    active = factor(active),
    # establish exercise as a factor variable
    exercise = factor(exercise),
    # create a treatment label variable
    qsmklabel = if_else(qsmk == 1,
                        'Quit Smoking 1971-1982',
                        'Did Not Quit Smoking 1971-1982'),
  # survtime variable
  survtime = if_else(death == 0, 120, (yrdth-83)*12 + modth) # yrdth ranges from 83 to 92
    ) %>% 
  # ignore those with missing values on some variables
  filter(!is.na(education))

Secton 17.4

Here we are obtaining the ATE of smoking on mortality using g-formula. I use the power of tidyr::uncount to convert the wide data into long format so that one row in the dataset corresponds to an observation point in time per subject.

# expand original data set until the last observed month for everyone
months_gform <- data %>% 
  # using observed data and observed censoring times
  uncount(weights = survtime, .remove = F) %>% 
  group_by(seqn) %>% 
  mutate (
    time = row_number() - 1,
    event = case_when(
      time == survtime -1 & death == 1 ~ 1,
      TRUE ~ 0
    ),
    timesq = time*time
  ) %>% 
  ungroup() %>% 
  select(seqn, qsmk, time, timesq, age, sex, race, education, smokeintensity, smkintensity82_71, smokeyrs, exercise, active, wt71, event)

# fitting Q-model (model for the outcome) with confounders and outcome predictors for better precision in observation-month data
q_model <- glm(event == 0 ~ qsmk + I(qsmk*time) + I(qsmk*timesq) + time + timesq + sex + race + age + I(age*age) + as.factor(education) + smokeintensity + I(smokeintensity*smokeintensity) + smkintensity82_71 + smokeyrs + I(smokeyrs*smokeyrs) + as.factor(exercise) + as.factor(active) + wt71 + I(wt71*wt71), data = months_gform, family = binomial(link = "logit"))

q_model %>% broom::tidy(.)
## # A tibble: 25 x 5
##    term               estimate std.error statistic  p.value
##    <chr>                 <dbl>     <dbl>     <dbl>    <dbl>
##  1 (Intercept)       9.27      1.38          6.72  1.76e-11
##  2 qsmk              0.0596    0.415         0.143 8.86e- 1
##  3 I(qsmk * time)   -0.0149    0.0151       -0.987 3.24e- 1
##  4 I(qsmk * timesq)  0.000170  0.000125      1.37  1.72e- 1
##  5 time             -0.0227    0.00844      -2.69  7.14e- 3
##  6 timesq            0.000117  0.0000671     1.75  8.00e- 2
##  7 sex               0.437     0.141         3.10  1.93e- 3
##  8 race             -0.0524    0.173        -0.302 7.63e- 1
##  9 age              -0.0875    0.0591       -1.48  1.39e- 1
## 10 I(age * age)      0.0000813 0.000547      0.149 8.82e- 1
## # … with 15 more rows
# create separate datasets for observation-months assigning exposure levels in each copy to either 0 or 1
# unexposed
months0_gform <- data %>% 
  mutate(uncount = max(survtime)) %>% 
  # use the same censoring time point of 120 months for all individuals
  uncount(weights = uncount, .remove = T) %>% 
  group_by(seqn) %>% 
  mutate (
    qsmk = 0, # set exposure level
    time = row_number() - 1,
    event = case_when(
      time == survtime -1 & death == 1 ~ 1,
      TRUE ~ 0
    ),
    timesq = time*time
  ) %>% 
  ungroup() %>% 
  select(seqn, qsmk, time, timesq, age, sex, race, education, smokeintensity, smkintensity82_71, smokeyrs, exercise, active, wt71, event) %>%  
  mutate(
    p.not.event = predict(q_model, type = "response", newdata = .)
  ) %>% 
  group_by(seqn) %>% 
  # compute probability of survival for each individual over all time points
  mutate(p.surv = cumprod(p.not.event)) %>% 
  ungroup()

summary(months0_gform$p.surv)
##     Min.  1st Qu.   Median     Mean  3rd Qu.     Max. 
## 0.008535 0.891414 0.966453 0.906467 0.989128 0.999966
# exposed
months1_gform <- data %>% 
  mutate(uncount = max(survtime)) %>% 
  uncount(weights = uncount, .remove = T) %>% 
  group_by(seqn) %>% 
  mutate (
    qsmk = 1, # set exposure level
    time = row_number() - 1,
    event = case_when(
      time == survtime -1 & death == 1 ~ 1,
      TRUE ~ 0
    ),
    timesq = time*time
  ) %>% 
  ungroup() %>% 
  select(seqn, qsmk, time, timesq, age, sex, race, education, smokeintensity, smkintensity82_71, smokeyrs, exercise, active, wt71, event) %>%  
  mutate(
    p.not.event = predict(q_model, type = "response", newdata = .)
  ) %>% 
  group_by(seqn) %>% 
  # compute probability of survival for each individual over all time points
  mutate(p.surv = cumprod(p.not.event)) %>% 
  ungroup()

summary(months1_gform$p.surv)
##     Min.  1st Qu.   Median     Mean  3rd Qu.     Max. 
## 0.007834 0.875271 0.961561 0.895782 0.987537 0.999968
# compute average survival over each observation-month in each dataset (under exposure and no exposure regimes) and combine in one dataset
# unexposed
months0_gform_mean <- months0_gform %>% 
  select(seqn, time, timesq, qsmk, p.surv) %>% 
  group_by(time) %>% 
  mutate(p.surv.mean = mean(p.surv)) %>% 
  ungroup() %>% 
  select(time, qsmk, p.surv.mean)

# exposed
months1_gform_mean <- months1_gform %>% 
  select(seqn, qsmk, time, p.surv) %>% 
  group_by(time) %>% 
  mutate(p.surv.mean = mean(p.surv)) %>% 
  ungroup() %>% 
  select(time, qsmk, p.surv.mean)

# combine data
# difference
months_gform_mean_diff <- bind_cols(months0_gform_mean, months1_gform_mean[, "p.surv.mean"]) %>% 
  mutate(
    p.surv.diff = p.surv.mean...4 - p.surv.mean...3
  )
## New names:
## * p.surv.mean -> p.surv.mean...3
## * p.surv.mean -> p.surv.mean...4
# bind and plot
surv_diff_gform <- bind_rows(months0_gform_mean, months1_gform_mean)

# survival
p_gform_1 <- surv_diff_gform %>%
  ggplot(aes(x = time, y = p.surv.mean, color = factor(qsmk), fill = factor(qsmk))) +
  geom_line() +
  xlab("Months") + 
  scale_x_continuous(limits = c(0, 120), breaks = seq(0, 120, 12)) +
  scale_y_continuous(limits = c(0.7, 1), breaks = seq(0.7, 1, 0.1)) +
  ylab("Survival, probability") + 
  ggtitle("Fitting g-formula model") +
  theme_minimal()+
  labs(colour = "Smoking", fill = "Smoking") +
  scale_fill_manual(values = wes_palette("IsleofDogs1")) +
  scale_color_manual(values = wes_palette("IsleofDogs1"))

# 1-survival
p_gform_2 <- surv_diff_gform %>%
  ggplot(aes(x = time, y = 1 - p.surv.mean, color = factor(qsmk), fill = factor(qsmk))) +
  geom_line() +
  xlab("Months") +
  scale_x_continuous(limits = c(0, 120), breaks = seq(0, 120, 12)) +
  scale_y_continuous(limits = c(0, 0.3), breaks = seq(0, 0.3, 0.1)) +
  ylab("Death, probability") + 
  theme_minimal()+
  labs(colour = "Smoking", fill = "Smoking") +
  scale_fill_manual(values = wes_palette("IsleofDogs1")) +
  scale_color_manual(values = wes_palette("IsleofDogs1"))

# plot difference
p_gform_3 <- months_gform_mean_diff %>%
  ggplot(aes(x = time, y = p.surv.diff)) +
  geom_line() +
  xlab("Months") + 
  ylab("Difference") + 
  theme_minimal()

# combine plots
p_gform_1 / p_gform_2 / p_gform_3 +
  plot_layout(guides = "collect")
# combine all plots
comb3 <- p_gform_1 + p_gform_2 +
  plot_layout(ncol = 2, guides = "collect") + 
  plot_annotation(
  caption = "dataviz by Elena Dudukina @evpatora"
)
comb3
library(tidymodels)
## ── Attaching packages ────────────────────────────────────── tidymodels 0.1.2 ──
## ✓ broom     0.7.2      ✓ recipes   0.1.15
## ✓ dials     0.0.9      ✓ rsample   0.0.8 
## ✓ infer     0.5.3      ✓ tune      0.1.2 
## ✓ modeldata 0.1.0      ✓ workflows 0.2.1 
## ✓ parsnip   0.1.4      ✓ yardstick 0.0.7
times <- 100

# reduce size of the dataset by keeping only relevant vars
data %<>% select(seqn, survtime, death, qsmk, age, sex, race, education, smokeintensity, smkintensity82_71, smokeyrs, exercise, active, wt71)

set.seed(123456789)
boots <- bootstraps(data, times = times, apparent = FALSE)

# mutate seqn: original + n of repeat
boots %<>% mutate(
  splits = map(splits, ~ as_tibble(.x)),
  splits = map(splits, ~ group_by(.x, seqn) %>% 
                               mutate(rep = row_number()) %>% 
                               ungroup() %>% 
                               mutate(
                                 seqn = paste0(seqn, rep)
                               ))
)

months_gform_list <- map(boots$splits, ~ uncount(.x, weights = survtime, .remove = F) %>% 
  group_by(seqn) %>% 
  mutate (
    time = row_number() - 1,
    event = case_when(
      time == survtime - 1 & death == 1 ~ 1,
      TRUE ~ 0
    ),
    timesq = time*time
  ) %>% 
  ungroup()
  )

# Q-model function
gform_model <- function(data) {
  glm(formula = event == 0 ~ qsmk + I(qsmk*time) + I(qsmk*timesq) + time + timesq + sex + race + age + I(age*age) + as.factor(education) + smokeintensity + I(smokeintensity*smokeintensity) + smkintensity82_71 + smokeyrs + I(smokeyrs*smokeyrs) + as.factor(exercise) + as.factor(active) + wt71 + I(wt71*wt71), family = binomial(link = "logit"), data = data)
  }

# apply gform_model function to all re-sampled datasets
boot_models <- boots %>% 
  mutate(model = map(months_gform_list, ~gform_model(data = .x)))

# compute the P.O.s in each re-sampled long dataset
# create a function, which performs all the steps as in code chunk above with the possibility to set the smoking level
set_qsmk <- function(data, qsmk_level){
  data %<>%  
  mutate(uncount = max(survtime)) %>% 
  # use the same censoring time point of 120 months for all individuals
  uncount(weights = uncount, .remove = T) %>% 
  group_by(seqn) %>% 
  mutate (
    qsmk = qsmk_level, # set exposure level
    time = row_number() - 1,
    event = case_when(
      time == survtime -1 & death == 1 ~ 1,
      TRUE ~ 0
    ),
    timesq = time*time
  ) %>% 
  ungroup() %>% 
  select(seqn, qsmk, time, timesq, age, sex, race, education, smokeintensity, smkintensity82_71, smokeyrs, exercise, active, wt71, event) %>%  
  mutate(
    p.not.event = predict(q_model, type = "response", newdata = .)
  ) %>% 
  group_by(seqn) %>% 
  # compute probability of survival for each individual over all time points
  mutate(p.surv = cumprod(p.not.event)) %>% 
  ungroup()
}

# set exposure to 0
months0_gform_list <- map(.x = boots$splits, ~set_qsmk(data = .x, qsmk_level = 0))

# predict P.O.s
months0_gform_list <- map2(.x = months0_gform_list, .y = boot_models$model, ~  mutate(.x ,
             p.not.event = predict(.y, type = "response", newdata = .x)
  ) %>% 
  group_by(seqn) %>% 
  # compute probability of survival for each individual over all time points
  mutate(p.surv = cumprod(p.not.event)) %>% 
  ungroup()
  )

# list: every observation-month with the assigned exposure level = 1
# exposed
# set exposure to 1
months1_gform_list <- map(.x = boots$splits, ~set_qsmk(data = .x, qsmk_level = 1))

# predict P.O.s
months1_gform_list <- map2(.x = months1_gform_list, .y = boot_models$model, ~ mutate(.x ,
             p.not.event = predict(.y, type = "response", newdata = .x)
  ) %>% 
  group_by(seqn) %>% 
  # compute probability of survival for each individual over all time points
  mutate(p.surv = cumprod(p.not.event)) %>% 
  ungroup())

# compute average survival over each observation-month in each dataset (under exposure and no exposure) and combine
# unexposed
months0_gform_list %<>% 
  map(.x = ., ~ select(.x, seqn, qsmk, time, p.surv) %>% 
        group_by(time) %>% 
        # mean surv
        mutate(p.surv.mean = mean(p.surv)) %>% 
        ungroup() %>% 
        select(time, qsmk, p.surv.mean) %>% 
        distinct()
      )

# exposed
months1_gform_list %<>% 
  map(.x = ., ~ select(.x, seqn, qsmk, time, p.surv) %>% 
        group_by(time) %>% 
        # mean surv
        mutate(p.surv.mean = mean(p.surv)) %>% 
        ungroup() %>% 
        select(time, qsmk, p.surv.mean) %>% 
        distinct()
  )

# difference in survival in each observation month
surv_diff_gform_list <- map(list(months0_gform_list, months1_gform_list), ~bind_rows(., .id = "iteration"))

surv_diff_gform <- surv_diff_gform_list %>% bind_cols(.) %>% 
  mutate(
    surv_diff_gform = p.surv.mean...8 - p.surv.mean...4,
    iteration...1 = as.numeric(iteration...1)
  )
## New names:
## * iteration -> iteration...1
## * time -> time...2
## * qsmk -> qsmk...3
## * p.surv.mean -> p.surv.mean...4
## * iteration -> iteration...5
## * ...
surv_diff_gform %<>% 
  group_by(time...2) %>% 
  summarise_at(.vars = vars(surv_diff_gform), .funs = list(Q2.5 = ~ quantile(., probs = 0.025), Q50 = ~ quantile(., probs = 0.50), Q97.5 = ~ quantile(., probs = 0.975)))

# combine and plot
surv_diff_gform_plot <- surv_diff_gform_list %>% bind_rows(., .id = "id")

surv_diff_gform_plot %<>% 
  group_by(qsmk, time) %>% 
  summarise_at(.vars = vars(p.surv.mean), .funs = list(Q2.5 = ~ quantile(., probs = 0.025), Q50 = ~ quantile(., probs = 0.50), Q97.5 = ~ quantile(., probs = 0.975)))

# survival gform with 95% CIs
p_gform_1 <- surv_diff_gform_plot %>% 
  group_by(qsmk) %>% 
  ggplot(aes(x = time, y = Q50, color = factor(qsmk), fill = factor(qsmk))) +
  geom_line() +
  geom_ribbon(aes(ymin = Q2.5, ymax = Q97.5),
              alpha = .2, colour = NA) +
  xlab("Months") + 
  scale_x_continuous(limits = c(0, 120), breaks = seq(0,120,12)) +
  scale_y_continuous(limits = c(0.7, 1), breaks = seq(0.7, 1, 0.1)) +
  ylab("Survival, probability") + 
  ggtitle("Fitting g-formula") +
  theme_minimal()+
  labs(colour = "Smoking", fill = "Smoking") +
  scale_fill_manual(values = wes_palette("IsleofDogs1")) +
  scale_color_manual(values = wes_palette("IsleofDogs1"))

# 1-survival
p_gform_2 <- surv_diff_gform_plot %>% 
  group_by(qsmk) %>% 
  ggplot(aes(x = time, y = 1 - Q50, color = factor(qsmk), fill = factor(qsmk))) +
  geom_line() +
  geom_ribbon(aes(ymin = 1 - Q97.5, ymax = 1 - Q2.5), alpha = .2, colour = NA) +
  scale_x_continuous(limits = c(0, 120), breaks = seq(0, 120, 12)) +
  scale_y_continuous(limits = c(0, 0.3), breaks = seq(0, 0.3, 0.1)) +
  xlab("Months") + 
  ylab("Death, probability") + 
  theme_minimal()+
  labs(colour = "Smoking", fill = "Smoking") +
  scale_fill_manual(values = wes_palette("IsleofDogs1")) +
  scale_color_manual(values = wes_palette("IsleofDogs1"))

p_gform_3 <- surv_diff_gform %>% 
  ggplot(aes(x = time...2, y = Q50)) +
  geom_line() +
  geom_hline(yintercept = 0, color = "grey", linetype = 2) + 
  geom_ribbon(aes(ymin = Q2.5, ymax = Q97.5), alpha = .2, colour = NA) +
  scale_x_continuous(limits = c(0, 120), breaks = seq(0, 120, 12)) +
  xlab("Months") + 
  ylab("Difference") +
  theme_minimal()+
  labs(colour = "Smoking", fill = "Smoking")

# combine gform plots
p_gform_1 / p_gform_2 / p_gform_3 +
  plot_layout(guides = "collect")
# combine plots
comb4 <- p_gform_1 + p_gform_2 +
  plot_layout(ncol = 2, guides = "collect") + 
  plot_annotation(
  caption = "dataviz by Elena Dudukina @evpatora"
)

comb4

Final notes

As in previous post, I want to stress that causal research question, causal diagram incorporating the assumed data-generating mechanism, and consideration of causal assumptions are pivotal for causal analysis ✌️

Next
Previous

Related