'How to add legend to plot from CausalImpact package?

I see the other post here about this, but I'm relatively new to R so the answers weren't helpful to me. I'd really appreciate some more in-depth help with how to do this.

I've already made a plot using the commands from the Causal Impact package. In the package documentation, it clearly says that the plots are ggplot2 objects and can be customized the same way as any other object like that. I've successfully done that, adding titles and customizing colors. I need to add a legend (it's required at the journal I'm submitting to). Here is an example of what my graph currently looks like and the code I used to get there.

enter image description here

library(ggplot2)
devtools::install_github("google/CausalImpact")
library(CausalImpact)

## note that I took this example code from the package documentation up until I customize the plot

#create data
set.seed(1)
x1 <- 100 + arima.sim(model = list(ar = 0.999), n = 100)
y <- 1.2 * x1 + rnorm(100)
y[71:100] <- y[71:100] + 10
data <- cbind(y, x1)

#causal impact analysis
> pre.period <- c(1, 70)
> post.period <- c(71, 100)
> impact <- CausalImpact(data, pre.period, post.period)

#graph
example<-plot(impact, c("original", "cumulative")) +
    labs(
        x = "Time",
        y = "Clicks (Millions)",
        title = "Figure. Analysis of click behavior after intervention.") +
    theme(plot.title = element_text(hjust = 0.5),
          plot.caption = element_text(hjust = 0),
          panel.background = element_rect(fill = "transparent"), # panel bg
          plot.background = element_rect(fill = "transparent", color = NA), # plot bg
          panel.grid.major = element_blank(), # get rid of major grid
          panel.grid.minor = element_blank())  # get rid of minor grid

In my head, the solution I'd like is to have a legend for each panel of the plot. The first legend (next to the 'original' panel) would show a solid line represents the observed data, the dotted line represents the estimated counterfactual, and the colored band represents the 95% CrI around the estimated counterfactual. The second legend (next to the 'cumulative' panel) would show the dotted line represents the estimated change in trend associated with the intervention and the colored band again represents the 95% CrI around the estimation. Maybe there's a better solution than that, but that's what I've thought of.

Here is a section of the underlying code that runs when you plot:

# Initialize plot
  q <- ggplot(data, aes(x = time)) + theme_bw(base_size = 15)
  q <- q + xlab("") + ylab("")
  if (length(metrics) > 1) {
    q <- q + facet_grid(metric ~ ., scales = "free_y")
  }

  # Add prediction intervals
  q <- q + geom_ribbon(aes(ymin = lower, ymax = upper),
                       data, fill = "slategray2")

  # Add pre-period markers
  xintercept <- CreatePeriodMarkers(impact$model$pre.period,
                                    impact$model$post.period,
                                    time(impact$series))
  q <- q + geom_vline(xintercept = xintercept,
                      colour = "darkgrey", size = 0.8, linetype = "dashed")

  # Add zero line to pointwise and cumulative plot
  q <- q + geom_line(aes(y = baseline),
                     colour = "darkgrey", size = 0.8, linetype = "solid", 
                     na.rm = TRUE)

  # Add point predictions
  q <- q + geom_line(aes(y = mean), data,
                     size = 0.6, colour = "darkblue", linetype = "dashed",
                     na.rm = TRUE)

  # Add observed data
  q <- q + geom_line(aes(y = response), size = 0.6,  na.rm = TRUE)
  return(q)
}

One of the answers in that older post here said that I'd have to adapt the pre-existing function to get a legend, and I don't really have the skills yet to see what I'd have to change or add. I thought that legends were supposed to be automatically added according to what's in the aes() bit of the ggplot code, so I'm a little confused why there isn't one in the first place. Can someone help me with this?



Solution 1:[1]

I rewrote the plot function. Instead of using facet_wrap(), I created individual plots with their own legends and used patchwork to group them together into a single plot. In order to run this you need to memory all of the source code including impact_analysis.R, impact_misc.R, impact_model.R, impact_inference.R and impact_plot.R with the exception of the CreateImpactPlot function which I recreated. So instead, run what I have below. You will also need to load ggplot2, tidyr, dplyr, and patchwork. This will only run for Original and Cumulative metrics. Though I revised to some extent for Pointwise, I did not want to do this as I didn't have an example to reproduce. I worked your theme preferences directly into the code in the function. You should be able to identify and change those elements now at your leisure. To be clear, the plots are q1 = original, q2 = pointwise, and q3 = cumulative. I don't see how to bring the confidence band into the legend as it is not part of aes(). Possibly could create a grob from scratch. I just referenced it in the title which you can change if it doesn't suit you. Hopefully this helps.

                                                 "cumulative")) {
    # Creates a plot of observed data and counterfactual predictions.
    #
    # Args:
    #   impact:  \code{CausalImpact} results object returned by
    #            \code{CausalImpact()}.
    #   metrics: Which metrics to include in the plot. Can be any combination of
    #            "original", "pointwise", and "cumulative".
    #
    # Returns:
    #   A ggplot2 object that can be plotted using plot().
    
    # Create data frame of: time, response, mean, lower, upper, metric
    data <- CreateDataFrameForPlot(impact)
    
    # Select metrics to display (and their order)
    assert_that(is.vector(metrics))
    metrics <- match.arg(metrics, several.ok = TRUE)
    data <- data[data$metric %in% metrics, , drop = FALSE]
    data$metric <- factor(data$metric, metrics)
    
    # Initialize plot
    #q <- ggplot(data, aes(x = time)) + theme_bw(base_size = 15)
    #q <- q + xlab("") + ylab("")
    #if (length(metrics) > 1) {
    #    q <- q + facet_grid(metric ~ ., scales = "free_y")
    #}
    
    q1 <- ggplot(data, aes(x = time)) + theme_bw(base_size = 15)
    q1 <- q1 + xlab("") + ylab("")
    
    q2 <- ggplot(data, aes(x = time)) + theme_bw(base_size = 15)
    q2 <- q2 + xlab("") + ylab("")
    
    q3 <- ggplot(data, aes(x = time)) + theme_bw(base_size = 15)
    q3 <- q3 + xlab("") + ylab("")
    
    
    # Add prediction intervals
    #q <- q + geom_ribbon(aes(ymin = lower, ymax = upper),
    #                     data, fill = "slategray2")
    
    
    
    q1 <- q1 + geom_ribbon(data = data %>% dplyr::filter(metric == "original"), aes(x = time, ymin = lower, ymax = upper),
                         fill = "slategray2")
    
    q2 <- q2 + geom_ribbon(data = data %>% dplyr::filter(metric == "pointwise"), aes(x = time, ymin = lower, ymax = upper),
                           fill = "slategray2")
    
    q3 <- q3 + geom_ribbon(data = data %>% dplyr::filter(metric == "cumulative"), aes(x = time, ymin = lower, ymax = upper),
                           fill = "slategray2")
    
    
    
    # Add pre-period markers
    xintercept <- CreatePeriodMarkers(impact$model$pre.period,
                                      impact$model$post.period,
                                      time(impact$series))
    
    
    #q <- q + geom_vline(xintercept = xintercept,
    #                    colour = "darkgrey", size = 0.8, linetype = "dashed")
    
    q1 <- q1 + geom_vline(xintercept = xintercept,
                        colour = "darkgrey", size = 0.8, linetype = "dotted")
    
    q2 <- q2 + geom_vline(xintercept = xintercept,
                        colour = "darkgrey", size = 0.8, linetype = "dotted")
    
    q3 <- q3 + geom_vline(xintercept = xintercept,
                        colour = "darkgrey", size = 0.8, linetype = "dotted")
    
    
    
    data_long <- data %>%
        tidyr::pivot_longer(cols = c("baseline", "mean", "response"), names_to = "variable",
            values_to = "value", values_drop_na = TRUE)
    
    
    # Add zero line to pointwise and cumulative plot
    #q <- q + geom_line(aes(y = baseline),
    #                   colour = "darkgrey", size = 0.8, linetype = "solid", 
    #                   na.rm = TRUE)
    
    q1 <- q1 + geom_line(data = data_long %>% dplyr::filter(metric == "original"), 
                      aes(x = time, y = value, linetype = variable, group = variable,
                            size = variable),
                       na.rm = TRUE)+
                scale_linetype_manual(guide = "Legend", labels = c("estimated counterfactual", "oberserved"), 
                                      values = c("dashed", "solid")) +
            scale_size_manual(values = c(0.6, 0.8)) +
            scale_color_manual(values = c("darkblue", "darkgrey")) +
            theme(legend.position = "right") +
            guides(linetype = guide_legend("Legend", nrow=2), size = "none", color = "none")+
            labs(title = "Original", y = "Clicks (Millions)") +
            theme(
                panel.background = element_rect(fill = "transparent"), # panel bg
                plot.background = element_rect(fill = "transparent", color = NA), # plot bg
                panel.grid.major = element_blank(), # get rid of major grid
                panel.grid.minor = element_blank())
            
    
    #q2 <- q2 + geom_line(data = data_long %>% dplyr::filter(metric == "pointwise"), 
    #                     aes(x = time, y = value, linetype = Line, group = Line),
    #                     na.rm = TRUE) +
    #        scale_linetype_manual(title = "Legend", labels = c("estimated counterfactual", "observed"), 
    #                          values = c("dashed", "solid")) +
    #        scale_size_manual(values = c(0.6, 0.8)) +
    #        scale_color_manual(values = c("darkblue", "darkgrey")) +
    #        theme(legend.position = "right") +
    #        guides(linetype = guide_legend("Legend", nrow=2), size = "none", color = "none")+
    #        labs(title = "Pointwise", y = "Clicks (Millions)")
    
    
    q3 <- q3 + geom_line(data = data_long %>% dplyr::filter(metric == "cumulative"), 
                         aes(x = time, y = value, linetype = variable, group = variable),
                         na.rm = TRUE) +
        scale_linetype_manual(labels = c("observed", "estimated trend change"), 
                              values = c("solid", "dashed")) +
            theme(legend.position = "right")+ 
            guides(linetype = guide_legend("Legend", nrow=2))+
            labs(title = "Cumulative",x = "Time",  y = "Clicks (Millions)")+
            theme(
                panel.background = element_rect(fill = "transparent"), # panel bg
                plot.background = element_rect(fill = "transparent", color = NA), # plot bg
                panel.grid.major = element_blank(), # get rid of major grid
                panel.grid.minor = element_blank())
        
    patchwork <- q1 / q3
    
    q <- patchwork + plot_annotation(title = "Figure. Analysis of click behavior after intervention with
                                95% Confidence Interval") 
    
    # Add point predictions
    #q <- q + geom_line(aes(y = mean), data,
    #                   size = 0.6, colour = "darkblue", linetype = "dashed",
    #                   na.rm = TRUE)
    
    # Add observed data
    #q <- q + geom_line(aes(y = response), size = 0.6,  na.rm = TRUE)
    
    
    return(q)
}

plot(impact, c("original", "cumulative")) 

modified plot

Solution 2:[2]

Here is a rebuild of the CreateImpactPlot() function that will work for all three metrics. The legends can be modified. I introduced more colors and linetypes so that the legends could be applicable across all the facets.

The base case looks like this:

plot(impact)

base case

You will note that the labels in the legend for the ribbons and for the lines refer to the metrics. These are placeholder labels that you can then modify.

line_labels <- c("cumulative_mean" = "change in trend", "baseline" = "baseline", "original_mean" =
                     "estimated counterfactual", "original_response" = "observed")

plot(impact, c("original", "cumulative")) +
    labs(
        x = "Time",
        y = "Clicks (Millions)",
        title = "Figure. Analysis of click behavior after intervention.") +
    theme(plot.title = element_text(hjust = 0.5),
          plot.caption = element_text(hjust = 0),
          panel.background = element_rect(fill = "transparent"), # panel bg
          plot.background = element_rect(fill = "transparent", color = NA), # plot bg
          panel.grid.major = element_blank(), # get rid of major grid
          panel.grid.minor = element_blank()) + # get rid of minor grid
    
        scale_fill_manual(name = "95% Crl", values = c("original" = "slategray2", "cumulative" = "darkseagreen"),
                      labels = c("original" = "counterfactual", "cumulative" = "estimation")) +
    
        scale_linetype_manual(name = "Legend", labels = line_labels, 
                          values = c("cumulative_mean" = "dotted", "baseline" = "solid", "original_mean" =
                                         "dotted", "original_response" = "solid")) +
    
        scale_color_manual(name = "Legend", labels = line_labels,
                       values = c("cumulative_mean" = "red", "baseline" = "darkgrey", "original_mean"= "darkblue", "original_response" = "goldenrod"))

The vector "line_labels" is where you define the text you want to appear in the Legend. You will note that I removed the pointwise related values as I am excluding the pointwise metric from the plot. The scale_linetype_manual and scale_color_manual have to have the Name and labels kept in synch in order to have a combined legend, otherwise you will have two separate legends. The scale_fill_manual is for the ribbons. For these scales, you can change the names, the labels and the values as you desire. You can copy the code out of the function, revise it, and add it to the plot call as shown above.

modified

Here is the code for the revised function. In the example, everything should be run and "impact" generated from the CausalImpact package. Then all of the package code needs to be loaded into memory including impact_analysis.R, impact_misc.R, impact_model.R, impact_inference.R and impact_plot.R. Then load the code below.

CreateImpactPlot2 <- function(impact, metrics = c("original", "pointwise","cumulative")) {
    # Creates a plot of observed data and counterfactual predictions.
    #
    # Args:
    #   impact:  \code{CausalImpact} results object returned by
    #            \code{CausalImpact()}.
    #   metrics: Which metrics to include in the plot. Can be any combination of
    #            "original", "pointwise", and "cumulative".
    #
    # Returns:
    #   A ggplot2 object that can be plotted using plot().
    
    # Create data frame of: time, response, mean, lower, upper, metric
    data <- CreateDataFrameForPlot(impact)
    
    # Select metrics to display (and their order)
    assert_that(is.vector(metrics))
    metrics <- match.arg(metrics, several.ok = TRUE)
    data <- data[data$metric %in% metrics, , drop = FALSE]
    data$metric <- factor(data$metric, metrics)
    
    data_long <- data %>%
        tidyr::pivot_longer(cols = c("baseline", "mean", "response"), names_to = "variable",
                            values_to = "value", values_drop_na = TRUE) %>%
        mutate(variable2 = factor(ifelse(variable == "baseline", variable, paste0(metric,"_", variable))),
               variable = factor(variable))
    
    
    
    # Initialize plot
    q <- ggplot(data, aes(x = time)) + theme_bw(base_size = 15)
    q <- q + xlab("") + ylab("")
    if (length(metrics) > 1) {
        q <- q + facet_grid(metric ~ ., scales = "free_y")
    }
    
  
    
     #Add prediction intervals
    q <- q + geom_ribbon(aes(x = time, ymin = lower, ymax = upper, fill = metric), data_long)
    
  
    
    # Add pre-period markers
    xintercept <- CreatePeriodMarkers(impact$model$pre.period,
                                      impact$model$post.period,
                                      time(impact$series))
    
    
    q <- q + geom_vline(xintercept = xintercept,
                        colour = "darkgrey", size = 0.8, linetype = "dashed")

    
    
    # Add zero line to pointwise and cumulative plot
    q <- q + geom_line(data = data_long %>% dplyr::filter(variable == "baseline"), 
                       aes(x = time, y = value, linetype = variable2, group = variable2, size = variable2, color = variable2),
                       na.rm = TRUE)
 
    # Add point predictions
    q <- q + geom_line(data = data_long %>% dplyr::filter(variable == "mean"), 
                       aes(x = time, y = value, linetype = variable2, group = variable2, size = variable2, color = variable2),
                       na.rm = TRUE)
    
    
    # Add observed data
    q <- q + geom_line(data = data_long %>% dplyr::filter(variable == "response"), 
                       aes(x = time, y = value, linetype = variable2, group = variable2, size = variable2, color = variable2),
                       na.rm = TRUE)
    
    
    #Add scales
    line_labels <- c("cumulative_mean" = "cumulative_mean", "baseline" = "baseline", "original_mean" =
                         "original_mean", "original_response" = "original_response", "pointwise_mean"=
                         "pointwise_mean")
    
    q <- q + scale_linetype_manual(name = "Legend", labels = line_labels, 
                                                        values = c("cumulative_mean" = "dotted", "baseline" = "solid", "original_mean" =
                                                                       "dotted", "original_response" = "solid", "pointwise_mean"=
                                                                       "solid")) +
        
            scale_size_manual(values = c("cumulative_mean" = 0.6, "baseline" = 0.8, "original_mean"= 0.6, "original_response" = 0.5, 
                                         "pointwise_mean"= 0.6)) +
        
            scale_color_manual(name = "Legend", labels = line_labels,
                                                values = c("cumulative_mean" = "red", "baseline" = "darkgrey", "original_mean"= "darkblue", "original_response" = "goldenrod", 
                                                            "pointwise_mean"= "darkgreen")) +
                                  
            scale_fill_manual(name = "95% Crl", values = c("original" = "slategray2", "pointwise" = "pink3", "cumulative" = "darkseagreen"),
                                         labels = c("original" = "original", "pointwise" = "pointwise", "cumulative" = "cumulative")) +
    
            guides(size = "none") 
    
    return(q)
}


plot.CausalImpact <- function(x, ...) {
    # Creates a plot of observed data and counterfactual predictions.
    #
    # Args:
    #   x:   A \code{CausalImpact} results object, as returned by
    #        \code{CausalImpact()}.
    #   ...: Can be used to specify \code{metrics}, which determines which panels
    #        to include in the plot. The argument \code{metrics} can be any
    #        combination of "original", "pointwise", "cumulative". Partial matches
    #        are allowed.
    #
    # Returns:
    #   A ggplot2 object that can be plotted using plot().
    #
    # Examples:
    #   \dontrun{
    #   impact <- CausalImpact(...)
    #
    #   # Default plot:
    #   plot(impact)
    #
    #   # Customized plot:
    #   impact.plot <- plot(impact) + ylab("Sales")
    #   plot(impact.plot)
    #   }
    
    return(CreateImpactPlot2(x, ...))
}

Solution 3:[3]

Here is an updated/edited version of an earlier solution in order to merge aesthetics into one legend. The requirement was to merge linetype and fill (ribbon color) into one legend.

enter image description here

In order to merge legends, the same aesthetics have to be used in the geoms and the scales have to account for the different variables, have the same name and the same labels. So geom_ribbon() needs to have a linetype in the aes() as well as fill, and the geom_line() needs to have a fill in the aes() as well as the linetype. One side effect of adding a linetype to geom_ribbon() is that you then get a line around both edges of the band. On the other hand, fill is not applicable to geom_line so you just get a warning message that the fill aesthetic will be ignored.

The way to address this is to apply a linetype of "blank" to the relevant value in scale_linetype_manual(). Similarly, we use "transparent" in scale_fill_manual() to avoid applying a color to the other elements of the scale.

What I didn't realize before working through this is that it is possible to create a legend for an aesthetic for values across multiple variables. The values just have to be mapped appropriately in the scale. So I truly learned something new putting this together.

CreateImpactPlot <- function(impact, metrics = c("original",  "cumulative")) {
    # Creates a plot of observed data and counterfactual predictions.
    #
    # Args:
    #   impact:  \code{CausalImpact} results object returned by
    #            \code{CausalImpact()}.
    #   metrics: Which metrics to include in the plot. Can be any combination of
    #            "original", "pointwise", and "cumulative".
    #
    # Returns:
    #   A ggplot2 object that can be plotted using plot().
    
    # Create data frame of: time, response, mean, lower, upper, metric
    data <- CreateDataFrameForPlot(impact)
    
    # Select metrics to display (and their order)
    assert_that(is.vector(metrics))
    metrics <- match.arg(metrics, several.ok = TRUE)
    data <- data[data$metric %in% metrics, , drop = FALSE]
    data$metric <- factor(data$metric, metrics)
    
    # Make data longer
    data_long <- data %>%
        tidyr::pivot_longer(cols = c("baseline", "mean", "response"), names_to = "variable",
                            values_to = "value", values_drop_na = TRUE)
    
    # Initialize plot
    q1 <- ggplot(data, aes(x = time)) + theme_bw(base_size = 15)
    q1 <- q1 + xlab("") + ylab("")
    

    q3 <- ggplot(data %>% 
                     filter(metric == "cumulative") %>%
                     mutate(metric = factor(metric, levels = c("cumulative"))), aes(x = time)) + theme_bw(base_size = 15)
    q3 <- q3 + xlab("") + ylab("")
    
    
    # Add prediction intervals

    q1 <- q1 + geom_ribbon(data = data %>% 
                               filter(metric == "original") %>%
                               mutate(metric = factor(metric, levels = c("original"))), aes(x = time, ymin = lower, ymax = upper, fill = metric, 
                                                                                            linetype = metric))
    q3 <- q3 + geom_ribbon(data = data %>% 
                               filter(metric == "cumulative") %>%
                               mutate(metric = factor(metric, levels = c("cumulative"))), aes(x = time, ymin = lower, ymax = upper, fill = metric))
    

    # Add pre-period markers
    xintercept <- CreatePeriodMarkers(impact$model$pre.period,
                                      impact$model$post.period,
                                      time(impact$series))

    q1 <- q1 + geom_vline(xintercept = xintercept,
                          colour = "darkgrey", size = 0.8, linetype = "dashed")
    
    
    q3 <- q3 + geom_vline(xintercept = xintercept,
                          colour = "darkgrey", size = 0.8, linetype = "dashed")
    
    

    # Add zero line to cumulative plot
    # Add point predictions
    # Add observed data
    
    q1 <- q1 + geom_line(data = data_long %>% dplyr::filter(metric == "original"), 
                         aes(x = time, y = value, linetype = variable, group = variable,
                             size = variable, fill = variable, color = variable),
                         na.rm = TRUE)+
        scale_linetype_manual(name = "Legend", labels = c("mean"= "estimated counterfactual", "response" = "oberserved", "original" = "95% Crl counterfactual"), 
                              values = c("dashed", "solid", "blank"), limits = c("mean", "response","original")) +
        
        scale_fill_manual(name = "Legend", labels = c("mean"= "estimated counterfactual", "response" = "oberserved", "original" = "95% Crl counterfactual"), 
                              values = c("transparent", "transparent","slategray2"), limits = c("mean", "response","original")) +  #limits controls the order in the legend
        
        scale_size_manual(values = c(0.6, 0.8, 0.5)) +
        scale_color_manual(values = c("darkgray", "darkblue")) +
        theme(legend.position = "right", axis.text.x = element_blank(), axis.title.y = element_blank()) +
        guides(size = "none", color = "none")+
        facet_wrap(~metric[1], strip.position = "right", drop = TRUE) #use facet_wrap to generate the stip
    

    
    q3 <- q3 + geom_line(data = data_long %>% dplyr::filter(metric == "cumulative"), 
                         aes(x = time, y = value, linetype = variable, group = variable,
                              fill = variable),
                         na.rm = TRUE) +
        scale_linetype_manual(name = "Legend", labels = c("mean"= "estimated trend change", "baseline" = "oberserved", "cumulative" = "95% Crl estimation"),
                              values = c("dashed", "solid", "blank"), limits = c("mean", "baseline","cumulative")) +
        
        scale_fill_manual(name = "Legend", labels = c("mean"= "estimated trend change", "baseline" = "oberserved", "cumulative" = "95% Crl estimation"),
                          values = c("transparent", "transparent","slategray2"), limits = c("mean", "baseline","cumulative")) +  #limits controls the order in the legend
        
        theme(legend.position = "right", axis.title.y = element_blank())+ 
        labs(x = "Time") +
        facet_wrap(~metric, strip.position = "right", drop = TRUE) #use facet_wrap to generate the stip
    
    
    
    g1 <- grid::textGrob("Clicks (Millions)", rot = 90, gp=gpar(fontsize = 15), x= 0.85)
    
    wrap_elements(g1) | (q1/q3) 
    
    patchwork <- wrap_elements(g1) | (q1/q3) 
    
    q <- patchwork 
    
 
    return(q)
}


# To run the function

plot(impact, c("original", "cumulative")) + 
    plot_annotation(title = "Figure. Analysis of click behavior after intervention"
                    , theme = theme(plot.title = element_text(hjust = 0.5))) &
    theme(
        panel.background = element_rect(fill = "transparent"), # panel bg
        plot.background = element_rect(fill = "transparent", color = NA), # plot bg
        panel.grid.major = element_blank(), # get rid of major grid
        panel.grid.minor = element_blank())

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 stomper
Solution 2 stomper
Solution 3