Issue
Can loss terms be manually added with add_loss inside a tensorflow graph? The below example for using add_loss is largely copied from
https://www.tensorflow.org/guide/keras/custom_layers_and_models#the_add_loss_method
but with @tf.function
added to the layer's call method.
import tensorflow as tf
from tensorflow import keras
def main():
layer = ActivityRegularizationLayer()
inputs = tf.constant(5.)
with tf.GradientTape() as tape:
y = layer(inputs)
loss = tf.reduce_mean(y)
loss += sum(layer.losses)
grad = tape.gradient(loss, layer.trainable_weights)
print(f"loss={float(loss)}, grad={grad}")
class ActivityRegularizationLayer(keras.layers.Layer):
def __init__(self, rate=1e-2):
super().__init__()
self.rate = rate
@tf.function
def call(self, inputs):
self.add_loss(self.rate * tf.reduce_sum(inputs))
return inputs
Running the above leads to the error
The tensor <tf.Tensor 'mul:0' shape=() dtype=float32> cannot be accessed from here, because it was defined in FuncGraph(name=call, id=46917885252656), which is out of scope.
Removing the decorator makes things run successfully
loss=5.050000190734863, grad=[]
as does removing the line adding sum(layer.losses)
to the total loss
loss=5.0, grad=[]
Additional details
python 3.9.12
tensorflow 2.8.0
Solution
This is addressed here:
https://github.com/tensorflow/tensorflow/issues/32058#issuecomment-592664998
In summary, this is a known behavior and the solution is to "wrap your whole training step or training loop in a tf.function"
def main():
model = MyModel()
inputs = tf.constant(5.)
loss, grad = model.train_step(inputs)
print(f"loss={float(loss)}, grad={grad}")
class MyModel(keras.models.Model):
def __init__(self):
super().__init__()
self.reg = ActivityRegularizationLayer()
def call(self, inputs):
return self.reg(inputs)
@tf.function
def train_step(self, data):
with tf.GradientTape() as tape:
y = self(data)
loss = tf.reduce_mean(y)
loss += sum(self.losses)
grad = tape.gradient(loss, self.trainable_weights)
return loss, grad
class ActivityRegularizationLayer(keras.layers.Layer):
def __init__(self, rate=1e-2):
super().__init__()
self.rate = rate
def call(self, inputs):
self.add_loss(self.rate * tf.reduce_sum(inputs))
return inputs
Answered By - LexTron
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.