Issue
I'm using a scikit-learn custom pipeline (sklearn.pipeline.Pipeline
) in conjunction with RandomizedSearchCV
for hyper-parameter optimization. This works great.
Now I would like to insert a keras model as a first step into the pipeline. The parameters of the model should be optimized. The computed (fitted) keras model should then be used later on in the pipeline by other steps, so I think I have to store the model as a global variable so that the other pipeline steps can use it. Is this right?
I know that keras offers some wrappers for the scikit-learn API, but the problem is that these wrappers already do classification/regression, but I only want to compute the keras model and nothing else.
How can this be done?
For example, I have a method which returns the model:
def create_model(file_path, argument2,...):
...
return model
The method needs some fixed parameters like a file_path
etc. but X
and y
are not needed (or can be ignored). The parameters of the model should be optimized (number of layers etc.).
Solution
You need to wrap your Keras model as a Scikit learn model first and then proceed as usual.
Here's a quick example (I've omitted the imports for brevity)
Here is a full blog post with this one and many other examples: Scikit-learn Pipeline Examples
# create a function that returns a model, taking as parameters things you
# want to verify using cross-valdiation and model selection
def create_model(optimizer='adagrad',
kernel_initializer='glorot_uniform',
dropout=0.2):
model = Sequential()
model.add(Dense(64,activation='relu',kernel_initializer=kernel_initializer))
model.add(Dropout(dropout))
model.add(Dense(1,activation='sigmoid',kernel_initializer=kernel_initializer))
model.compile(loss='binary_crossentropy',optimizer=optimizer, metrics=['accuracy'])
return model
# wrap the model using the function you created
clf = KerasRegressor(build_fn=create_model,verbose=0)
# just create the pipeline
pipeline = Pipeline([
('clf',clf)
])
pipeline.fit(X_train, y_train)
Answered By - Felipe
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.