Issue
I am having a lot of trouble understanding how the class_weight
parameter in scikit-learn's Logistic Regression operates.
The Situation
I want to use logistic regression to do binary classification on a very unbalanced data set. The classes are labelled 0 (negative) and 1 (positive) and the observed data is in a ratio of about 19:1 with the majority of samples having negative outcome.
First Attempt: Manually Preparing Training Data
I split the data I had into disjoint sets for training and testing (about 80/20). Then I randomly sampled the training data by hand to get training data in different proportions than 19:1; from 2:1 -> 16:1.
I then trained logistic regression on these different training data subsets and plotted recall (= TP/(TP+FN)) as a function of the different training proportions. Of course, the recall was computed on the disjoint TEST samples which had the observed proportions of 19:1. Note, although I trained the different models on different training data, I computed recall for all of them on the same (disjoint) test data.
The results were as expected: the recall was about 60% at 2:1 training proportions and fell off rather fast by the time it got to 16:1. There were several proportions 2:1 -> 6:1 where the recall was decently above 5%.
Second Attempt: Grid Search
Next, I wanted to test different regularization parameters and so I used GridSearchCV and made a grid of several values of the C
parameter as well as the class_weight
parameter. To translate my n:m proportions of negative:positive training samples into the dictionary language of class_weight
I thought that I just specify several dictionaries as follows:
{ 0:0.67, 1:0.33 } #expected 2:1
{ 0:0.75, 1:0.25 } #expected 3:1
{ 0:0.8, 1:0.2 } #expected 4:1
and I also included None
and auto
.
This time the results were totally wacked. All my recalls came out tiny (< 0.05) for every value of class_weight
except auto
. So I can only assume that my understanding of how to set the class_weight
dictionary is wrong. Interestingly, the class_weight
value of 'auto' in the grid search was around 59% for all values of C
, and I guessed it balances to 1:1?
My Questions
How do you properly use
class_weight
to achieve different balances in training data from what you actually give it? Specifically, what dictionary do I pass toclass_weight
to use n:m proportions of negative:positive training samples?If you pass various
class_weight
dictionaries to GridSearchCV, during cross-validation will it rebalance the training fold data according to the dictionary but use the true given sample proportions for computing my scoring function on the test fold? This is critical since any metric is only useful to me if it comes from data in the observed proportions.What does the
auto
value ofclass_weight
do as far as proportions? I read the documentation and I assume "balances the data inversely proportional to their frequency" just means it makes it 1:1. Is this correct? If not, can someone clarify?
Solution
First off, it might not be good to just go by recall alone. You can simply achieve a recall of 100% by classifying everything as the positive class. I usually suggest using AUC for selecting parameters, and then finding a threshold for the operating point (say a given precision level) that you are interested in.
For how class_weight
works: It penalizes mistakes in samples of class[i]
with class_weight[i]
instead of 1. So higher class-weight means you want to put more emphasis on a class. From what you say it seems class 0 is 19 times more frequent than class 1. So you should increase the class_weight
of class 1 relative to class 0, say {0:.1, 1:.9}.
If the class_weight
doesn't sum to 1, it will basically change the regularization parameter.
For how class_weight="auto"
works, you can have a look at this discussion.
In the dev version you can use class_weight="balanced"
, which is easier to understand: it basically means replicating the smaller class until you have as many samples as in the larger one, but in an implicit way.
Answered By - Andreas Mueller
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.