Issue
I have a dataset where there a bunch of different types of data for each each sample and I'd like to use separate models for different data types and use them together in sklearn.ensemble.StackingClassifier
. However, StackingClassifier
takes the same feature matrix and applies different algorithms to it, then sends the probabilities to the meta classifier.
Is there a way to specify particular feature matrices (representing the same samples) that correspond with specific algorithms in the StackingClassifier
?
If not, how can you use class inheritance of a StackingClassifier
to adapt to this type of functionality?
Below is a very quick and non-elegant example (e.g., for demonstration only not for practicality) of using 2 feature sets (i.e., sepal features and pedal features from iris) from the same samples (i.e., iris samples). Each feature set uses a different algorithm and then the probabilities are used as input into the meta classifier.
Doing it this way is very tedious...
from sklearn.linear_model import LogisticRegression
from sklearn.svm import LinearSVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble import AdaBoostClassifier
from sklearn.model_selection import train_test_split
# Data
X_sepal = pd.DataFrame({'sepal_length': {'iris_0': 5.1,'iris_1': 4.9,'iris_2': 4.7,'iris_3': 4.6,'iris_4': 5.0,'iris_5': 5.4,'iris_6': 4.6,'iris_7': 5.0,'iris_8': 4.4,'iris_9': 4.9,'iris_10': 5.4,'iris_11': 4.8,'iris_12': 4.8,'iris_13': 4.3,'iris_14': 5.8,'iris_15': 5.7,'iris_16': 5.4,'iris_17': 5.1,'iris_18': 5.7,'iris_19': 5.1,'iris_20': 5.4,'iris_21': 5.1,'iris_22': 4.6,'iris_23': 5.1,'iris_24': 4.8,'iris_25': 5.0,'iris_26': 5.0,'iris_27': 5.2,'iris_28': 5.2,'iris_29': 4.7,'iris_30': 4.8,'iris_31': 5.4,'iris_32': 5.2,'iris_33': 5.5,'iris_34': 4.9,'iris_35': 5.0,'iris_36': 5.5,'iris_37': 4.9,'iris_38': 4.4,'iris_39': 5.1,'iris_40': 5.0,'iris_41': 4.5,'iris_42': 4.4,'iris_43': 5.0,'iris_44': 5.1,'iris_45': 4.8,'iris_46': 5.1,'iris_47': 4.6,'iris_48': 5.3,'iris_49': 5.0,'iris_50': 7.0,'iris_51': 6.4,'iris_52': 6.9,'iris_53': 5.5,'iris_54': 6.5,'iris_55': 5.7,'iris_56': 6.3,'iris_57': 4.9,'iris_58': 6.6,'iris_59': 5.2,'iris_60': 5.0,'iris_61': 5.9,'iris_62': 6.0,'iris_63': 6.1,'iris_64': 5.6,'iris_65': 6.7,'iris_66': 5.6,'iris_67': 5.8,'iris_68': 6.2,'iris_69': 5.6,'iris_70': 5.9,'iris_71': 6.1,'iris_72': 6.3,'iris_73': 6.1,'iris_74': 6.4,'iris_75': 6.6,'iris_76': 6.8,'iris_77': 6.7,'iris_78': 6.0,'iris_79': 5.7,'iris_80': 5.5,'iris_81': 5.5,'iris_82': 5.8,'iris_83': 6.0,'iris_84': 5.4,'iris_85': 6.0,'iris_86': 6.7,'iris_87': 6.3,'iris_88': 5.6,'iris_89': 5.5,'iris_90': 5.5,'iris_91': 6.1,'iris_92': 5.8,'iris_93': 5.0,'iris_94': 5.6,'iris_95': 5.7,'iris_96': 5.7,'iris_97': 6.2,'iris_98': 5.1,'iris_99': 5.7,'iris_100': 6.3,'iris_101': 5.8,'iris_102': 7.1,'iris_103': 6.3,'iris_104': 6.5,'iris_105': 7.6,'iris_106': 4.9,'iris_107': 7.3,'iris_108': 6.7,'iris_109': 7.2,'iris_110': 6.5,'iris_111': 6.4,'iris_112': 6.8,'iris_113': 5.7,'iris_114': 5.8,'iris_115': 6.4,'iris_116': 6.5,'iris_117': 7.7,'iris_118': 7.7,'iris_119': 6.0,'iris_120': 6.9,'iris_121': 5.6,'iris_122': 7.7,'iris_123': 6.3,'iris_124': 6.7,'iris_125': 7.2,'iris_126': 6.2,'iris_127': 6.1,'iris_128': 6.4,'iris_129': 7.2,'iris_130': 7.4,'iris_131': 7.9,'iris_132': 6.4,'iris_133': 6.3,'iris_134': 6.1,'iris_135': 7.7,'iris_136': 6.3,'iris_137': 6.4,'iris_138': 6.0,'iris_139': 6.9,'iris_140': 6.7,'iris_141': 6.9,'iris_142': 5.8,'iris_143': 6.8,'iris_144': 6.7,'iris_145': 6.7,'iris_146': 6.3,'iris_147': 6.5,'iris_148': 6.2,'iris_149': 5.9},'sepal_width': {'iris_0': 3.5,'iris_1': 3.0,'iris_2': 3.2,'iris_3': 3.1,'iris_4': 3.6,'iris_5': 3.9,'iris_6': 3.4,'iris_7': 3.4,'iris_8': 2.9,'iris_9': 3.1,'iris_10': 3.7,'iris_11': 3.4,'iris_12': 3.0,'iris_13': 3.0,'iris_14': 4.0,'iris_15': 4.4,'iris_16': 3.9,'iris_17': 3.5,'iris_18': 3.8,'iris_19': 3.8,'iris_20': 3.4,'iris_21': 3.7,'iris_22': 3.6,'iris_23': 3.3,'iris_24': 3.4,'iris_25': 3.0,'iris_26': 3.4,'iris_27': 3.5,'iris_28': 3.4,'iris_29': 3.2,'iris_30': 3.1,'iris_31': 3.4,'iris_32': 4.1,'iris_33': 4.2,'iris_34': 3.1,'iris_35': 3.2,'iris_36': 3.5,'iris_37': 3.6,'iris_38': 3.0,'iris_39': 3.4,'iris_40': 3.5,'iris_41': 2.3,'iris_42': 3.2,'iris_43': 3.5,'iris_44': 3.8,'iris_45': 3.0,'iris_46': 3.8,'iris_47': 3.2,'iris_48': 3.7,'iris_49': 3.3,'iris_50': 3.2,'iris_51': 3.2,'iris_52': 3.1,'iris_53': 2.3,'iris_54': 2.8,'iris_55': 2.8,'iris_56': 3.3,'iris_57': 2.4,'iris_58': 2.9,'iris_59': 2.7,'iris_60': 2.0,'iris_61': 3.0,'iris_62': 2.2,'iris_63': 2.9,'iris_64': 2.9,'iris_65': 3.1,'iris_66': 3.0,'iris_67': 2.7,'iris_68': 2.2,'iris_69': 2.5,'iris_70': 3.2,'iris_71': 2.8,'iris_72': 2.5,'iris_73': 2.8,'iris_74': 2.9,'iris_75': 3.0,'iris_76': 2.8,'iris_77': 3.0,'iris_78': 2.9,'iris_79': 2.6,'iris_80': 2.4,'iris_81': 2.4,'iris_82': 2.7,'iris_83': 2.7,'iris_84': 3.0,'iris_85': 3.4,'iris_86': 3.1,'iris_87': 2.3,'iris_88': 3.0,'iris_89': 2.5,'iris_90': 2.6,'iris_91': 3.0,'iris_92': 2.6,'iris_93': 2.3,'iris_94': 2.7,'iris_95': 3.0,'iris_96': 2.9,'iris_97': 2.9,'iris_98': 2.5,'iris_99': 2.8,'iris_100': 3.3,'iris_101': 2.7,'iris_102': 3.0,'iris_103': 2.9,'iris_104': 3.0,'iris_105': 3.0,'iris_106': 2.5,'iris_107': 2.9,'iris_108': 2.5,'iris_109': 3.6,'iris_110': 3.2,'iris_111': 2.7,'iris_112': 3.0,'iris_113': 2.5,'iris_114': 2.8,'iris_115': 3.2,'iris_116': 3.0,'iris_117': 3.8,'iris_118': 2.6,'iris_119': 2.2,'iris_120': 3.2,'iris_121': 2.8,'iris_122': 2.8,'iris_123': 2.7,'iris_124': 3.3,'iris_125': 3.2,'iris_126': 2.8,'iris_127': 3.0,'iris_128': 2.8,'iris_129': 3.0,'iris_130': 2.8,'iris_131': 3.8,'iris_132': 2.8,'iris_133': 2.8,'iris_134': 2.6,'iris_135': 3.0,'iris_136': 3.4,'iris_137': 3.1,'iris_138': 3.0,'iris_139': 3.1,'iris_140': 3.1,'iris_141': 3.1,'iris_142': 2.7,'iris_143': 3.2,'iris_144': 3.3,'iris_145': 3.0,'iris_146': 2.5,'iris_147': 3.0,'iris_148': 3.4,'iris_149': 3.0}})
X_petal = pd.DataFrame({'petal_length': {'iris_0': 1.4,'iris_1': 1.4,'iris_2': 1.3,'iris_3': 1.5,'iris_4': 1.4,'iris_5': 1.7,'iris_6': 1.4,'iris_7': 1.5,'iris_8': 1.4,'iris_9': 1.5,'iris_10': 1.5,'iris_11': 1.6,'iris_12': 1.4,'iris_13': 1.1,'iris_14': 1.2,'iris_15': 1.5,'iris_16': 1.3,'iris_17': 1.4,'iris_18': 1.7,'iris_19': 1.5,'iris_20': 1.7,'iris_21': 1.5,'iris_22': 1.0,'iris_23': 1.7,'iris_24': 1.9,'iris_25': 1.6,'iris_26': 1.6,'iris_27': 1.5,'iris_28': 1.4,'iris_29': 1.6,'iris_30': 1.6,'iris_31': 1.5,'iris_32': 1.5,'iris_33': 1.4,'iris_34': 1.5,'iris_35': 1.2,'iris_36': 1.3,'iris_37': 1.4,'iris_38': 1.3,'iris_39': 1.5,'iris_40': 1.3,'iris_41': 1.3,'iris_42': 1.3,'iris_43': 1.6,'iris_44': 1.9,'iris_45': 1.4,'iris_46': 1.6,'iris_47': 1.4,'iris_48': 1.5,'iris_49': 1.4,'iris_50': 4.7,'iris_51': 4.5,'iris_52': 4.9,'iris_53': 4.0,'iris_54': 4.6,'iris_55': 4.5,'iris_56': 4.7,'iris_57': 3.3,'iris_58': 4.6,'iris_59': 3.9,'iris_60': 3.5,'iris_61': 4.2,'iris_62': 4.0,'iris_63': 4.7,'iris_64': 3.6,'iris_65': 4.4,'iris_66': 4.5,'iris_67': 4.1,'iris_68': 4.5,'iris_69': 3.9,'iris_70': 4.8,'iris_71': 4.0,'iris_72': 4.9,'iris_73': 4.7,'iris_74': 4.3,'iris_75': 4.4,'iris_76': 4.8,'iris_77': 5.0,'iris_78': 4.5,'iris_79': 3.5,'iris_80': 3.8,'iris_81': 3.7,'iris_82': 3.9,'iris_83': 5.1,'iris_84': 4.5,'iris_85': 4.5,'iris_86': 4.7,'iris_87': 4.4,'iris_88': 4.1,'iris_89': 4.0,'iris_90': 4.4,'iris_91': 4.6,'iris_92': 4.0,'iris_93': 3.3,'iris_94': 4.2,'iris_95': 4.2,'iris_96': 4.2,'iris_97': 4.3,'iris_98': 3.0,'iris_99': 4.1,'iris_100': 6.0,'iris_101': 5.1,'iris_102': 5.9,'iris_103': 5.6,'iris_104': 5.8,'iris_105': 6.6,'iris_106': 4.5,'iris_107': 6.3,'iris_108': 5.8,'iris_109': 6.1,'iris_110': 5.1,'iris_111': 5.3,'iris_112': 5.5,'iris_113': 5.0,'iris_114': 5.1,'iris_115': 5.3,'iris_116': 5.5,'iris_117': 6.7,'iris_118': 6.9,'iris_119': 5.0,'iris_120': 5.7,'iris_121': 4.9,'iris_122': 6.7,'iris_123': 4.9,'iris_124': 5.7,'iris_125': 6.0,'iris_126': 4.8,'iris_127': 4.9,'iris_128': 5.6,'iris_129': 5.8,'iris_130': 6.1,'iris_131': 6.4,'iris_132': 5.6,'iris_133': 5.1,'iris_134': 5.6,'iris_135': 6.1,'iris_136': 5.6,'iris_137': 5.5,'iris_138': 4.8,'iris_139': 5.4,'iris_140': 5.6,'iris_141': 5.1,'iris_142': 5.1,'iris_143': 5.9,'iris_144': 5.7,'iris_145': 5.2,'iris_146': 5.0,'iris_147': 5.2,'iris_148': 5.4,'iris_149': 5.1},'petal_width': {'iris_0': 0.2,'iris_1': 0.2,'iris_2': 0.2,'iris_3': 0.2,'iris_4': 0.2,'iris_5': 0.4,'iris_6': 0.3,'iris_7': 0.2,'iris_8': 0.2,'iris_9': 0.1,'iris_10': 0.2,'iris_11': 0.2,'iris_12': 0.1,'iris_13': 0.1,'iris_14': 0.2,'iris_15': 0.4,'iris_16': 0.4,'iris_17': 0.3,'iris_18': 0.3,'iris_19': 0.3,'iris_20': 0.2,'iris_21': 0.4,'iris_22': 0.2,'iris_23': 0.5,'iris_24': 0.2,'iris_25': 0.2,'iris_26': 0.4,'iris_27': 0.2,'iris_28': 0.2,'iris_29': 0.2,'iris_30': 0.2,'iris_31': 0.4,'iris_32': 0.1,'iris_33': 0.2,'iris_34': 0.2,'iris_35': 0.2,'iris_36': 0.2,'iris_37': 0.1,'iris_38': 0.2,'iris_39': 0.2,'iris_40': 0.3,'iris_41': 0.3,'iris_42': 0.2,'iris_43': 0.6,'iris_44': 0.4,'iris_45': 0.3,'iris_46': 0.2,'iris_47': 0.2,'iris_48': 0.2,'iris_49': 0.2,'iris_50': 1.4,'iris_51': 1.5,'iris_52': 1.5,'iris_53': 1.3,'iris_54': 1.5,'iris_55': 1.3,'iris_56': 1.6,'iris_57': 1.0,'iris_58': 1.3,'iris_59': 1.4,'iris_60': 1.0,'iris_61': 1.5,'iris_62': 1.0,'iris_63': 1.4,'iris_64': 1.3,'iris_65': 1.4,'iris_66': 1.5,'iris_67': 1.0,'iris_68': 1.5,'iris_69': 1.1,'iris_70': 1.8,'iris_71': 1.3,'iris_72': 1.5,'iris_73': 1.2,'iris_74': 1.3,'iris_75': 1.4,'iris_76': 1.4,'iris_77': 1.7,'iris_78': 1.5,'iris_79': 1.0,'iris_80': 1.1,'iris_81': 1.0,'iris_82': 1.2,'iris_83': 1.6,'iris_84': 1.5,'iris_85': 1.6,'iris_86': 1.5,'iris_87': 1.3,'iris_88': 1.3,'iris_89': 1.3,'iris_90': 1.2,'iris_91': 1.4,'iris_92': 1.2,'iris_93': 1.0,'iris_94': 1.3,'iris_95': 1.2,'iris_96': 1.3,'iris_97': 1.3,'iris_98': 1.1,'iris_99': 1.3,'iris_100': 2.5,'iris_101': 1.9,'iris_102': 2.1,'iris_103': 1.8,'iris_104': 2.2,'iris_105': 2.1,'iris_106': 1.7,'iris_107': 1.8,'iris_108': 1.8,'iris_109': 2.5,'iris_110': 2.0,'iris_111': 1.9,'iris_112': 2.1,'iris_113': 2.0,'iris_114': 2.4,'iris_115': 2.3,'iris_116': 1.8,'iris_117': 2.2,'iris_118': 2.3,'iris_119': 1.5,'iris_120': 2.3,'iris_121': 2.0,'iris_122': 2.0,'iris_123': 1.8,'iris_124': 2.1,'iris_125': 1.8,'iris_126': 1.8,'iris_127': 1.8,'iris_128': 2.1,'iris_129': 1.6,'iris_130': 1.9,'iris_131': 2.0,'iris_132': 2.2,'iris_133': 1.5,'iris_134': 1.4,'iris_135': 2.3,'iris_136': 2.4,'iris_137': 1.8,'iris_138': 1.8,'iris_139': 2.1,'iris_140': 2.4,'iris_141': 2.3,'iris_142': 1.9,'iris_143': 2.3,'iris_144': 2.5,'iris_145': 2.3,'iris_146': 1.9,'iris_147': 2.0,'iris_148': 2.3,'iris_149': 1.8}})
y_iris = pd.Series({'iris_0': 'setosa','iris_1': 'setosa','iris_2': 'setosa','iris_3': 'setosa','iris_4': 'setosa','iris_5': 'setosa','iris_6': 'setosa','iris_7': 'setosa','iris_8': 'setosa','iris_9': 'setosa','iris_10': 'setosa','iris_11': 'setosa','iris_12': 'setosa','iris_13': 'setosa','iris_14': 'setosa','iris_15': 'setosa','iris_16': 'setosa','iris_17': 'setosa','iris_18': 'setosa','iris_19': 'setosa','iris_20': 'setosa','iris_21': 'setosa','iris_22': 'setosa','iris_23': 'setosa','iris_24': 'setosa','iris_25': 'setosa','iris_26': 'setosa','iris_27': 'setosa','iris_28': 'setosa','iris_29': 'setosa','iris_30': 'setosa','iris_31': 'setosa','iris_32': 'setosa','iris_33': 'setosa','iris_34': 'setosa','iris_35': 'setosa','iris_36': 'setosa','iris_37': 'setosa','iris_38': 'setosa','iris_39': 'setosa','iris_40': 'setosa','iris_41': 'setosa','iris_42': 'setosa','iris_43': 'setosa','iris_44': 'setosa','iris_45': 'setosa','iris_46': 'setosa','iris_47': 'setosa','iris_48': 'setosa','iris_49': 'setosa','iris_50': 'versicolor','iris_51': 'versicolor','iris_52': 'versicolor','iris_53': 'versicolor','iris_54': 'versicolor','iris_55': 'versicolor','iris_56': 'versicolor','iris_57': 'versicolor','iris_58': 'versicolor','iris_59': 'versicolor','iris_60': 'versicolor','iris_61': 'versicolor','iris_62': 'versicolor','iris_63': 'versicolor','iris_64': 'versicolor','iris_65': 'versicolor','iris_66': 'versicolor','iris_67': 'versicolor','iris_68': 'versicolor','iris_69': 'versicolor','iris_70': 'versicolor','iris_71': 'versicolor','iris_72': 'versicolor','iris_73': 'versicolor','iris_74': 'versicolor','iris_75': 'versicolor','iris_76': 'versicolor','iris_77': 'versicolor','iris_78': 'versicolor','iris_79': 'versicolor','iris_80': 'versicolor','iris_81': 'versicolor','iris_82': 'versicolor','iris_83': 'versicolor','iris_84': 'versicolor','iris_85': 'versicolor','iris_86': 'versicolor','iris_87': 'versicolor','iris_88': 'versicolor','iris_89': 'versicolor','iris_90': 'versicolor','iris_91': 'versicolor','iris_92': 'versicolor','iris_93': 'versicolor','iris_94': 'versicolor','iris_95': 'versicolor','iris_96': 'versicolor','iris_97': 'versicolor','iris_98': 'versicolor','iris_99': 'versicolor','iris_100': 'virginica','iris_101': 'virginica','iris_102': 'virginica','iris_103': 'virginica','iris_104': 'virginica','iris_105': 'virginica','iris_106': 'virginica','iris_107': 'virginica','iris_108': 'virginica','iris_109': 'virginica','iris_110': 'virginica','iris_111': 'virginica','iris_112': 'virginica','iris_113': 'virginica','iris_114': 'virginica','iris_115': 'virginica','iris_116': 'virginica','iris_117': 'virginica','iris_118': 'virginica','iris_119': 'virginica','iris_120': 'virginica','iris_121': 'virginica','iris_122': 'virginica','iris_123': 'virginica','iris_124': 'virginica','iris_125': 'virginica','iris_126': 'virginica','iris_127': 'virginica','iris_128': 'virginica','iris_129': 'virginica','iris_130': 'virginica','iris_131': 'virginica','iris_132': 'virginica','iris_133': 'virginica','iris_134': 'virginica','iris_135': 'virginica','iris_136': 'virginica','iris_137': 'virginica','iris_138': 'virginica','iris_139': 'virginica','iris_140': 'virginica','iris_141': 'virginica','iris_142': 'virginica','iris_143': 'virginica','iris_144': 'virginica','iris_145': 'virginica','iris_146': 'virginica','iris_147': 'virginica','iris_148': 'virginica','iris_149': 'virginica'})
# Training/Testing
idx_training, idx_testing = train_test_split(y_iris.index, stratify=y_iris, random_state=0)
# Classifiers
clf_sepal = AdaBoostClassifier(base_estimator=LinearSVC(random_state=0), random_state=0, algorithm='SAMME')
clf_petal = RandomForestClassifier(random_state=0)
clf_meta = LogisticRegression(random_state=0)
# Fitting base classifiers
clf_sepal.fit(X_sepal.loc[idx_training], y_iris.loc[idx_training])
clf_petal.fit(X_sepal.loc[idx_training], y_iris.loc[idx_training])
# Fitting meta classifier
clf_meta.fit(
X=pd.concat([
pd.DataFrame(clf_sepal.predict_proba(X_sepal.loc[idx_training]), index=idx_training, columns=pd.Index(clf_sepal.classes_).map(lambda j: "sepal__{}".format(j))),
pd.DataFrame(clf_petal.predict_proba(X_petal.loc[idx_training]), index=idx_training, columns=pd.Index(clf_petal.classes_).map(lambda j: "petal__{}".format(j))),
], axis=1),
y=y_iris.loc[idx_training],
)
# Predicting with meta classifier
y_hat = pd.Series(
clf_meta.predict(
X=pd.concat([
pd.DataFrame(clf_sepal.predict_proba(X_sepal.loc[idx_testing]), index=idx_testing, columns=pd.Index(clf_sepal.classes_).map(lambda j: "sepal__{}".format(j))),
pd.DataFrame(clf_petal.predict_proba(X_petal.loc[idx_testing]), index=idx_testing, columns=pd.Index(clf_petal.classes_).map(lambda j: "petal__{}".format(j))),
], axis=1),
),
index=idx_testing,
)
print("Accuracy on test set:", np.mean(y_hat == y_iris.loc[idx_testing]))
# Accuracy on test set: 0.9736842105263158
Solution
You can do the column selection as part of a pipeline for the base estimators. One approach to that is a ColumnTransformer
, which is a little verbose for the purpose but the alternatives I know about (FunctionTransformer
e.g.) are a little less robust.
sepal_cols = ['sepal_length', 'sepal_width']
petal_cols = ['petal_length', 'petal_width']
X = X_iris # as loaded from sklearn, or the hstack of your examples
pipe_sepal = Pipeline([
('select', ColumnTransformer([('sel', 'passthrough', sepal_cols)], remainder='drop')), # remainder='drop' is the default, but I've included it for clarity
('clf', clf_sepal)
])
pipe_petal = Pipeline([
('select', ColumnTransformer([('sel', 'passthrough', petal_cols)], remainder='drop')),
('clf', clf_petal)
])
stack = StackingClassifier(
estimators=[
('sepal', pipe_sepal),
('petal', pipe_petal),
],
final_estimator=clf_meta,
...
)
stack.fit(X_train, y_train)
y_hat = stack.predict(X_test)
Your manual method, in addition to being tedious, is statistically unsound: your base models are making predictions on their own training set to be used as inputs to the meta-estimator. This can generally lead to the meta-estimator giving preference to the most-overfit base estimator; I assume your high testing score (which does appear to be valid) is just due to iris being relatively easy?
Answered By - Ben Reiniger
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.