更新时间:2022-06-21 22:54:24
这应该做到:
from collections import deque
tree = clf.tree_
stack = deque()
stack.append(0) # push tree root to stack
while stack:
current_node = stack.popleft()
# do whatever you want with current node
# ...
left_child = tree.children_left[current_node]
if left_child >= 0:
stack.append(left_child)
right_child = tree.children_right[current_node]
if right_child >= 0:
stack.append(right_child)
这使用 deque
保留要处理的下一堆节点.由于我们从左侧删除了元素,然后在右侧添加了元素,因此这应该表示宽度优先的遍历.
This uses a deque
to keep a stack of the nodes to process next. Since we remove elements from the left and add them to the right, this should represent a breadth-first traversal.
为实际使用,我建议您将其变成发电机:
For actual use, I suggest you turn this into a generator:
from collections import deque
def breadth_first_traversal(tree):
stack = deque()
stack.append(0)
while stack:
current_node = stack.popleft()
yield current_node
left_child = tree.children_left[current_node]
if left_child >= 0:
stack.append(left_child)
right_child = tree.children_right[current_node]
if right_child >= 0:
stack.append(right_child)
然后,您只需要对原始功能进行最小的更改:
Then, you only need minimal changes to your original function:
def encoding(clf, features):
l1 = list()
l2 = list()
for i in breadth_first_traversal(clf.tree_):
if(clf.tree_.feature[i]>=0):
l1.append( features[clf.tree_.feature[i]])
l2.append(clf.tree_.threshold[i])
else:
l1.append(None)
print(np.max(clf.tree_.value))
l2.append(np.argmax(clf.tree_.value[i]))
l = [l1 , l2]
return np.array(l)