'Extract split values from rpart object in R
I can't find the split values (or other data) for nodes in an rpart object. I see it with summary(sample_model) but not in the list or data frame
Some sample data
foo.df <- structure(list(type = c("fudai", "fudai", "fudai", "fudai", "fudai",
"fudai", "fudai", "tozama", "fudai", "fudai", "tozama", "tozama",
"fudai", "tozama", "fudai", "fudai", "tozama", "fudai", "fudai",
"tozama", "fudai", "fudai", "fudai", "tozama", "fudai", "fudai",
"tozama", "fudai", "fudai", "fudai", "fudai", "fudai", "tozama",
"fudai", "fudai", "fudai", "fudai", "fudai", "fudai", "tozama",
"tozama", "fudai", "tozama", "tozama", "tozama", "tozama", "fudai",
"fudai", "tozama", "tozama"), distance = c(12.5366985071383,
272.697138147139, 40.4780423740381, 109.806349869662, 147.781805212839,
89.4280438527415, 49.1425850803745, 555.414271440522, 119.365138867582,
182.902536555383, 310.019126513348, 277.122207392514, 214.510428881317,
235.111617874157, 104.494518693549, 50.7561853895564, 343.308898045237,
151.796857505073, 36.0391449169937, 30.8214406651022, 343.294467363406,
135.841501028422, 154.798119311647, 317.739208576563, 3.33794280697559,
98.9182898110913, 422.915369767251, 194.957988642709, 87.6548263591412,
187.571370158631, 236.292608259126, 17.915709270268, 193.548578374405,
262.190146422316, 21.6219797945323, 121.199009527283, 261.670997612517,
202.2051991431, 125.418459536787, 275.964068539003, 190.112226847932,
20.1753302760961, 488.80323504215, 579.25515722891, 233.500797034697,
207.588349435329, 183.770003408524, 168.739293254246, 313.140075747773,
131.69228390613), age = c(1756, 1711, 1712, 1746, 1868, 1866,
1682, 1617, 1771, 1764, 1672, 1636, 1864, 1704, 1762, 1868, 1694,
1749, 1703, 1616, 1691, 1702, 1723, 1683, 1742, 1691, 1623, 1721,
1704, 1745, 1749, 1723, 1639, 1661, 1843, 1845, 1669, 1698, 1698,
1664, 1868, 1633, 1783, 1642, 1615, 1648, 1734, 1758, 1725, 1635
)), class = c("tbl_df", "tbl", "data.frame"), row.names = c(NA,
-50L))
And a basic model
library("rpart")
sample_model <- rpart(formula = type ~ .,
data = sample_data,
method = "class",
control = rpart.control(xval = 50, minbucket = 5, cp = 0.05),
parms = list(split = "gini"))
The rpart documentation say that there's supposed to be a column(s) in sample_model$frame called "splits" but it's not there. To quote: "splits, a two column matrix of left and right split labels for each node" https://www.rdocumentation.org/packages/rpart/versions/4.1-15/topics/rpart.object
Where are those columns in in sample_model$frame or sample_model? However, I see the data I want in
summary(sample_model)
What's going on?
Solution 1:[1]
I see than now, but it doesn't seem to describe the current structure. The $splits item is a separate list element:
sample_model$splits
#----------
count ncat improve index adj
distance 50 -1 9.134639 274.3306 0
age 50 1 7.910588 1687.0000 0
age 39 1 6.062937 1654.5000 0
distance 39 -1 1.950142 188.8418 0
To see the full structure of the sample_model, do this:
str(sample_model)
I was unable to confirm my hunch about the docs lagging the code:
news(grepl('splits', Text), 'rpart') #--------------------
Changes in version 4.1-0
Surrogate splits are now considered only if they send two or more cases with non-zero weight each way. For numeric/ordinal variables the restriction to non-zero weights is new: for categorical variables this is a new restriction. Surrogate splits which improve only by rounding error over the default split are no longer returned. Where weights and missing values are present, the splits component for some of these was not returned correctly.
Changes in version 4.0-1
The other major change was an error for asymmetric loss matrices, prompted by a user query. With L=loss asymmetric, the altered priors were computed incorrectly - they were using L' instead of L. Upshot - the tree would not not necessarily choose optimal splits for the given loss matrix. Once chosen, splits were evaluated correctly. The printed “improvement” values are of course the wrong ones as well. It is interesting that for my little test case, with L quite asymmetric, the early splits in the tree are unchanged - a good split still looks good.
To get a canonical answer you would need to contact the maintainer:
maintainer('rpart')
Solution 2:[2]
The docs are indeed outdated. Here is an extractor derived by inspecting summary.rpart function:
rpart_splits <- function(fit, digits = getOption("digits")) {
splits <- fit$splits
if (!is.null(splits)) {
ff <- fit$frame
is.leaf <- ff$var == "<leaf>"
n <- nrow(splits)
nn <- ff$ncompete + ff$nsurrogate + !is.leaf
ix <- cumsum(c(1L, nn))
ix_prim <- unlist(mapply(ix, ix + c(ff$ncompete, 0), FUN = seq, SIMPLIFY = F))
type <- rep.int("surrogate", n)
type[ix_prim[ix_prim <= n]] <- "primary"
type[ix[ix <= n]] <- "main"
left <- character(nrow(splits))
side <- splits[, 2L]
for (i in seq_along(left)) {
left[i] <- if (side[i] == -1L)
paste("<", format(signif(splits[i, 4L], digits)))
else if (side[i] == 1L)
paste(">=", format(signif(splits[i, 4L], digits)))
else {
catside <- fit$csplit[splits[i, 4L], 1:side[i]]
paste(c("L", "-", "R")[catside], collapse = "", sep = "")
}
}
cbind(data.frame(var = rownames(splits),
type = type,
node = rep(as.integer(row.names(ff)), times = nn),
ix = rep(seq_len(nrow(ff)), nn),
left = left),
as.data.frame(splits, row.names = F))
}
}
Filter on type == "main" to get only the main splits:
> fit <- rpart(Kyphosis ~ Age + Number + Start, data = kyphosis)
> rpart_splits(fit)
var type node ix left count ncat improve index adj
1 Start main 1 1 >= 8.5 81 1 6.76232996 8.5 0.0000000
2 Number primary 1 1 < 5.5 81 -1 2.86679493 5.5 0.0000000
3 Age primary 1 1 < 39.5 81 -1 2.25021152 39.5 0.0000000
4 Number surrogate 1 1 < 6.5 0 -1 0.80246914 6.5 0.1578947
5 Start main 2 2 >= 14.5 62 1 1.02052786 14.5 0.0000000
6 Age primary 2 2 < 55 62 -1 0.68486352 55.0 0.0000000
7 Number primary 2 2 < 4.5 62 -1 0.29753321 4.5 0.0000000
8 Number surrogate 2 2 < 3.5 0 -1 0.64516129 3.5 0.2413793
9 Age surrogate 2 2 < 16 0 -1 0.59677419 16.0 0.1379310
10 Age main 5 4 < 55 33 -1 1.24675325 55.0 0.0000000
11 Start primary 5 4 >= 12.5 33 1 0.28877005 12.5 0.0000000
12 Number primary 5 4 >= 3.5 33 1 0.17532468 3.5 0.0000000
13 Start surrogate 5 4 < 9.5 0 -1 0.75757576 9.5 0.3333333
14 Number surrogate 5 4 >= 5.5 0 1 0.69696970 5.5 0.1666667
15 Age main 11 6 >= 111 21 1 1.71428571 111.0 0.0000000
16 Start primary 11 6 >= 12.5 21 1 0.79365079 12.5 0.0000000
17 Number primary 11 6 >= 3.5 21 1 0.07142857 3.5 0.0000000
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | |
| Solution 2 |
