Issue
I found a PyTorch implementation that decays the batch norm momentum
parameter from 0.1
in the first epoch to 0.001
in the final epoch. Any suggestions on how to do this with the batch norm momentum
parameter in TF2? (i.e., start at 0.9
and end at 0.999
) For example, this is what is done in the PyTorch code:
# in training script
momentum = initial_momentum * np.exp(-epoch/args.epochs * np.log(initial_momentum/final_momentum))
model_pos_train.set_bn_momentum(momentum)
# model class function
def set_bn_momentum(self, momentum):
self.expand_bn.momentum = momentum
for bn in self.layers_bn:
bn.momentum = momentum
SOLUTION:
The selected answer below provides a viable solution when using the tf.keras.Model.fit()
API. However, I was using a custom training loop. Here is what I did instead:
After each epoch:
mi = 1 - initial_momentum # i.e., inital_momentum = 0.9, mi = 0.1
mf = 1 - final_momentum # i.e., final_momentum = 0.999, mf = 0.001
momentum = 1 - mi * np.exp(-epoch / epochs * np.log(mi / mf))
model = set_bn_momentum(model, momentum)
set_bn_momentum function (credit to this article):
def set_bn_momentum(model, momentum):
for layer in model.layers:
if hasattr(layer, 'momentum'):
print(layer.name, layer.momentum)
setattr(layer, 'momentum', momentum)
# When we change the layers attributes, the change only happens in the model config file
model_json = model.to_json()
# Save the weights before reloading the model.
tmp_weights_path = os.path.join(tempfile.gettempdir(), 'tmp_weights.h5')
model.save_weights(tmp_weights_path)
# load the model from the config
model = tf.keras.models.model_from_json(model_json)
# Reload the model weights
model.load_weights(tmp_weights_path, by_name=True)
return model
This method did not add significant overhead to the training routine.
Solution
You can set an action in the begin/the end of each batch, so you can control the any parameter during the epoch.
Below the options for the callbacks:
class CustomCallback(keras.callbacks.Callback):
def on_epoch_begin(self, epoch, logs=None):
keys = list(logs.keys())
print("Start epoch {} of training; got log keys: {}".format(epoch, keys))
def on_epoch_end(self, epoch, logs=None):
keys = list(logs.keys())
print("End epoch {} of training; got log keys: {}".format(epoch, keys))
def on_train_batch_begin(self, batch, logs=None):
keys = list(logs.keys())
print("...Training: start of batch {}; got log keys: {}".format(batch, keys))
def on_train_batch_end(self, batch, logs=None):
keys = list(logs.keys())
print("...Training: end of batch {}; got log keys: {}".format(batch, keys))
def on_test_batch_begin(self, batch, logs=None):
keys = list(logs.keys())
print("...Evaluating: start of batch {}; got log keys: {}".format(batch, keys))
def on_test_batch_end(self, batch, logs=None):
keys = list(logs.keys())
print("...Evaluating: end of batch {}; got log keys: {}".format(batch, keys))
You can access the momentum
batch = tf.keras.layers.BatchNormalization()
batch.momentum = 0.001
Inside the model you have to specified the correct layer
model.layers[1].momentum = 0.001
You can find more information and example at writing_your_own_callbacks
Answered By - Fernando Silva
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.