Issue
Below is a snippet of the decision tree as it is pretty huge.
How to make the tree stop growing when the lowest value in a node is under 5. Here is the code to produce the decision tree. On SciKit - Decission Tree we can see the only way to do so is by min_impurity_decrease but I am not sure how it specifically works.
import numpy as np
import pandas as pd
from sklearn.datasets import make_classification
from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import DecisionTreeClassifier
X, y = make_classification(n_samples=1000,
n_features=6,
n_informative=3,
n_classes=2,
random_state=0,
shuffle=False)
# Creating a dataFrame
df = pd.DataFrame({'Feature 1':X[:,0],
'Feature 2':X[:,1],
'Feature 3':X[:,2],
'Feature 4':X[:,3],
'Feature 5':X[:,4],
'Feature 6':X[:,5],
'Class':y})
y_train = df['Class']
X_train = df.drop('Class',axis = 1)
dt = DecisionTreeClassifier( random_state=42)
dt.fit(X_train, y_train)
from IPython.display import display, Image
import pydotplus
from sklearn import tree
from sklearn.tree import _tree
from sklearn import tree
import collections
import drawtree
import os
os.environ["PATH"] += os.pathsep + 'C:\\Anaconda3\\Library\\bin\\graphviz'
dot_data = tree.export_graphviz(dt, out_file = 'thisIsTheImagetree.dot',
feature_names=X_train.columns, filled = True
, rounded = True
, special_characters = True)
graph = pydotplus.graph_from_dot_file('thisIsTheImagetree.dot')
thisIsTheImage = Image(graph.create_png())
display(thisIsTheImage)
#print(dt.tree_.feature)
from subprocess import check_call
check_call(['dot','-Tpng','thisIsTheImagetree.dot','-o','thisIsTheImagetree.png'])
Update
I think min_impurity_decrease
can in a way help reach the goal. As tweaking min_impurity_decrease
does actually prune the tree. Can anyone kindly explain min_impurity_decrease.
I am trying to understand the equation in scikit learn but I am not sure what is the value of right_impurity and left_impurity.
N = 256
N_t = 256
impurity = ??
N_t_R = 242
N_t_L = 14
right_impurity = ??
left_impurity = ??
New_Value = N_t / N * (impurity - ((N_t_R / N_t) * right_impurity)
- ((N_t_L / N_t) * left_impurity))
New_Value
Update 2
Instead of pruning at a certain value, we prune under a certain condition. such as We do split at 6/4 and 5/5 but not at 6000/4 or 5000/5. Let's say if one value is under a certain percentage in comparison with its adjacent value in the node, rather than a certain value.
11/9
/ \
6/4 5/5
/ \ / \
6/0 0/4 2/2 3/3
Solution
Directly restricting the lowest value (number of occurences of a particular class) of a leaf cannot be done with min_impurity_decrease or any other built-in stopping criteria.
I think the only way you can accomplish this without changing the source code of scikit-learn is to post-prune your tree. To accomplish this, you can just traverse the tree and remove all children of the nodes with minimum class count less that 5 (or any other condition you can think of). I will continue your example:
from sklearn.tree._tree import TREE_LEAF
def prune_index(inner_tree, index, threshold):
if inner_tree.value[index].min() < threshold:
# turn node into a leaf by "unlinking" its children
inner_tree.children_left[index] = TREE_LEAF
inner_tree.children_right[index] = TREE_LEAF
# if there are shildren, visit them as well
if inner_tree.children_left[index] != TREE_LEAF:
prune_index(inner_tree, inner_tree.children_left[index], threshold)
prune_index(inner_tree, inner_tree.children_right[index], threshold)
print(sum(dt.tree_.children_left < 0))
# start pruning from the root
prune_index(dt.tree_, 0, 5)
sum(dt.tree_.children_left < 0)
this code will print first 74
, and then 91
. It means that the code has created 17 new leaf nodes (by practically removing links to their ancestors). The tree, which has looked before like
now looks like
so you can see that is indeed has decreased a lot.
Answered By - David Dale
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.