Issue
After a lot of research, it seems like there is no good way to properly stop and resume training using a Tensorflow 2 / Keras model. This is true whether you are using model.fit()
or using a custom training loop.
There seem to be 2 supported ways to save a model while training:
Save just the weights of the model, using
model.save_weights()
orsave_weights_only=True
withtf.keras.callbacks.ModelCheckpoint
. This seems to be preferred by most of the examples I've seen, however it has a number of major issues:- The optimizer state is not saved, meaning training resumption will not be correct.
- Learning rate schedule is reset - this can be catastrophic for some models.
- Tensorboard logs go back to step 0 - making logging essentually useless unless complex workarounds are implemented.
Save the entire model, optimizer, etc. using
model.save()
orsave_weights_only=False
. The optimizer state is saved (good) but the following issues remain:- Tensorboard logs still go back to step 0
- Learning rate schedule is still reset (!!!)
- It is impossible to use custom metrics.
- This doesn't work at all when using a custom training loop - custom training loops use a non-compiled model, and saving/loading a non-compiled model doesn't seem to be supported.
The best workaround I've found is to use a custom training loop, manually saving the step. This fixes the tensorboard logging, and the learning rate schedule can be fixed by doing something like keras.backend.set_value(model.optimizer.iterations, step)
. However, since a full model save is off the table, the optimizer state is not preserved. I can see no way to save the state of the optimizer independently, at least without a lot of work. And messing with the LR schedule as I've done feels messy as well.
Am I missing something? How are people out there saving/resuming using this API?
Solution
tf.keras.callbacks.experimental.BackupAndRestore
API for resuming training from interruptions has been added for tensorflow>=2.3
. It works great in my experience.
Reference: https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/experimental/BackupAndRestore
Answered By - yanp
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.