Issue
I am writing a training pipeline for time series forecasting models. I am using a naive seasonal model as baseline, which just outputs the last out_steps
of the input.
class Naive(tf.keras.Model):
def __init__(self, out_steps: int,
**kwargs):
super().__init__(**kwargs)
self.out_steps = out_steps
def call(self, inputs, training=None):
features = inputs
return features[:, -self.out_steps:, :]
I then can use the generic training stage:
def train_model(model_name, **model_params):
model = instantiate_model(model_name, **model_params)
model.compile(loss='mse', optimizer='adam')
model.fit(train_dataset)
Is there a way to make fit
understand that the model has no trainable layers and stop immediately?
Solution
Just a one-line wrapper:
if model.trainable_variables:
model.fit(...)
model.trainable_variables
contains the, well, trainable variables, for a model without such variables, this will be empty and the if condition is False.
NOTE though, that if you create models using Sequential
, and did not provide an input shape in the first layer, and also did not call the build()
function of the model, such a model will also not have any variables before calling fit
! These will only be created once the model is called for the first time. So you have to be careful with this, or add something like this:
if model.trainable_variables:
model.fit(...)
elif not model.built:
raise SomeError # or print a warning
Answered By - xdurch0
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.