¿Cómo extraer las reglas de decisión de scikit-learn decision-tree?

¿Puedo extraer las reglas de decisión subyacentes (o ‘rutas de decisión’) de un árbol capacitado en un árbol de decisión como una lista textual?

Algo como:

if A>0.4 then if B0.8 then class='X'

Gracias por tu ayuda.

Creo que esta respuesta es más correcta que las otras respuestas aquí:

 from sklearn.tree import _tree def tree_to_code(tree, feature_names): tree_ = tree.tree_ feature_name = [ feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!" for i in tree_.feature ] print "def tree({}):".format(", ".join(feature_names)) def recurse(node, depth): indent = " " * depth if tree_.feature[node] != _tree.TREE_UNDEFINED: name = feature_name[node] threshold = tree_.threshold[node] print "{}if {} <= {}:".format(indent, name, threshold) recurse(tree_.children_left[node], depth + 1) print "{}else: # if {} > {}".format(indent, name, threshold) recurse(tree_.children_right[node], depth + 1) else: print "{}return {}".format(indent, tree_.value[node]) recurse(0, 1) 

Esto imprime una función de Python válida. Aquí hay un ejemplo de salida para un árbol que está intentando devolver su entrada, un número entre 0 y 10.

 def tree(f0): if f0 <= 6.0: if f0 <= 1.5: return [[ 0.]] else: # if f0 > 1.5 if f0 <= 4.5: if f0 <= 3.5: return [[ 3.]] else: # if f0 > 3.5 return [[ 4.]] else: # if f0 > 4.5 return [[ 5.]] else: # if f0 > 6.0 if f0 <= 8.5: if f0 <= 7.5: return [[ 7.]] else: # if f0 > 7.5 return [[ 8.]] else: # if f0 > 8.5 return [[ 9.]] 

Aquí hay algunos escollos que veo en otras respuestas:

  1. Usar tree_.threshold == -2 para decidir si un nodo es una hoja no es una buena idea. ¿Y si es un nodo de decisión real con un umbral de -2? En su lugar, debes mirar tree.feature o tree.children_* .
  2. La línea features = [feature_names[i] for i in tree_.feature] bloquea con mi versión de sklearn, porque algunos valores de tree.tree_.feature son -2 (específicamente para nodos de hoja).
  3. No es necesario tener varias sentencias if en la función recursiva, solo una está bien.

Creé mi propia función para extraer las reglas de los árboles de decisión creados por sklearn:

 import pandas as pd import numpy as np from sklearn.tree import DecisionTreeClassifier # dummy data: df = pd.DataFrame({'col1':[0,1,2,3],'col2':[3,4,5,6],'dv':[0,1,0,1]}) # create decision tree dt = DecisionTreeClassifier(max_depth=5, min_samples_leaf=1) dt.fit(df.ix[:,:2], df.dv) 

Esta función comienza primero con los nodos (identificados por -1 en las matrices secundarias) y luego busca a los padres de forma recursiva. Yo llamo a esto ‘linaje’ de un nodo. En el camino, tomo los valores que necesito crear si / luego / else SAS logic:

 def get_lineage(tree, feature_names): left = tree.tree_.children_left right = tree.tree_.children_right threshold = tree.tree_.threshold features = [feature_names[i] for i in tree.tree_.feature] # get ids of child nodes idx = np.argwhere(left == -1)[:,0] def recurse(left, right, child, lineage=None): if lineage is None: lineage = [child] if child in left: parent = np.where(left == child)[0].item() split = 'l' else: parent = np.where(right == child)[0].item() split = 'r' lineage.append((parent, split, threshold[parent], features[parent])) if parent == 0: lineage.reverse() return lineage else: return recurse(left, right, parent, lineage) for child in idx: for node in recurse(left, right, child): print node 

Los conjuntos de tuplas que aparecen a continuación contienen todo lo que necesito para crear declaraciones SAS if / then / else. No me gusta usar do blocks en SAS, por eso creo la lógica que describe la ruta completa de un nodo. El único entero después de las tuplas es el ID del nodo terminal en una ruta. Todas las tuplas anteriores se combinan para crear ese nodo.

 In [1]: get_lineage(dt, df.columns) (0, 'l', 0.5, 'col1') 1 (0, 'r', 0.5, 'col1') (2, 'l', 4.5, 'col2') 3 (0, 'r', 0.5, 'col1') (2, 'r', 4.5, 'col2') (4, 'l', 2.5, 'col1') 5 (0, 'r', 0.5, 'col1') (2, 'r', 4.5, 'col2') (4, 'r', 2.5, 'col1') 6 

Salida GraphViz del árbol de ejemplo

Modifiqué el código enviado por Zelazny7 para imprimir algunos pseudocódigos:

 def get_code(tree, feature_names): left = tree.tree_.children_left right = tree.tree_.children_right threshold = tree.tree_.threshold features = [feature_names[i] for i in tree.tree_.feature] value = tree.tree_.value def recurse(left, right, threshold, features, node): if (threshold[node] != -2): print "if ( " + features[node] + " <= " + str(threshold[node]) + " ) {" if left[node] != -1: recurse (left, right, threshold, features,left[node]) print "} else {" if right[node] != -1: recurse (left, right, threshold, features,right[node]) print "}" else: print "return " + str(value[node]) recurse(left, right, threshold, features, 0) 

Si llama a get_code(dt, df.columns) en el mismo ejemplo, obtendrá:

 if ( col1 <= 0.5 ) { return [[ 1. 0.]] } else { if ( col2 <= 4.5 ) { return [[ 0. 1.]] } else { if ( col1 <= 2.5 ) { return [[ 1. 0.]] } else { return [[ 0. 1.]] } } } 

Hay un nuevo método DecisionTreeClassifier , decision_path , en la versión 0.18.0 . Los desarrolladores proporcionan un tutorial extenso (bien documentado).

La primera sección del código en el tutorial que imprime la estructura de árbol parece estar bien. Sin embargo, modifiqué el código en la segunda sección para interrogar una muestra. Mis cambios denotados con # <--

Editar Los cambios marcados con # <-- en el código a continuación se han actualizado en el enlace del tutorial después de que los errores fueron señalados en las solicitudes de extracción # 8653 y # 10951 . Es mucho más fácil de seguir ahora.

 sample_id = 0 node_index = node_indicator.indices[node_indicator.indptr[sample_id]: node_indicator.indptr[sample_id + 1]] print('Rules used to predict sample %s: ' % sample_id) for node_id in node_index: if leave_id[sample_id] == node_id: # <-- changed != to == #continue # <-- comment out print("leaf node {} reached, no decision here".format(leave_id[sample_id])) # <-- else: # < -- added else to iterate through decision nodes if (X_test[sample_id, feature[node_id]] <= threshold[node_id]): threshold_sign = "<=" else: threshold_sign = ">" print("decision id node %s : (X[%s, %s] (= %s) %s %s)" % (node_id, sample_id, feature[node_id], X_test[sample_id, feature[node_id]], # <-- changed i to sample_id threshold_sign, threshold[node_id])) Rules used to predict sample 0: decision id node 0 : (X[0, 3] (= 2.4) > 0.800000011921) decision id node 2 : (X[0, 2] (= 5.1) > 4.94999980927) leaf node 4 reached, no decision here 

Cambie el sample_id para ver las rutas de decisión para otras muestras. No he preguntado a los desarrolladores sobre estos cambios, simplemente me pareció más intuitivo cuando trabajé en el ejemplo.

 from StringIO import StringIO out = StringIO() out = tree.export_graphviz(clf, out_file=out) print out.getvalue() 

Puedes ver un árbol digraph. Luego, clf.tree_.feature y clf.tree_.value son matrices de características de división de nodos y matrices de valores de nodos respectivamente. Puede consultar más detalles de esta fuente de github .

Los códigos a continuación son mi enfoque en anaconda python 2.7 más un nombre de paquete “pydot-ng” para crear un archivo PDF con reglas de decisión. Espero que sea útil.

 from sklearn import tree clf = tree.DecisionTreeClassifier(max_leaf_nodes=n) clf_ = clf.fit(X, data_y) feature_names = X.columns class_name = clf_.classes_.astype(int).astype(str) def output_pdf(clf_, name): from sklearn import tree from sklearn.externals.six import StringIO import pydot_ng as pydot dot_data = StringIO() tree.export_graphviz(clf_, out_file=dot_data, feature_names=feature_names, class_names=class_name, filled=True, rounded=True, special_characters=True, node_ids=1,) graph = pydot.graph_from_dot_data(dot_data.getvalue()) graph.write_pdf("%s.pdf"%name) output_pdf(clf_, name='filename%s'%n) 

un graphy de arbol muestra aqui

Solo porque todos fueron tan útiles, solo agregaré una modificación a las hermosas soluciones de Zelazny7 y Daniele. Este es para Python 2.7, con tabs para que sea más legible:

 def get_code(tree, feature_names, tabdepth=0): left = tree.tree_.children_left right = tree.tree_.children_right threshold = tree.tree_.threshold features = [feature_names[i] for i in tree.tree_.feature] value = tree.tree_.value def recurse(left, right, threshold, features, node, tabdepth=0): if (threshold[node] != -2): print '\t' * tabdepth, print "if ( " + features[node] + " <= " + str(threshold[node]) + " ) {" if left[node] != -1: recurse (left, right, threshold, features,left[node], tabdepth+1) print '\t' * tabdepth, print "} else {" if right[node] != -1: recurse (left, right, threshold, features,right[node], tabdepth+1) print '\t' * tabdepth, print "}" else: print '\t' * tabdepth, print "return " + str(value[node]) recurse(left, right, threshold, features, 0) 

Esto se basa en la respuesta de @paulkernfeld. Si tiene un dataframe X con sus características y un dataframe de destino con sus resonses y quiere tener una idea de qué valor de y terminó en qué nodo (y también para trazarlo en consecuencia) puede hacer lo siguiente:

  def tree_to_code(tree, feature_names): codelines = [] codelines.append('def get_cat(X_tmp):\n') codelines.append(' catout = []\n') codelines.append(' for codelines in range(0,X_tmp.shape[0]):\n') codelines.append(' Xin = X_tmp.iloc[codelines]\n') tree_ = tree.tree_ feature_name = [ feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!" for i in tree_.feature ] #print "def tree({}):".format(", ".join(feature_names)) def recurse(node, depth): indent = " " * depth if tree_.feature[node] != _tree.TREE_UNDEFINED: name = feature_name[node] threshold = tree_.threshold[node] codelines.append ('{}if Xin["{}"] <= {}:\n'.format(indent, name, threshold)) recurse(tree_.children_left[node], depth + 1) codelines.append( '{}else: # if Xin["{}"] > {}\n'.format(indent, name, threshold)) recurse(tree_.children_right[node], depth + 1) else: codelines.append( '{}mycat = {}\n'.format(indent, node)) recurse(0, 1) codelines.append(' catout.append(mycat)\n') codelines.append(' return pd.DataFrame(catout,index=X_tmp.index,columns=["category"])\n') codelines.append('node_ids = get_cat(X)\n') return codelines mycode = tree_to_code(clf,X.columns.values) # now execute the function and obtain the dataframe with all nodes exec(''.join(mycode)) node_ids = [int(x[0]) for x in node_ids.values] node_ids2 = pd.DataFrame(node_ids) print('make plot') import matplotlib.cm as cm colors = cm.rainbow(np.linspace(0, 1, 1+max( list(set(node_ids))))) #plt.figure(figsize=cm2inch(24, 21)) for i in list(set(node_ids)): plt.plot(y[node_ids2.values==i],'o',color=colors[i], label=str(i)) mytitle = ['y colored by node'] plt.title(mytitle ,fontsize=14) plt.xlabel('my xlabel') plt.ylabel(tagname) plt.xticks(rotation=70) plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.00), shadow=True, ncol=9) plt.tight_layout() plt.show() plt.close 

No es la versión más elegante pero hace el trabajo …

Aquí hay una función, que imprime las reglas de un árbol de decisiones de scikit-learn bajo python 3 y con compensaciones para bloques condicionales para hacer que la estructura sea más legible:

 def print_decision_tree(tree, feature_names=None, offset_unit=' '): '''Plots textual representation of rules of a decision tree tree: scikit-learn representation of tree feature_names: list of feature names. They are set to f1,f2,f3,... if not specified offset_unit: a string of offset of the conditional block''' left = tree.tree_.children_left right = tree.tree_.children_right threshold = tree.tree_.threshold value = tree.tree_.value if feature_names is None: features = ['f%d'%i for i in tree.tree_.feature] else: features = [feature_names[i] for i in tree.tree_.feature] def recurse(left, right, threshold, features, node, depth=0): offset = offset_unit*depth if (threshold[node] != -2): print(offset+"if ( " + features[node] + " <= " + str(threshold[node]) + " ) {") if left[node] != -1: recurse (left, right, threshold, features,left[node],depth+1) print(offset+"} else {") if right[node] != -1: recurse (left, right, threshold, features,right[node],depth+1) print(offset+"}") else: print(offset+"return " + str(value[node])) recurse(left, right, threshold, features, 0,0) 

He estado pasando por esto, pero necesitaba que las reglas estuvieran escritas en este formato

 if A>0.4 then if B<0.2 then if C>0.8 then class='X' 

Así que adapté la respuesta de @paulkernfeld (gracias) que puede personalizar según sus necesidades.

 def tree_to_code(tree, feature_names, Y): tree_ = tree.tree_ feature_name = [ feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!" for i in tree_.feature ] pathto=dict() global k k = 0 def recurse(node, depth, parent): global k indent = " " * depth if tree_.feature[node] != _tree.TREE_UNDEFINED: name = feature_name[node] threshold = tree_.threshold[node] s= "{} <= {} ".format( name, threshold, node ) if node == 0: pathto[node]=s else: pathto[node]=pathto[parent]+' & ' +s recurse(tree_.children_left[node], depth + 1, node) s="{} > {}".format( name, threshold) if node == 0: pathto[node]=s else: pathto[node]=pathto[parent]+' & ' +s recurse(tree_.children_right[node], depth + 1, node) else: k=k+1 print(k,')',pathto[parent], tree_.value[node]) recurse(0, 1, 0) 

Aquí hay una forma de traducir todo el árbol en una expresión de python única (no necesariamente legible para los humanos) usando la biblioteca SKompiler :

 from skompiler import skompile skompile(dtree.predict).to('python/code') 

Se modificó el código de Zelazny7 para obtener SQL del árbol de decisión.

 # SQL from decision tree def get_lineage(tree, feature_names): left = tree.tree_.children_left right = tree.tree_.children_right threshold = tree.tree_.threshold features = [feature_names[i] for i in tree.tree_.feature] le='<=' g ='>' # get ids of child nodes idx = np.argwhere(left == -1)[:,0] def recurse(left, right, child, lineage=None): if lineage is None: lineage = [child] if child in left: parent = np.where(left == child)[0].item() split = 'l' else: parent = np.where(right == child)[0].item() split = 'r' lineage.append((parent, split, threshold[parent], features[parent])) if parent == 0: lineage.reverse() return lineage else: return recurse(left, right, parent, lineage) print 'case ' for j,child in enumerate(idx): clause=' when ' for node in recurse(left, right, child): if len(str(node))<3: continue i=node if i[1]=='l': sign=le else: sign=g clause=clause+i[3]+sign+str(i[2])+' and ' clause=clause[:-4]+' then '+str(j) print clause print 'else 99 end as clusters' 

Aparentemente, hace mucho tiempo, alguien ya decidió intentar agregar la siguiente función a las funciones de exportación del árbol de scikit oficial (que básicamente solo admite export_graphviz)

 def export_dict(tree, feature_names=None, max_depth=None) : """Export a decision tree in dict format. 

Aquí está su compromiso completo:

https://github.com/scikit-learn/scikit-learn/blob/79bdc8f711d0af225ed6be9fdb708cea9f98a910/sklearn/tree/export.py

No sé exactamente qué pasó con este comentario. Pero también podrías intentar usar esa función.

Creo que esto justifica una solicitud de documentación seria para que la gente buena de scikit-learn documente correctamente la API sklearn.tree.Tree que es la estructura de árbol subyacente que DecisionTreeClassifier expone como su atributo tree_ .

Solo usa la función de sklearn.tree como esta

 from sklearn.tree import export_graphviz export_graphviz(tree, out_file = "tree.dot", feature_names = tree.columns) //or just ["petal length", "petal width"] 

Y luego busque en la carpeta del proyecto el archivo tree.dot , copie TODO el contenido y péguelo aquí http://www.webgraphviz.com/ y genere su gráfico 🙂