且构网

分享程序员开发的那些事...
且构网 - 分享程序员编程开发的那些事

遍历sklearn决策树

更新时间:2022-06-21 22:54:24

这应该做到:

from collections import deque

tree = clf.tree_

stack = deque()
stack.append(0)  # push tree root to stack

while stack:
    current_node = stack.popleft()

    # do whatever you want with current node
    # ...

    left_child = tree.children_left[current_node]
    if left_child >= 0:
        stack.append(left_child)

    right_child = tree.children_right[current_node]
    if right_child >= 0:
        stack.append(right_child)

这使用 deque 保留要处理的下一堆节点.由于我们从左侧删除了元素,然后在右侧添加了元素,因此这应该表示宽度优先的遍历.

This uses a deque to keep a stack of the nodes to process next. Since we remove elements from the left and add them to the right, this should represent a breadth-first traversal.

为实际使用,我建议您将其变成发电机:

For actual use, I suggest you turn this into a generator:

from collections import deque

def breadth_first_traversal(tree):
    stack = deque()
    stack.append(0)

    while stack:
        current_node = stack.popleft()

        yield current_node

        left_child = tree.children_left[current_node]
        if left_child >= 0:
            stack.append(left_child)

        right_child = tree.children_right[current_node]
        if right_child >= 0:
            stack.append(right_child)

然后,您只需要对原始功能进行最小的更改:

Then, you only need minimal changes to your original function:

def encoding(clf, features):
    l1 = list()
    l2 = list()

    for i in breadth_first_traversal(clf.tree_):
        if(clf.tree_.feature[i]>=0):
            l1.append( features[clf.tree_.feature[i]])
            l2.append(clf.tree_.threshold[i])
        else:
            l1.append(None)
            print(np.max(clf.tree_.value))
            l2.append(np.argmax(clf.tree_.value[i]))

    l = [l1 , l2]

    return np.array(l)