Issue
plt.figure(figsize=(15,7.5))
plot_tree(clf_dt, filled= True, `rounded=True, class_names = ["No HD","Yes HD"], feature_names=X_encoded.columns)
I am new to machine learning and python. Trying to plot a classification tree for the heart disease data from the UCI repository. After performing on-hot encoding for the categorical variables and storing it in X_encoded, when I try to plot the tree, it is giving me this error message:
InvalidParameterError: The 'feature_names' parameter of plot_tree must be an instance of 'list' or None. Got Index(['age', 'sex', 'cp', 'restbp', 'chol', 'fbs', 'restecg', 'thalach',
'exang', 'oldpeak', 'slope', 'ca', 'thal'],
dtype='object') instead.
Solution
You are passing an Index object rather than a list. Change to the following and you will have a list that you need.
feature_names=X_encoded.columns.tolist()
Answered By - Jesse Sealand
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.