Issue
After you train a model in Tensorflow:
- How do you save the trained model?
- How do you later restore this saved model?
Solution
Tensorflow 2 Docs
Saving Checkpoints
Adapted from the docs
# -------------------------
# ----- Toy Context -----
# -------------------------
import tensorflow as tf
class Net(tf.keras.Model):
"""A simple linear model."""
def __init__(self):
super(Net, self).__init__()
self.l1 = tf.keras.layers.Dense(5)
def call(self, x):
return self.l1(x)
def toy_dataset():
inputs = tf.range(10.0)[:, None]
labels = inputs * 5.0 + tf.range(5.0)[None, :]
return (
tf.data.Dataset.from_tensor_slices(dict(x=inputs, y=labels)).repeat().batch(2)
)
def train_step(net, example, optimizer):
"""Trains `net` on `example` using `optimizer`."""
with tf.GradientTape() as tape:
output = net(example["x"])
loss = tf.reduce_mean(tf.abs(output - example["y"]))
variables = net.trainable_variables
gradients = tape.gradient(loss, variables)
optimizer.apply_gradients(zip(gradients, variables))
return loss
# ----------------------------
# ----- Create Objects -----
# ----------------------------
net = Net()
opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(
step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator
)
manager = tf.train.CheckpointManager(ckpt, "./tf_ckpts", max_to_keep=3)
# ----------------------------
# ----- Train and Save -----
# ----------------------------
ckpt.restore(manager.latest_checkpoint)
if manager.latest_checkpoint:
print("Restored from {}".format(manager.latest_checkpoint))
else:
print("Initializing from scratch.")
for _ in range(50):
example = next(iterator)
loss = train_step(net, example, opt)
ckpt.step.assign_add(1)
if int(ckpt.step) % 10 == 0:
save_path = manager.save()
print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
print("loss {:1.2f}".format(loss.numpy()))
# ---------------------
# ----- Restore -----
# ---------------------
# In another script, re-initialize objects
opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(
step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator
)
manager = tf.train.CheckpointManager(ckpt, "./tf_ckpts", max_to_keep=3)
# Re-use the manager code above ^
ckpt.restore(manager.latest_checkpoint)
if manager.latest_checkpoint:
print("Restored from {}".format(manager.latest_checkpoint))
else:
print("Initializing from scratch.")
for _ in range(50):
example = next(iterator)
# Continue training or evaluate etc.
More links
exhaustive and useful tutorial on
saved_model
-> https://www.tensorflow.org/guide/saved_modelkeras
detailed guide to save models -> https://www.tensorflow.org/guide/keras/save_and_serialize
Checkpoints capture the exact value of all parameters (tf.Variable objects) used by a model. Checkpoints do not contain any description of the computation defined by the model and thus are typically only useful when source code that will use the saved parameter values is available.
The SavedModel format on the other hand includes a serialized description of the computation defined by the model in addition to the parameter values (checkpoint). Models in this format are independent of the source code that created the model. They are thus suitable for deployment via TensorFlow Serving, TensorFlow Lite, TensorFlow.js, or programs in other programming languages (the C, C++, Java, Go, Rust, C# etc. TensorFlow APIs).
(Highlights are my own)
Tensorflow < 2
From the docs:
Save
# Create some variables.
v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer)
v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer)
inc_v1 = v1.assign(v1+1)
dec_v2 = v2.assign(v2-1)
# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()
# Add ops to save and restore all the variables.
saver = tf.train.Saver()
# Later, launch the model, initialize the variables, do some work, and save the
# variables to disk.
with tf.Session() as sess:
sess.run(init_op)
# Do some work with the model.
inc_v1.op.run()
dec_v2.op.run()
# Save the variables to disk.
save_path = saver.save(sess, "/tmp/model.ckpt")
print("Model saved in path: %s" % save_path)
Restore
tf.reset_default_graph()
# Create some variables.
v1 = tf.get_variable("v1", shape=[3])
v2 = tf.get_variable("v2", shape=[5])
# Add ops to save and restore all the variables.
saver = tf.train.Saver()
# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
with tf.Session() as sess:
# Restore variables from disk.
saver.restore(sess, "/tmp/model.ckpt")
print("Model restored.")
# Check the values of the variables
print("v1 : %s" % v1.eval())
print("v2 : %s" % v2.eval())
simple_save
Many good answer, for completeness I'll add my 2 cents: simple_save. Also a standalone code example using the tf.data.Dataset
API.
Python 3 ; Tensorflow 1.14
import tensorflow as tf
from tensorflow.saved_model import tag_constants
with tf.Graph().as_default():
with tf.Session() as sess:
...
# Saving
inputs = {
"batch_size_placeholder": batch_size_placeholder,
"features_placeholder": features_placeholder,
"labels_placeholder": labels_placeholder,
}
outputs = {"prediction": model_output}
tf.saved_model.simple_save(
sess, 'path/to/your/location/', inputs, outputs
)
Restoring:
graph = tf.Graph()
with restored_graph.as_default():
with tf.Session() as sess:
tf.saved_model.loader.load(
sess,
[tag_constants.SERVING],
'path/to/your/location/',
)
batch_size_placeholder = graph.get_tensor_by_name('batch_size_placeholder:0')
features_placeholder = graph.get_tensor_by_name('features_placeholder:0')
labels_placeholder = graph.get_tensor_by_name('labels_placeholder:0')
prediction = restored_graph.get_tensor_by_name('dense/BiasAdd:0')
sess.run(prediction, feed_dict={
batch_size_placeholder: some_value,
features_placeholder: some_other_value,
labels_placeholder: another_value
})
Standalone example
The following code generates random data for the sake of the demonstration.
- We start by creating the placeholders. They will hold the data at runtime. From them, we create the
Dataset
and then itsIterator
. We get the iterator's generated tensor, calledinput_tensor
which will serve as input to our model. - The model itself is built from
input_tensor
: a GRU-based bidirectional RNN followed by a dense classifier. Because why not. - The loss is a
softmax_cross_entropy_with_logits
, optimized withAdam
. After 2 epochs (of 2 batches each), we save the "trained" model withtf.saved_model.simple_save
. If you run the code as is, then the model will be saved in a folder calledsimple/
in your current working directory. - In a new graph, we then restore the saved model with
tf.saved_model.loader.load
. We grab the placeholders and logits withgraph.get_tensor_by_name
and theIterator
initializing operation withgraph.get_operation_by_name
. - Lastly we run an inference for both batches in the dataset, and check that the saved and restored model both yield the same values. They do!
Code:
import os
import shutil
import numpy as np
import tensorflow as tf
from tensorflow.python.saved_model import tag_constants
def model(graph, input_tensor):
"""Create the model which consists of
a bidirectional rnn (GRU(10)) followed by a dense classifier
Args:
graph (tf.Graph): Tensors' graph
input_tensor (tf.Tensor): Tensor fed as input to the model
Returns:
tf.Tensor: the model's output layer Tensor
"""
cell = tf.nn.rnn_cell.GRUCell(10)
with graph.as_default():
((fw_outputs, bw_outputs), (fw_state, bw_state)) = tf.nn.bidirectional_dynamic_rnn(
cell_fw=cell,
cell_bw=cell,
inputs=input_tensor,
sequence_length=[10] * 32,
dtype=tf.float32,
swap_memory=True,
scope=None)
outputs = tf.concat((fw_outputs, bw_outputs), 2)
mean = tf.reduce_mean(outputs, axis=1)
dense = tf.layers.dense(mean, 5, activation=None)
return dense
def get_opt_op(graph, logits, labels_tensor):
"""Create optimization operation from model's logits and labels
Args:
graph (tf.Graph): Tensors' graph
logits (tf.Tensor): The model's output without activation
labels_tensor (tf.Tensor): Target labels
Returns:
tf.Operation: the operation performing a stem of Adam optimizer
"""
with graph.as_default():
with tf.variable_scope('loss'):
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
logits=logits, labels=labels_tensor, name='xent'),
name="mean-xent"
)
with tf.variable_scope('optimizer'):
opt_op = tf.train.AdamOptimizer(1e-2).minimize(loss)
return opt_op
if __name__ == '__main__':
# Set random seed for reproducibility
# and create synthetic data
np.random.seed(0)
features = np.random.randn(64, 10, 30)
labels = np.eye(5)[np.random.randint(0, 5, (64,))]
graph1 = tf.Graph()
with graph1.as_default():
# Random seed for reproducibility
tf.set_random_seed(0)
# Placeholders
batch_size_ph = tf.placeholder(tf.int64, name='batch_size_ph')
features_data_ph = tf.placeholder(tf.float32, [None, None, 30], 'features_data_ph')
labels_data_ph = tf.placeholder(tf.int32, [None, 5], 'labels_data_ph')
# Dataset
dataset = tf.data.Dataset.from_tensor_slices((features_data_ph, labels_data_ph))
dataset = dataset.batch(batch_size_ph)
iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes)
dataset_init_op = iterator.make_initializer(dataset, name='dataset_init')
input_tensor, labels_tensor = iterator.get_next()
# Model
logits = model(graph1, input_tensor)
# Optimization
opt_op = get_opt_op(graph1, logits, labels_tensor)
with tf.Session(graph=graph1) as sess:
# Initialize variables
tf.global_variables_initializer().run(session=sess)
for epoch in range(3):
batch = 0
# Initialize dataset (could feed epochs in Dataset.repeat(epochs))
sess.run(
dataset_init_op,
feed_dict={
features_data_ph: features,
labels_data_ph: labels,
batch_size_ph: 32
})
values = []
while True:
try:
if epoch < 2:
# Training
_, value = sess.run([opt_op, logits])
print('Epoch {}, batch {} | Sample value: {}'.format(epoch, batch, value[0]))
batch += 1
else:
# Final inference
values.append(sess.run(logits))
print('Epoch {}, batch {} | Final inference | Sample value: {}'.format(epoch, batch, values[-1][0]))
batch += 1
except tf.errors.OutOfRangeError:
break
# Save model state
print('\nSaving...')
cwd = os.getcwd()
path = os.path.join(cwd, 'simple')
shutil.rmtree(path, ignore_errors=True)
inputs_dict = {
"batch_size_ph": batch_size_ph,
"features_data_ph": features_data_ph,
"labels_data_ph": labels_data_ph
}
outputs_dict = {
"logits": logits
}
tf.saved_model.simple_save(
sess, path, inputs_dict, outputs_dict
)
print('Ok')
# Restoring
graph2 = tf.Graph()
with graph2.as_default():
with tf.Session(graph=graph2) as sess:
# Restore saved values
print('\nRestoring...')
tf.saved_model.loader.load(
sess,
[tag_constants.SERVING],
path
)
print('Ok')
# Get restored placeholders
labels_data_ph = graph2.get_tensor_by_name('labels_data_ph:0')
features_data_ph = graph2.get_tensor_by_name('features_data_ph:0')
batch_size_ph = graph2.get_tensor_by_name('batch_size_ph:0')
# Get restored model output
restored_logits = graph2.get_tensor_by_name('dense/BiasAdd:0')
# Get dataset initializing operation
dataset_init_op = graph2.get_operation_by_name('dataset_init')
# Initialize restored dataset
sess.run(
dataset_init_op,
feed_dict={
features_data_ph: features,
labels_data_ph: labels,
batch_size_ph: 32
}
)
# Compute inference for both batches in dataset
restored_values = []
for i in range(2):
restored_values.append(sess.run(restored_logits))
print('Restored values: ', restored_values[i][0])
# Check if original inference and restored inference are equal
valid = all((v == rv).all() for v, rv in zip(values, restored_values))
print('\nInferences match: ', valid)
This will print:
$ python3 save_and_restore.py
Epoch 0, batch 0 | Sample value: [-0.13851789 -0.3087595 0.12804556 0.20013677 -0.08229901]
Epoch 0, batch 1 | Sample value: [-0.00555491 -0.04339041 -0.05111827 -0.2480045 -0.00107776]
Epoch 1, batch 0 | Sample value: [-0.19321944 -0.2104792 -0.00602257 0.07465433 0.11674127]
Epoch 1, batch 1 | Sample value: [-0.05275984 0.05981954 -0.15913513 -0.3244143 0.10673307]
Epoch 2, batch 0 | Final inference | Sample value: [-0.26331693 -0.13013336 -0.12553 -0.04276478 0.2933622 ]
Epoch 2, batch 1 | Final inference | Sample value: [-0.07730117 0.11119192 -0.20817074 -0.35660955 0.16990358]
Saving...
INFO:tensorflow:Assets added to graph.
INFO:tensorflow:No assets to write.
INFO:tensorflow:SavedModel written to: b'/some/path/simple/saved_model.pb'
Ok
Restoring...
INFO:tensorflow:Restoring parameters from b'/some/path/simple/variables/variables'
Ok
Restored values: [-0.26331693 -0.13013336 -0.12553 -0.04276478 0.2933622 ]
Restored values: [-0.07730117 0.11119192 -0.20817074 -0.35660955 0.16990358]
Inferences match: True
Answered By - ted
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.