J'ai créé ma propre fonction pour extraire les règles des arbres de décision créés par sklearn:
import pandas as pd
import numpy as np
from sklearn.tree import DecisionTreeClassifier
# dummy data:
df = pd.DataFrame({'col1':[0,1,2,3],'col2':[3,4,5,6],'dv':[0,1,0,1]})
# create decision tree
dt = DecisionTreeClassifier(max_depth=5, min_samples_leaf=1)
dt.fit(df.ix[:,:2], df.dv)
Cette fonction commence par les nœuds (identifiés par -1 dans les tableaux enfants), puis trouve récursivement les parents. J'appelle cela la «lignée» d'un nœud. En cours de route, je saisis les valeurs dont j'ai besoin pour créer une logique SAS if / then / else:
def get_lineage(tree, feature_names):
     left      = tree.tree_.children_left
     right     = tree.tree_.children_right
     threshold = tree.tree_.threshold
     features  = [feature_names[i] for i in tree.tree_.feature]
     # get ids of child nodes
     idx = np.argwhere(left == -1)[:,0]     
     def recurse(left, right, child, lineage=None):          
          if lineage is None:
               lineage = [child]
          if child in left:
               parent = np.where(left == child)[0].item()
               split = 'l'
          else:
               parent = np.where(right == child)[0].item()
               split = 'r'
          lineage.append((parent, split, threshold[parent], features[parent]))
          if parent == 0:
               lineage.reverse()
               return lineage
          else:
               return recurse(left, right, parent, lineage)
     for child in idx:
          for node in recurse(left, right, child):
               print node
Les ensembles de tuples ci-dessous contiennent tout ce dont j'ai besoin pour créer des instructions SAS if / then / else. Je n'aime pas utiliser des doblocs dans SAS, c'est pourquoi je crée une logique décrivant le chemin complet d'un nœud. Le seul entier après les tuples est l'ID du nœud terminal dans un chemin. Tous les tuples précédents se combinent pour créer ce nœud.
In [1]: get_lineage(dt, df.columns)
(0, 'l', 0.5, 'col1')
1
(0, 'r', 0.5, 'col1')
(2, 'l', 4.5, 'col2')
3
(0, 'r', 0.5, 'col1')
(2, 'r', 4.5, 'col2')
(4, 'l', 2.5, 'col1')
5
(0, 'r', 0.5, 'col1')
(2, 'r', 4.5, 'col2')
(4, 'r', 2.5, 'col1')
6
