Issue
I think this error is coming from a problem with shapes, but I have no idea where. The complete error message suggests to do the following:
Also, tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing.
When I enter this argument in the function decorator, it does work.
@tf.function(experimental_relax_shapes=True)
What can the cause be? Here's the full code:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
print(f'Tensorflow version {tf.__version__}')
from tensorflow import keras
from tensorflow.keras.layers import Dense, Conv1D, GlobalAveragePooling1D, Embedding
import tensorflow_datasets as tfds
from tensorflow.keras.models import Model
(train_data, test_data), info = tfds.load('imdb_reviews/subwords8k',
split=[tfds.Split.TRAIN, tfds.Split.TEST],
as_supervised=True, with_info=True)
padded_shapes = ([None], ())
train_dataset = train_data.shuffle(25000).\
padded_batch(padded_shapes=padded_shapes, batch_size=16)
test_dataset = test_data.shuffle(25000).\
padded_batch(padded_shapes=padded_shapes, batch_size=16)
n_words = info.features['text'].encoder.vocab_size
class ConvModel(Model):
def __init__(self):
super(ConvModel, self).__init__()
self.embe = Embedding(n_words, output_dim=16)
self.conv = Conv1D(32, kernel_size=6, activation='elu')
self.glob = GlobalAveragePooling1D()
self.dens = Dense(2)
def call(self, x, training=None, mask=None):
x = self.embe(x)
x = self.conv(x)
x = self.glob(x)
x = self.dens(x)
return x
conv = ConvModel()
conv(next(iter(train_dataset))[0])
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
train_loss = tf.keras.metrics.Mean()
test_loss = tf.keras.metrics.Mean()
train_acc = tf.keras.metrics.CategoricalAccuracy()
test_acc = tf.keras.metrics.CategoricalAccuracy()
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
@tf.function
def train_step(inputs, labels):
with tf.GradientTape() as tape:
logits = conv(inputs, training=True)
loss = loss_object(labels, logits)
train_loss(loss)
train_acc(logits, labels)
gradients = tape.gradient(loss, conv.trainable_variables)
optimizer.apply_gradients(zip(gradients, conv.trainable_variables))
@tf.function
def test_step(inputs, labels):
logits = conv(inputs, training=False)
loss = loss_object(labels, logits)
test_loss(loss)
test_acc(logits, labels)
def learn():
train_loss.reset_states()
test_loss.reset_states()
train_acc.reset_states()
test_acc.reset_states()
for text, target in train_dataset:
train_step(inputs=text, labels=target)
for text, target in test_dataset:
test_step(inputs=text, labels=target)
def main(epochs=2):
for epoch in tf.range(1, epochs + 1):
learn()
template = 'TRAIN LOSS {:>5.3f} TRAIN ACC {:.2f} TEST LOSS {:>5.3f} TEST ACC {:.2f}'
print(template.format(
train_loss.result(),
train_acc.result(),
test_loss.result(),
test_acc.result()
))
if __name__ == '__main__':
main(epochs=1)
Solution
TF/DR: Root-cause of this error is due to change in shape of train_data
which varies from batch to batch. Fixing the size/shape of train_data
resolves this tracing warning. I changed the following line, then everything works as expected. Full gist is here
padded_shapes = ([9000], ())#None.
Details:
As mentioned in the warning message
WARNING:tensorflow:10 out of the last 11 calls to <function train_step at 0x7f4825f6d400> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing.
this retracing warning happens because of the three reasons mentioned in the warning message. Reason (1) is not the root-cause because @tf.function is not called in a loop, also reason (3) is not the root-cause because both the arguments of train_step
and test_step
are tensor objects. So the root-cause is the reason (2) mentioned in the warning.
When I printed the size of train_data
, it printed different sizes. So I tried to pad train_data
so that shape is same for all the batches.
padded_shapes = ([9000], ())#None. # this line throws tracing error as the shape of text is varying for each step in an epoch.
# as the data size is varying, tf.function will start retracing it
# For the demonstration, I used 9000 as max length, but please change it accordingly
Answered By - Vishnuvardhan Janapati
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.