Issue
Good morning/afternoon, I would like to use cross-validation in sklearn for the prediction of a continuous variable.
I have refered to the "Visualizing cross-validation behavior in scikit-learn" page to select the cross-validation method suited to my problem. https://scikit-learn.org/stable/auto_examples/model_selection/plot_cv_indices.html#sphx-glr-auto-examples-model-selection-plot-cv-indices-py
I want to use StratifiedKFold but it does not provide a way to use a "stratifying" variable that is not the target variable ("class") as in the example below.
What I would like is to use the "group" variable to stratify instead.
Currently, what I do is this:
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import cross_val_score
skf = StratifiedKFold(n_splits=5,
shuffle = True,
random_state=57)
cross_val_score(regr, X, y, cv=skf.split(training,groups))
where regr is my regressor, X my features, y my target and groups a panda Series of my prefered "stratifying" variable. I have checked that skf.split(training,groups) provides splits suited to my needs, i.e., train and test sets where the original distribution of my groups is maintained.
However, I have no mean to check that the cross-validation have the behavior I am expecting. Am I correct? Can I check?
Solution
Your approach looks correct to me, even if it is rather uncommon.
You could check if the stratification worked with this code:
# Setup StratifiedKFold, just as you did
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=57)
# Set in which the seen test indiced are put
seen_test_indices = set()
# Iterating over each fold
for split_id, (train_index, test_index) in enumerate(skf.split(X, groups), start=1):
# Check if any of the test indices have been seen before
overlapping_indices = seen_test_indices.intersection(test_index)
if overlapping_indices:
print(f"Overlap detected in Split ID {split_id} with indices {overlapping_indices}")
break
seen_test_indices.update(test_index)
# Distribution of 'groups' in train and test split
train_groups_distribution = np.bincount(groups[train_index])
test_groups_distribution = np.bincount(groups[test_index])
print(f"Split ID: {split_id}")
print("Train Groups Distribution:", train_groups_distribution)
print("Test Groups Distribution:", test_groups_distribution)
print("-----")
I wouldn't use it if the variable groups has too many distinct/unique values. If each group has only a small number of samples, StratifiedKFold
might throw an error due to not having enough samples to create stratified folds.
Answered By - DataJanitor
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.