Issue
I am working with a Decision Tree model (sklearn.tree.DecisionTreeRegressor
) and I would like to look at the detailed structure of the tree itself. I am currently using matplotlib.pyplot.figure
and tree.export_text
to output the tree however neither of these meets my requirements.
I would like to output the tree as a table with 1 row for each node in the tree. Suppose the tree looks like the following:
Node 1
/ \
/ \
/ \
Node 2.1 Node 2.2
/ \ / \
/ \ / \
Node 3.1 Node 3.2 Node 3.3 Node 3.4
Then I would like to produce a table with the following rows and columns.
Node | Variable | Threshold | Value | MSE | Samples |
---|---|---|---|---|---|
1 | |||||
2.1 | |||||
2.2 | |||||
3.1 | |||||
3.2 | |||||
3.3 | |||||
3.4 |
I know there is the tree_
attribute which could help. However I am not familiar with it and not sure where to start.
Solution
You can use this piece of code, it traverses all the nodes and collects/calculates the pieces of information you are interested in. However, I felt free to modify the columns a little and change the enumeration of the nodes:
1
/ \
/ \
1.1 1.2
/ \ / \
/ \ / \
1.1.1 1.1.2 1.2.1 1.2.2
Function
import pandas as pd
from sklearn.tree import DecisionTreeRegressor
def decision_tree_to_tabular(clf, feature_names):
total_samples = clf.tree_.n_node_samples[0] # total number of samples at the root node
tabular_tree = {
"Node": [],
"Depth": [],
"Type": [],
"Splitting Feature": [],
"Splitting Threshold": [],
"Prediction": [],
"MSE": [],
"Number of Samples": [],
"Proportion of Total Samples": [],
"Proportion of Parent Samples": []
}
def traverse_nodes(node_id=0, parent_node_id=None, parent_samples=None, current_node_id='1', depth=1):
samples = clf.tree_.n_node_samples[node_id]
prop_total_samples = samples / total_samples
prop_parent_samples = samples / parent_samples if parent_samples else None
if clf.tree_.children_left[node_id] != clf.tree_.children_right[node_id]: # internal node
tabular_tree["Node"].append(current_node_id)
tabular_tree["Depth"].append(depth)
tabular_tree["Type"].append("Node")
tabular_tree["Splitting Feature"].append(feature_names[clf.tree_.feature[node_id]])
tabular_tree["Splitting Threshold"].append(clf.tree_.threshold[node_id])
tabular_tree["Prediction"].append(None)
tabular_tree["MSE"].append(clf.tree_.impurity[node_id])
tabular_tree["Number of Samples"].append(samples)
tabular_tree["Proportion of Total Samples"].append(prop_total_samples)
tabular_tree["Proportion of Parent Samples"].append(prop_parent_samples)
traverse_nodes(clf.tree_.children_left[node_id], current_node_id, samples, current_node_id + ".1", depth + 1) # left child
traverse_nodes(clf.tree_.children_right[node_id], current_node_id, samples, current_node_id + ".2", depth + 1) # right child
else: # leaf
tabular_tree["Node"].append(current_node_id)
tabular_tree["Depth"].append(depth)
tabular_tree["Type"].append("Leaf")
tabular_tree["Splitting Feature"].append(None)
tabular_tree["Splitting Threshold"].append(None)
tabular_tree["Prediction"].append(clf.tree_.value[node_id].mean())
tabular_tree["MSE"].append(clf.tree_.impurity[node_id])
tabular_tree["Number of Samples"].append(samples)
tabular_tree["Proportion of Total Samples"].append(prop_total_samples)
tabular_tree["Proportion of Parent Samples"].append(prop_parent_samples)
traverse_nodes()
return pd.DataFrame(tabular_tree)
Test
from sklearn.datasets import fetch_california_housing
# Load the dataset
california = fetch_california_housing()
X = california.data
y = california.target
feature_names = california.feature_names
# Train a DecisionTreeRegressor
clf = DecisionTreeRegressor(random_state=0, max_depth=2).fit(X, y)
# Get the tree as a DataFrame
tabular_tree = decision_tree_to_tabular(clf, feature_names)
# display(tabular_tree)
display(tabular_tree.sort_values("Depth"))
Output
Node | Depth | Type | Splitting Feature | Splitting Threshold | Prediction | MSE | Number of Samples | Proportion of Total Samples | Proportion of Parent Samples |
---|---|---|---|---|---|---|---|---|---|
1 | 1 | Node | MedInc | 5.03515 | NaN | 1.331550 | 20640 | 1.000000 | NaN |
1.1 | 2 | Node | MedInc | 3.07430 | NaN | 0.837354 | 16255 | 0.787548 | 0.787548 |
1.2 | 2 | Node | MedInc | 6.81955 | NaN | 1.220713 | 4385 | 0.212452 | 0.212452 |
1.1.1 | 3 | Leaf | None | NaN | 1.356930 | 0.561155 | 7860 | 0.380814 | 0.483544 |
1.1.2 | 3 | Leaf | None | NaN | 2.088733 | 0.836995 | 8395 | 0.406734 | 0.516456 |
1.2.1 | 3 | Leaf | None | NaN | 2.905507 | 0.890550 | 3047 | 0.147626 | 0.694869 |
1.2.2 | 3 | Leaf | None | NaN | 4.216431 | 0.778440 | 1338 | 0.064826 | 0.305131 |
Answered By - DataJanitor
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.