Issue
I am using random forest in scikit learn for classification and for getting the class probabilities , I used pred_proba function. But it outputs probabilities rounded to first decimal place
I tried with sample iris dataset
iris = load_iris()
df = pd.DataFrame(iris.data, columns=iris.feature_names)
df['is_train'] = np.random.uniform(0, 1, len(df)) <= .75
df['species'] = pd.Categorical(iris.target, iris.target_names)
df.head()
train, test = df[df['is_train']==True], df[df['is_train']==False]
features = df.columns[:4]
clf = RandomForestClassifier(n_jobs=2)
y, _ = pd.factorize(train['species'])
clf.fit(train[features], y)
clf.predict_proba(train[features])
output probabilities
[ 1. , 0. , 0. ],
[ 1. , 0. , 0. ],
[ 1. , 0. , 0. ],
[ 1. , 0. , 0. ],
[ 0. , 1. , 0. ],
[ 0. , 1. , 0. ],
[ 0. , 1. , 0. ],
[ 0. , 1. , 0. ],
[ 0. , 1. , 0. ],
[ 0. , 1. , 0. ],
[ 0. , 0.8, 0.2],
[ 0. , 1. , 0. ],
[ 0. , 1. , 0. ],
[ 0. , 1. , 0. ],
Is it the default output? Is it possible to increase the decimal places?
Note: Found the solution. Default no. of trees =10, after increasing no. of trees to hundred the precision of probability is increased.
Solution
There is a default setting of ten trees apparently, you are using the default in your code:
Parameters:
n_estimators : integer, optional (default=10)
The number of trees in the forest.
Try something like this, increasing the number of trees to 25 or some greater number than 10:
RandomForestClassifier(n_estimators=25, n_jobs=2)
If you are just getting the proportion of votes across 10 default trees this could very well result in the probabilities you are seeing
You may run into issues because the iris dataset is very small. Less than 200 observations if I recall correclty.
The documentation for predict.proba() reads:
The predicted class probabilities of an input sample is computed as the
mean predicted class probabilities of the trees in the forest. The class
probability of a single tree is the fraction of samples of the same
class in a leaf.
There isn't any parameters to adjust decimal precision of the predicted probabilities that I could find in the documentation.
Answered By - invoketheshell
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.