更新时间:2022-11-28 22:05:21
partykit
软件包具有函数 .list.rules.party()
目前未导出,但可以利用它来完成您想做的事情。我们尚未导出它的主要原因是,它的输出类型在将来的版本中可能会发生变化。
The partykit
package has a function .list.rules.party()
which is currently unexported but can be leveraged to do what you want to do. The main reason that we haven't exported it, yet, is that its type of output may change in future versions.
要获得您在上面描述的预测,您可以执行:
To obtain the predictions you describe above you can do:
pathpred <- function(object, ...)
{
## coerce to "party" object if necessary
if(!inherits(object, "party")) object <- as.party(object)
## get standard predictions (response/prob) and collect in data frame
rval <- data.frame(response = predict(object, type = "response", ...))
rval$prob <- predict(object, type = "prob", ...)
## get rules for each node
rls <- partykit:::.list.rules.party(object)
## get predicted node and select corresponding rule
rval$rule <- rls[as.character(predict(object, type = "node", ...))]
return(rval)
}
使用 iris的插图
数据和 rpart()
:
library("rpart")
library("partykit")
rp <- rpart(Species ~ ., data = iris)
rp_pred <- pathpred(rp)
rp_pred[c(1, 51, 101), ]
## response prob.setosa prob.versicolor prob.virginica
## 1 setosa 1.00000000 0.00000000 0.00000000
## 51 versicolor 0.00000000 0.90740741 0.09259259
## 101 virginica 0.00000000 0.02173913 0.97826087
## rule
## 1 Petal.Length < 2.45
## 51 Petal.Length >= 2.45 & Petal.Width < 1.75
## 101 Petal.Length >= 2.45 & Petal.Width >= 1.75
(此处为简洁起见,仅显示了每个物种的第一个观察结果。这对应于索引1、51和101。)
(Only the first observation of each species is shown for brevity here. This corresponds to indexes 1, 51, and 101.)
并使用 ctree()
:
ct <- ctree(Species ~ ., data = iris)
ct_pred <- pathpred(ct)
ct_pred[c(1, 51, 101), ]
## response prob.setosa prob.versicolor prob.virginica
## 1 setosa 1.00000000 0.00000000 0.00000000
## 51 versicolor 0.00000000 0.97826087 0.02173913
## 101 virginica 0.00000000 0.02173913 0.97826087
## rule
## 1 Petal.Length <= 1.9
## 51 Petal.Length > 1.9 & Petal.Width <= 1.7 & Petal.Length <= 4.8
## 101 Petal.Length > 1.9 & Petal.Width > 1.7