'Extract split values from rpart object in R

I can't find the split values (or other data) for nodes in an rpart object. I see it with summary(sample_model) but not in the list or data frame

Some sample data

foo.df <- structure(list(type = c("fudai", "fudai", "fudai", "fudai", "fudai", 
                              "fudai", "fudai", "tozama", "fudai", "fudai", "tozama", "tozama", 
                              "fudai", "tozama", "fudai", "fudai", "tozama", "fudai", "fudai", 
                              "tozama", "fudai", "fudai", "fudai", "tozama", "fudai", "fudai", 
                              "tozama", "fudai", "fudai", "fudai", "fudai", "fudai", "tozama", 
                              "fudai", "fudai", "fudai", "fudai", "fudai", "fudai", "tozama", 
                              "tozama", "fudai", "tozama", "tozama", "tozama", "tozama", "fudai", 
                              "fudai", "tozama", "tozama"), distance = c(12.5366985071383, 
                                                                         272.697138147139, 40.4780423740381, 109.806349869662, 147.781805212839, 
                                                                         89.4280438527415, 49.1425850803745, 555.414271440522, 119.365138867582, 
                                                                         182.902536555383, 310.019126513348, 277.122207392514, 214.510428881317, 
                                                                         235.111617874157, 104.494518693549, 50.7561853895564, 343.308898045237, 
                                                                         151.796857505073, 36.0391449169937, 30.8214406651022, 343.294467363406, 
                                                                         135.841501028422, 154.798119311647, 317.739208576563, 3.33794280697559, 
                                                                         98.9182898110913, 422.915369767251, 194.957988642709, 87.6548263591412, 
                                                                         187.571370158631, 236.292608259126, 17.915709270268, 193.548578374405, 
                                                                         262.190146422316, 21.6219797945323, 121.199009527283, 261.670997612517, 
                                                                         202.2051991431, 125.418459536787, 275.964068539003, 190.112226847932, 
                                                                         20.1753302760961, 488.80323504215, 579.25515722891, 233.500797034697, 
                                                                         207.588349435329, 183.770003408524, 168.739293254246, 313.140075747773, 
                                                                         131.69228390613), age = c(1756, 1711, 1712, 1746, 1868, 1866, 
                                                                                                   1682, 1617, 1771, 1764, 1672, 1636, 1864, 1704, 1762, 1868, 1694, 
                                                                                                   1749, 1703, 1616, 1691, 1702, 1723, 1683, 1742, 1691, 1623, 1721, 
                                                                                                   1704, 1745, 1749, 1723, 1639, 1661, 1843, 1845, 1669, 1698, 1698, 
                                                                                                   1664, 1868, 1633, 1783, 1642, 1615, 1648, 1734, 1758, 1725, 1635
                                                                         )), class = c("tbl_df", "tbl", "data.frame"), row.names = c(NA, 
                                                                                                                                     -50L))

And a basic model

library("rpart")
sample_model <- rpart(formula = type ~ ., 
                  data = sample_data, 
                  method = "class",
                  control = rpart.control(xval = 50, minbucket = 5, cp = 0.05),
                  parms = list(split = "gini"))

The rpart documentation say that there's supposed to be a column(s) in sample_model$frame called "splits" but it's not there. To quote: "splits, a two column matrix of left and right split labels for each node" https://www.rdocumentation.org/packages/rpart/versions/4.1-15/topics/rpart.object

Where are those columns in in sample_model$frame or sample_model? However, I see the data I want in

summary(sample_model)

What's going on?



Solution 1:[1]

I see than now, but it doesn't seem to describe the current structure. The $splits item is a separate list element:

  sample_model$splits

 #----------

         count ncat  improve     index adj
distance    50   -1 9.134639  274.3306   0
age         50    1 7.910588 1687.0000   0
age         39    1 6.062937 1654.5000   0
distance    39   -1 1.950142  188.8418   0

To see the full structure of the sample_model, do this:

str(sample_model)

I was unable to confirm my hunch about the docs lagging the code:

news(grepl('splits', Text), 'rpart')     #--------------------

Changes in version 4.1-0

Surrogate splits are now considered only if they send two or more cases with non-zero weight each way. For numeric/ordinal variables the restriction to non-zero weights is new: for categorical variables this is a new restriction. Surrogate splits which improve only by rounding error over the default split are no longer returned. Where weights and missing values are present, the splits component for some of these was not returned correctly.

Changes in version 4.0-1

The other major change was an error for asymmetric loss matrices, prompted by a user query. With L=loss asymmetric, the altered priors were computed incorrectly - they were using L' instead of L. Upshot - the tree would not not necessarily choose optimal splits for the given loss matrix. Once chosen, splits were evaluated correctly. The printed “improvement” values are of course the wrong ones as well. It is interesting that for my little test case, with L quite asymmetric, the early splits in the tree are unchanged - a good split still looks good.

To get a canonical answer you would need to contact the maintainer:

 maintainer('rpart')

Solution 2:[2]

The docs are indeed outdated. Here is an extractor derived by inspecting summary.rpart function:


rpart_splits <- function(fit, digits = getOption("digits")) {
  splits <- fit$splits
  if (!is.null(splits)) {
    ff <- fit$frame
    is.leaf <- ff$var == "<leaf>"
    n <- nrow(splits)
    nn <- ff$ncompete + ff$nsurrogate + !is.leaf
    ix <- cumsum(c(1L, nn))
    ix_prim <- unlist(mapply(ix, ix + c(ff$ncompete, 0), FUN = seq, SIMPLIFY = F))
    type <- rep.int("surrogate", n)
    type[ix_prim[ix_prim <= n]] <- "primary"
    type[ix[ix <= n]] <- "main"
    left <- character(nrow(splits))
    side <- splits[, 2L]
    for (i in seq_along(left)) {
      left[i] <- if (side[i] == -1L)
                   paste("<", format(signif(splits[i, 4L], digits)))
                 else if (side[i] == 1L)
                   paste(">=", format(signif(splits[i, 4L], digits)))
                 else {
                   catside <- fit$csplit[splits[i, 4L], 1:side[i]]
                   paste(c("L", "-", "R")[catside], collapse = "", sep = "")
                 }
    }
    cbind(data.frame(var = rownames(splits),
                     type = type,
                     node = rep(as.integer(row.names(ff)), times = nn),
                     ix = rep(seq_len(nrow(ff)), nn),
                     left = left),
          as.data.frame(splits, row.names = F))
  }
}

Filter on type == "main" to get only the main splits:

> fit <- rpart(Kyphosis ~ Age + Number + Start, data = kyphosis)
> rpart_splits(fit)
      var      type node ix    left count ncat    improve index       adj
1   Start      main    1  1  >= 8.5    81    1 6.76232996   8.5 0.0000000
2  Number   primary    1  1   < 5.5    81   -1 2.86679493   5.5 0.0000000
3     Age   primary    1  1  < 39.5    81   -1 2.25021152  39.5 0.0000000
4  Number surrogate    1  1   < 6.5     0   -1 0.80246914   6.5 0.1578947
5   Start      main    2  2 >= 14.5    62    1 1.02052786  14.5 0.0000000
6     Age   primary    2  2    < 55    62   -1 0.68486352  55.0 0.0000000
7  Number   primary    2  2   < 4.5    62   -1 0.29753321   4.5 0.0000000
8  Number surrogate    2  2   < 3.5     0   -1 0.64516129   3.5 0.2413793
9     Age surrogate    2  2    < 16     0   -1 0.59677419  16.0 0.1379310
10    Age      main    5  4    < 55    33   -1 1.24675325  55.0 0.0000000
11  Start   primary    5  4 >= 12.5    33    1 0.28877005  12.5 0.0000000
12 Number   primary    5  4  >= 3.5    33    1 0.17532468   3.5 0.0000000
13  Start surrogate    5  4   < 9.5     0   -1 0.75757576   9.5 0.3333333
14 Number surrogate    5  4  >= 5.5     0    1 0.69696970   5.5 0.1666667
15    Age      main   11  6  >= 111    21    1 1.71428571 111.0 0.0000000
16  Start   primary   11  6 >= 12.5    21    1 0.79365079  12.5 0.0000000
17 Number   primary   11  6  >= 3.5    21    1 0.07142857   3.5 0.0000000

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1
Solution 2