Issue
Preciously I have set my EfficientDetLite4 model "grad_checkpoint=true" in config.yaml. And it had successfully generated some checkpoints. However, I can't figure out how to use these checkpoints when I want to continue training based on them.
Every time I train the model it just start from the beginning, not from my checkpoints.
The following picture shows my colab file system structure:
The following picture shows where my checkpoints store:
The following code shows how I configure the model and how I train with the model.
import numpy as np
import os
from tflite_model_maker.config import ExportFormat
from tflite_model_maker import model_spec
from tflite_model_maker import object_detector
import tensorflow as tf
assert tf.__version__.startswith('2')
tf.get_logger().setLevel('ERROR')
from absl import logging
logging.set_verbosity(logging.ERROR)
train_data, validation_data, test_data =
object_detector.DataLoader.from_csv('csv_path')
spec = object_detector.EfficientDetLite4Spec(
uri='/content/model',
model_dir='/content/drive/MyDrive/MathSymbolRecognition/CheckPoints/',
hparams='grad_checkpoint=true,strategy=gpus',
epochs=50, batch_size=3,
steps_per_execution=1, moving_average_decay=0,
var_freeze_expr='(efficientnet|fpn_cells|resample_p6)',
tflite_max_detections=25, strategy=spec_strategy
)
model = object_detector.create(train_data, model_spec=spec, batch_size=3,
train_whole_model=True, validation_data=validation_data)
Solution
The source code is the answer !
I ran into the same problem and found out that the model_dir
we pass to the TFLite model Maker's object detector API is only used for saving the model's weights: that's why the API never restores from checkpoints.
Having a look at the source code of this API, I noticed it internally uses the standard model.compile
and model.fit
functions and it saves the model's weights through the callbacks
parameter of model.fit
.
This means that, provided that we can get the interal keras model, we can just restore our checkpoints by using model.load_weights
!
These are the links to the source code if you want to know more about what some of the functions I use below do:
- Object Detector Documentation
- Object Detector Source Code
- EfficientDetSpec Source Code
- How the TFLite Model Maker API Compiles your Model
This is the code:
#Useful imports
import tensorflow as tf
from tflite_model_maker.config import QuantizationConfig
from tflite_model_maker.config import ExportFormat
from tflite_model_maker import model_spec
from tflite_model_maker import object_detector
from tflite_model_maker.object_detector import DataLoader
#Import the same libs that TFLiteModelMaker interally uses
from tensorflow_examples.lite.model_maker.third_party.efficientdet.keras import train
from tensorflow_examples.lite.model_maker.third_party.efficientdet.keras import train_lib
#Setup variables
batch_size = 6 #or whatever batch size you want
epochs = 50
checkpoint_dir = "/content/..." #whatever your checkpoint directory is
#Create whichever object detector's spec you want
spec = object_detector.EfficientDetLite4Spec(
model_name='efficientdet-lite4',
uri='https://tfhub.dev/tensorflow/efficientdet/lite4/feature-vector/2',
hparams='', #enable grad_checkpoint=True if you want
model_dir=checkpoint_dir,
epochs=epochs,
batch_size=batch_size,
steps_per_execution=1,
moving_average_decay=0,
var_freeze_expr='(efficientnet|fpn_cells|resample_p6)',
tflite_max_detections=25,
strategy=None,
tpu=None,
gcp_project=None,
tpu_zone=None,
use_xla=False,
profile=False,
debug=False,
tf_random_seed=111111,
verbose=1
)
#Load you datasets
train_data, validation_data, test_data = object_detector.DataLoader.from_csv('/path/to/csv.csv')
#Create the object detector
detector = object_detector.create(train_data,
model_spec=spec,
batch_size=batch_size,
train_whole_model=True,
validation_data=validation_data,
epochs = epochs,
do_train = False
)
"""
From here on we use internal/"private" functions of the API,
you can tell because the methods's names begin with an underscore
"""
#Convert the datasets for training
train_ds, steps_per_epoch, _ = detector._get_dataset_and_steps(train_data, batch_size, is_training = True)
validation_ds, validation_steps, val_json_file = detector._get_dataset_and_steps(validation_data, batch_size, is_training = False)
#Get the interal keras model
model = detector.create_model()
#Copy what the API interally does as setup
config = spec.config
config.update(
dict(
steps_per_epoch=steps_per_epoch,
eval_samples=batch_size * validation_steps,
val_json_file=val_json_file,
batch_size=batch_size
)
)
train.setup_model(model, config) #This is the model.compile call basically
model.summary()
"""
Here we restore the weights
"""
#Load the weights from the latest checkpoint.
#In my case:
#checkpoint_dir = "/content/drive/My Drive/Colab Notebooks/checkpoints_heavy/"
#specific_checkpoint_dir = "/content/drive/My Drive/Colab Notebooks/checkpoints_heavy/ckpt-35"
try:
#Option A:
#load the weights from the last successfully completed epoch
latest = tf.train.latest_checkpoint(checkpoint_dir)
#Option B:
#load the weights from a specific checkpoint
#latest = specific_checkpoint_dir
completed_epochs = int(latest.split("/")[-1].split("-")[1]) #the epoch the training was at when the training was last interrupted
model.load_weights(latest)
print("Checkpoint found {}".format(latest))
except Exception as e:
print("Checkpoint not found: ", e)
"""
Optional step.
Add callbacks that get executed at the end of every N
epochs: in this case I want to log the training results to tensorboard.
"""
#tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=tensorboard_dir, histogram_freq=1)
#callbacks = train_lib.get_callbacks(config.as_dict(), validation_ds)
#callbacks.append(tensorboard_callback)
"""
Train the model
"""
model.fit(
train_ds,
epochs=epochs,
initial_epoch=completed_epochs,
steps_per_epoch=steps_per_epoch,
validation_data=validation_ds,
validation_steps=validation_steps,
callbacks=train_lib.get_callbacks(config.as_dict(), validation_ds) #This is for saving checkpoints at the end of every epoch
)
"""
Save/export the trained model
Tip: for integer quantization you simply have to NOT SPECIFY
the quantization_config parameter of the detector.export method
"""
export_dir = "/content/..." #save the tflite wherever you want
quant_config = QuantizationConfig.for_float16() #or whatever quantization you want
detector.model = model #inject our trained model into the object detector
detector.export(export_dir = export_dir, tflite_filename='model.tflite', quantization_config = quant_config)
Answered By - Cristian Davide Conte
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.