'How to generate covariate-adjusted cox survival/hazard functions?

I'm using the survminer package to try to generate survival and hazard function graphs for a longitudinal student-level dataset that has 5 subgroups of interest.

I've had success creating a model that shows the survival functions without adjusting for student-level covariates using ggsurvplot.

ggsurvplot(survfit(Surv(expectedgr, sped) ~ langstatus_new, data=mydata), pvalue=TRUE)

Output example

However, I cannot manage to get these curves adjusted for covariates. My aim is to create graphs like these. As you can see, these are covariate-adjusted survival curves according to some factor variable. Does anyone how such graphs can be obtained in R?



Solution 1:[1]

Although correct, I believe that the method described in the answer of Dion Groothof is not what is usually of interest. Usually, researchers are interested in visualizing the causal effect of a variable adjusted for confounders. Simply showing the predicted survival curve for one single covariate combination does not really do the trick here. I would recommend reading up on confounder-adjusted survival curves. See https://arxiv.org/abs/2203.10002 for example.

Those type of curves can be calculated in R using the adjustedCurves package: https://github.com/RobinDenz1/adjustedCurves

In your example, the following code could be used:

library(survival)
library(devtools)

# install adjustedCurves from github, load it
devtools::install_github("/RobinDenz1/adjustedCurves")
library(adjustedCurves)

# "event" needs to be binary
lung$status <- lung$status - 1

# "variable" needs to be a factor
lung$ph.ecog <- factor(lung$ph.ecog)

fit <- coxph(Surv(time, status) ~  ph.ecog + age + sex, data=lung,
             x=TRUE)

# calculate and plot curves
adj <- adjustedsurv(data=lung, variable="ph.ecog", ev_time="time",
                    event="status", method="direct",
                    outcome_model=fit, conf_int=TRUE)
plot(adj)

Producing the following output:

Output

These survival curves are adjusted for the effect of age and sex. More information on how this adjustment works can be found in the documentation of the adjustedCurves package or the article I cited above.

Solution 2:[2]

You want to obtain survival probabilities from a Cox model for certain values of some covariate of interest, while adjusting for other covariates. However, because we do not make any assumption on the distribution of the survival times in a Cox model, we cannot directly obtain survival probabilities from it. We first have to estimate the baseline hazard function, which is typically done with the non-parametric Breslow estimator. When the Cox model is fitted with coxph from the survival package, we can obtain such probabilites with a call to the survfit() function. You may consult ?survfit.coxph for more information.

Let's see how we can do this by using the lung data set.

library(survival)

# select covariates of interest
df <- subset(lung, select = c(time, status, age, sex, ph.karno))

# assess whether there are any missing observations
apply(df, 2, \(x) sum(is.na(x))) # 1 in ph.karno

# listwise delete missing observations
df <- df[complete.cases(df), ]

# Cox model
fit <- coxph(Surv(time, status == 2) ~ age + sex + ph.karno, data = df)

## Note that I ignore the fact that ph.karno does not satisfy the PH assumption.

# specify for which combinations of values of age, sex, and 
# ph.karno we want to derive survival probabilies
ND1 <- with(df, expand.grid(
  age = median(age),
  sex = c(1,2),
  ph.karno = median(ph.karno)
))
ND2 <- with(df, expand.grid(
  age = median(age),
  sex = 1, # males
  ph.karno = round(create_intervals(n_groups = 3L))
))

# Obtain the expected survival times
sfit1 <- survfit(fit, newdata = ND1)
sfit2 <- survfit(fit, newdata = ND2)

The code behind the function create_intervals() can be found in this post. I just simply replaced speed with ph.karno in the function.

The output sfit1 contains the expected median survival times and the corresponding 95% confidence intervals for the combinations of covariates as specified in ND1.

> sfit1
Call: survfit(formula = fit, newdata = ND)

    n events median 0.95LCL 0.95UCL
1 227    164    283     223     329
2 227    164    371     320     524

Survival probabilities at specific follow-up times be obtained with the times argument of the summary() method.

# survival probabilities at 200 days of follow-up
summary(sfit1, times = 200)

The output contains again the expected survival probability, but now after 200 days of follow-up, wherein survival1 corresponds to the expected survival probability of the first row of ND1, i.e. a male and female patient of median age with median ph.karno.

> summary(sfit1, times = 200)
Call: survfit(formula = fit, newdata = ND1)

 time n.risk n.event survival1 survival2
  200    144      71     0.625     0.751

The 95% confidence limits associated with these two probabilities can be manually extracted from summary().

sum_sfit <- summary(sfit1, times = 200)
sum_sfit <- t(rbind(sum_sfit$surv, sum_sfit$lower, sum_sfit$upper))
colnames(sum_sfit) <- c("S_hat", "2.5 %", "97.5 %")
# ------------------------------------------------------
> sum_sfit
      S_hat     2.5 %    97.5 %
1 0.6250586 0.5541646 0.7050220
2 0.7513961 0.6842830 0.8250914

If you would like to use ggplot to depict the expected survival probabilities (and the corresponding 95% confidence intervals) for the combinations of values as specified in ND1 and ND2, we first need to make data.frames that contain all the information in an appropriate format.

# function which returns the output from a survfit.object
# in an appropriate format, which can be used in a call
# to ggplot()
df_fun <- \(surv_obj, newdata, factor) {
  len <- length(unique(newdata[[factor]]))
  out <- data.frame(
    time = rep(surv_obj[['time']], times = len),
    n.risk = rep(surv_obj[['n.risk']], times = len),
    n.event = rep(surv_obj[['n.event']], times = len),
    surv = stack(data.frame(surv_obj[['surv']]))[, 'values'],
    upper = stack(data.frame(surv_obj[['upper']]))[, 'values'],
    lower = stack(data.frame(surv_obj[['lower']]))[, 'values']
  )
  out[, 7] <- gl(len, length(surv_obj[['time']]))
  names(out)[7] <- 'factor'
  return(out)
}

# data for the first panel (A)
df_leftPanel <- df_fun(surv_obj = sfit1, newdata = ND1, factor = 'sex')

# data for the second panel (B)
df_rightPanel <- df_fun(surv_obj = sfit2, newdata = ND2, factor = 'ph.karno')

Now that we have defined our data.frames, we need to define a new function which allows us to plot the 95% CIs. We assign it the generic name geom_stepribbon.

library(ggplot2)

# Function for geom_stepribbon
geom_stepribbon <- function(
  mapping     = NULL,
  data        = NULL,
  stat        = "identity",
  position    = "identity",
  na.rm       = FALSE,
  show.legend = NA,
  inherit.aes = TRUE, ...) {
  layer(
    data        = data,
    mapping     = mapping,
    stat        = stat,
    geom        = GeomStepribbon,
    position    = position,
    show.legend = show.legend,
    inherit.aes = inherit.aes,
    params      = list(na.rm = na.rm, ... )
  )
}

GeomStepribbon <- ggproto(
  "GeomStepribbon", GeomRibbon,
  extra_params = c("na.rm"),
  draw_group = function(data, panel_scales, coord, na.rm = FALSE) {
    if (na.rm) data <- data[complete.cases(data[c("x", "ymin", "ymax")]), ]
    data   <- rbind(data, data)
    data   <- data[order(data$x), ]
    data$x <- c(data$x[2:nrow(data)], NA)
    data   <- data[complete.cases(data["x"]), ]
    GeomRibbon$draw_group(data, panel_scales, coord, na.rm = FALSE)
  }
)

Finally, we can plot the expected survival probabilities for ND1 and ND2.

yl <- 'Expected Survival probability\n'
xl <- '\nTime (days)'

# left panel
my_colours <- c('blue4', 'darkorange')
adj_colour <- \(x) adjustcolor(x, alpha.f = 0.2)
my_colours <- c(
  my_colours, adj_colour(my_colours[1]), adj_colour(my_colours[2])
)
left_panel <- ggplot(df_leftPanel,
                     aes(x = time, colour = factor, fill = factor)) + 
  geom_step(aes(y = surv), size = 0.8) + 
  geom_stepribbon(aes(ymin = lower, ymax = upper), colour = NA) +
  scale_colour_manual(name = 'Sex',
                      values = c('1' = my_colours[1],
                                 '2' = my_colours[2]),
                      labels = c('1' = 'Males',
                                 '2' = 'Females')) +
  scale_fill_manual(name = 'Sex',
                    values = c('1' = my_colours[3],
                               '2' = my_colours[4]),
                    labels = c('1' = 'Males',
                               '2' = 'Females')) +
  ylab(yl) + xlab(xl) +
  theme(axis.text = element_text(size = 12),
        axis.title = element_text(size = 12),
        legend.text = element_text(size = 12),
        legend.title = element_text(size = 12),
        legend.position = 'top')

# right panel
my_colours <- c('blue4', 'darkorange', '#00b0a4')
my_colours <- c(
  my_colours, adj_colour(my_colours[1]),
  adj_colour(my_colours[2]), adj_colour(my_colours[3])
)
right_panel <- ggplot(df_rightPanel,
                      aes(x = time, colour = factor, fill = factor)) + 
  geom_step(aes(y = surv), size = 0.8) +  
  geom_stepribbon(aes(ymin = lower, ymax = upper), colour = NA) +
  scale_colour_manual(name = 'Ph.karno',
                      values = c('1' = my_colours[1],
                                 '2' = my_colours[2],
                                 '3' = my_colours[3]),
                      labels = c('1' = 'Low',
                                 '2' = 'Middle',
                                 '3' = 'High')) +
  scale_fill_manual(name = 'Ph.karno',
                    values = c('1' = my_colours[4],
                               '2' = my_colours[5],
                               '3' = my_colours[6]),
                    labels = c('1' = 'Low',
                               '2' = 'Middle',
                               '3' = 'High')) +
  ylab(yl) + xlab(xl) +
  theme(axis.text = element_text(size = 12),
        axis.title = element_text(size = 12),
        legend.text = element_text(size = 12),
        legend.title = element_text(size = 12),
        legend.position = 'top')

# composite plot
library(ggpubr)
ggarrange(left_panel, right_panel,
          ncol = 2, nrow = 1,
          labels = c('A', 'B'))

Output

enter image description here

Interpretation

  • Panel A shows the expected survival probabilities for a male and female patient of median age with a median ph.karno.
  • Panel B shows the expected survival probabilities for three male patients of median age with ph.karnos of 67 (low), 83 (middle), and 100 (high).

These survival curves will always satisfy the PH assumption, as they were derived from the Cox model.

Note: use function(x) instead of \(x) if you use a version of R <4.1.0

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 Denzo
Solution 2