[python] scikit-learn 의사 결정 트리에서 의사 결정 규칙을 추출하는 방법은 무엇입니까?

의사 결정 트리의 훈련 된 트리에서 텍스트 목록으로 기본 의사 결정 규칙 (또는 ‘결정 경로’)을 추출 할 수 있습니까?

다음과 같은 것 :

if A>0.4 then if B<0.2 then if C>0.8 then class='X'

당신의 도움을 주셔서 감사합니다.



답변

이 답변이 다른 답변보다 더 정확하다고 생각합니다.

from sklearn.tree import _tree

def tree_to_code(tree, feature_names):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]
    print "def tree({}):".format(", ".join(feature_names))

    def recurse(node, depth):
        indent = "  " * depth
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            print "{}if {} <= {}:".format(indent, name, threshold)
            recurse(tree_.children_left[node], depth + 1)
            print "{}else:  # if {} > {}".format(indent, name, threshold)
            recurse(tree_.children_right[node], depth + 1)
        else:
            print "{}return {}".format(indent, tree_.value[node])

    recurse(0, 1)

유효한 파이썬 함수를 출력합니다. 다음은 입력을 리턴하려는 트리에 대한 예제 출력입니다 (0과 10 사이의 숫자).

def tree(f0):
  if f0 <= 6.0:
    if f0 <= 1.5:
      return [[ 0.]]
    else:  # if f0 > 1.5
      if f0 <= 4.5:
        if f0 <= 3.5:
          return [[ 3.]]
        else:  # if f0 > 3.5
          return [[ 4.]]
      else:  # if f0 > 4.5
        return [[ 5.]]
  else:  # if f0 > 6.0
    if f0 <= 8.5:
      if f0 <= 7.5:
        return [[ 7.]]
      else:  # if f0 > 7.5
        return [[ 8.]]
    else:  # if f0 > 8.5
      return [[ 9.]]

다른 답변에서 볼 수있는 걸림돌은 다음과 같습니다.

  1. 사용하여 tree_.threshold == -2노드가 잎인지 여부를 결정하는 것은 좋은 생각이 아니다. 임계 값이 -2 인 실제 의사 결정 노드 인 경우 어떻게합니까? 대신, tree.feature또는tree.children_* 합니다.
  2. 선은 features = [feature_names[i] for i in tree_.feature], sklearn의 내 버전과 충돌의 일부 값 때문에tree.tree_.feature 입니다 -2 (특히 잎 노드).
  3. 재귀 함수에 여러 개의 if 문을 가질 필요는 없으며 하나만 있으면됩니다.

답변

sklearn이 만든 의사 결정 트리에서 규칙을 추출하는 자체 기능을 만들었습니다.

import pandas as pd
import numpy as np
from sklearn.tree import DecisionTreeClassifier

# dummy data:
df = pd.DataFrame({'col1':[0,1,2,3],'col2':[3,4,5,6],'dv':[0,1,0,1]})

# create decision tree
dt = DecisionTreeClassifier(max_depth=5, min_samples_leaf=1)
dt.fit(df.ix[:,:2], df.dv)

이 함수는 먼저 노드 (자식 배열에서 -1로 식별)로 시작한 다음 부모를 재귀 적으로 찾습니다. 나는 이것을 노드의 ‘계보’라고 부릅니다. 그 과정에서 if / then / else SAS 논리를 작성하는 데 필요한 값을 가져옵니다.

def get_lineage(tree, feature_names):
     left      = tree.tree_.children_left
     right     = tree.tree_.children_right
     threshold = tree.tree_.threshold
     features  = [feature_names[i] for i in tree.tree_.feature]

     # get ids of child nodes
     idx = np.argwhere(left == -1)[:,0]

     def recurse(left, right, child, lineage=None):
          if lineage is None:
               lineage = [child]
          if child in left:
               parent = np.where(left == child)[0].item()
               split = 'l'
          else:
               parent = np.where(right == child)[0].item()
               split = 'r'

          lineage.append((parent, split, threshold[parent], features[parent]))

          if parent == 0:
               lineage.reverse()
               return lineage
          else:
               return recurse(left, right, parent, lineage)

     for child in idx:
          for node in recurse(left, right, child):
               print node

아래 튜플 세트에는 SAS if / then / else 문을 작성하는 데 필요한 모든 것이 포함되어 있습니다. doSAS에서 블록을 사용하는 것을 좋아하지 않기 때문에 노드의 전체 경로를 설명하는 논리를 작성합니다. 튜플 뒤의 단일 정수는 경로에서 터미널 노드의 ID입니다. 앞의 모든 튜플이 결합되어 해당 노드를 만듭니다.

In [1]: get_lineage(dt, df.columns)
(0, 'l', 0.5, 'col1')
1
(0, 'r', 0.5, 'col1')
(2, 'l', 4.5, 'col2')
3
(0, 'r', 0.5, 'col1')
(2, 'r', 4.5, 'col2')
(4, 'l', 2.5, 'col1')
5
(0, 'r', 0.5, 'col1')
(2, 'r', 4.5, 'col2')
(4, 'r', 2.5, 'col1')
6

예제 트리의 GraphViz 출력


답변

Zelazny7 이 제출 한 코드를 수정하여 의사 코드를 인쇄했습니다.

def get_code(tree, feature_names):
        left      = tree.tree_.children_left
        right     = tree.tree_.children_right
        threshold = tree.tree_.threshold
        features  = [feature_names[i] for i in tree.tree_.feature]
        value = tree.tree_.value

        def recurse(left, right, threshold, features, node):
                if (threshold[node] != -2):
                        print "if ( " + features[node] + " <= " + str(threshold[node]) + " ) {"
                        if left[node] != -1:
                                recurse (left, right, threshold, features,left[node])
                        print "} else {"
                        if right[node] != -1:
                                recurse (left, right, threshold, features,right[node])
                        print "}"
                else:
                        print "return " + str(value[node])

        recurse(left, right, threshold, features, 0)

get_code(dt, df.columns)동일한 예제 를 호출 하면 다음을 얻을 수 있습니다.

if ( col1 <= 0.5 ) {
return [[ 1.  0.]]
} else {
if ( col2 <= 4.5 ) {
return [[ 0.  1.]]
} else {
if ( col1 <= 2.5 ) {
return [[ 1.  0.]]
} else {
return [[ 0.  1.]]
}
}
}


답변

Scikit Learn에서는 export_text트리에서 규칙을 추출하기 위해 버전 0.21 (2019 년 5 월)에서 맛있는 새로운 방법을 도입했습니다 . 여기에 문서 . 더 이상 사용자 정의 기능을 만들 필요가 없습니다.

모델에 적합하면 두 줄의 코드 만 있으면됩니다. 먼저 수입 export_text:

from sklearn.tree.export import export_text

둘째, 규칙을 포함 할 객체를 만듭니다. 규칙을보다 읽기 쉽게 보이게하려면 feature_names인수를 사용하고 기능 이름 목록을 전달하십시오. 예를 들어, 모델이 호출 model되고 피쳐가이라는 데이터 프레임에 이름이 지정된 경우 다음 과 X_train같은 객체를 작성할 수 있습니다 tree_rules.

tree_rules = export_text(model, feature_names=list(X_train))

그런 다음 인쇄하거나 저장하십시오 tree_rules. 결과는 다음과 같습니다.

|--- Age <= 0.63
|   |--- EstimatedSalary <= 0.61
|   |   |--- Age <= -0.16
|   |   |   |--- class: 0
|   |   |--- Age >  -0.16
|   |   |   |--- EstimatedSalary <= -0.06
|   |   |   |   |--- class: 0
|   |   |   |--- EstimatedSalary >  -0.06
|   |   |   |   |--- EstimatedSalary <= 0.40
|   |   |   |   |   |--- EstimatedSalary <= 0.03
|   |   |   |   |   |   |--- class: 1


답변

새로운이 DecisionTreeClassifier방법은 decision_path에, 0.18.0 릴리스. 개발자는 광범위한 (잘 문서화 된) 연습을 제공합니다 합니다.

연습에서 트리 구조를 인쇄하는 첫 번째 코드 섹션은 정상인 것 같습니다. 그러나 두 번째 섹션의 코드를 수정하여 하나의 샘플을 조사했습니다. 내 변경 사항은# <--

편집# <-- 아래 코드에 표시된 변경 사항 은 풀 요청 # 8653# 10951 에서 오류가 지적 된 후 연습 링크에서 업데이트되었습니다 . 지금 따라 가기가 훨씬 쉽습니다.

sample_id = 0
node_index = node_indicator.indices[node_indicator.indptr[sample_id]:
                                    node_indicator.indptr[sample_id + 1]]

print('Rules used to predict sample %s: ' % sample_id)
for node_id in node_index:

    if leave_id[sample_id] == node_id:  # <-- changed != to ==
        #continue # <-- comment out
        print("leaf node {} reached, no decision here".format(leave_id[sample_id])) # <--

    else: # < -- added else to iterate through decision nodes
        if (X_test[sample_id, feature[node_id]] <= threshold[node_id]):
            threshold_sign = "<="
        else:
            threshold_sign = ">"

        print("decision id node %s : (X[%s, %s] (= %s) %s %s)"
              % (node_id,
                 sample_id,
                 feature[node_id],
                 X_test[sample_id, feature[node_id]], # <-- changed i to sample_id
                 threshold_sign,
                 threshold[node_id]))

Rules used to predict sample 0:
decision id node 0 : (X[0, 3] (= 2.4) > 0.800000011921)
decision id node 2 : (X[0, 2] (= 5.1) > 4.94999980927)
leaf node 4 reached, no decision here

sample_id다른 샘플의 결정 경로를 보려면를 변경 하십시오. 개발자에게 이러한 변경 사항에 대해 묻지 않고 예제를 통해 작업 할 때 더 직관적 인 것처럼 보였습니다.


답변

from StringIO import StringIO
out = StringIO()
out = tree.export_graphviz(clf, out_file=out)
print out.getvalue()

Digraph Tree를 볼 수 있습니다. 그런 다음, clf.tree_.featureclf.tree_.value노드 분할 기능 및 노드 값의 배열의 배열은 각각이다. 이 github 소스 에서 자세한 내용을 참조 할 수 있습니다 .


답변

모두가 매우 도움이 되었기 때문에 Zelazny7과 Daniele의 아름다운 솔루션에 수정 사항을 추가합니다. 이것은 파이썬 2.7 용이며 더 읽기 쉬운 탭이 있습니다.

def get_code(tree, feature_names, tabdepth=0):
    left      = tree.tree_.children_left
    right     = tree.tree_.children_right
    threshold = tree.tree_.threshold
    features  = [feature_names[i] for i in tree.tree_.feature]
    value = tree.tree_.value

    def recurse(left, right, threshold, features, node, tabdepth=0):
            if (threshold[node] != -2):
                    print '\t' * tabdepth,
                    print "if ( " + features[node] + " <= " + str(threshold[node]) + " ) {"
                    if left[node] != -1:
                            recurse (left, right, threshold, features,left[node], tabdepth+1)
                    print '\t' * tabdepth,
                    print "} else {"
                    if right[node] != -1:
                            recurse (left, right, threshold, features,right[node], tabdepth+1)
                    print '\t' * tabdepth,
                    print "}"
            else:
                    print '\t' * tabdepth,
                    print "return " + str(value[node])

    recurse(left, right, threshold, features, 0)