且构网

分享程序员开发的那些事...
且构网 - 分享程序员编程开发的那些事

为R中的rpart / ctree包获取预测数据集的每一行的决策树规则/路径模式

更新时间:2022-11-28 22:05:21

partykit 软件包具有函数 .list.rules.party()目前未导出,但可以利用它来完成您想做的事情。我们尚未导出它的主要原因是,它的输出类型在将来的版本中可能会发生变化。

The partykit package has a function .list.rules.party() which is currently unexported but can be leveraged to do what you want to do. The main reason that we haven't exported it, yet, is that its type of output may change in future versions.

要获得您在上面描述的预测,您可以执行:

To obtain the predictions you describe above you can do:

pathpred <- function(object, ...)
{
  ## coerce to "party" object if necessary
  if(!inherits(object, "party")) object <- as.party(object)

  ## get standard predictions (response/prob) and collect in data frame
  rval <- data.frame(response = predict(object, type = "response", ...))
  rval$prob <- predict(object, type = "prob", ...)

  ## get rules for each node
  rls <- partykit:::.list.rules.party(object)

  ## get predicted node and select corresponding rule
  rval$rule <- rls[as.character(predict(object, type = "node", ...))]

  return(rval)
}

使用 iris的插图数据和 rpart()

library("rpart")
library("partykit")
rp <- rpart(Species ~ ., data = iris)
rp_pred <- pathpred(rp)
rp_pred[c(1, 51, 101), ]
##       response prob.setosa prob.versicolor prob.virginica
## 1       setosa  1.00000000      0.00000000     0.00000000
## 51  versicolor  0.00000000      0.90740741     0.09259259
## 101  virginica  0.00000000      0.02173913     0.97826087
##                                           rule
## 1                          Petal.Length < 2.45
## 51   Petal.Length >= 2.45 & Petal.Width < 1.75
## 101 Petal.Length >= 2.45 & Petal.Width >= 1.75

(此处为简洁起见,仅显示了每个物种的第一个观察结果。这对应于索引1、51和101。)

(Only the first observation of each species is shown for brevity here. This corresponds to indexes 1, 51, and 101.)

并使用 ctree()

ct <- ctree(Species ~ ., data = iris)
ct_pred <- pathpred(ct)
ct_pred[c(1, 51, 101), ]
##       response prob.setosa prob.versicolor prob.virginica
## 1       setosa  1.00000000      0.00000000     0.00000000
## 51  versicolor  0.00000000      0.97826087     0.02173913
## 101  virginica  0.00000000      0.02173913     0.97826087
##                                                              rule
## 1                                             Petal.Length <= 1.9
## 51  Petal.Length > 1.9 & Petal.Width <= 1.7 & Petal.Length <= 4.8
## 101                        Petal.Length > 1.9 & Petal.Width > 1.7