Issue
I am converting my code to Tensorflow v2 and I keep getting the following error:
AssertionError: Called a function referencing variables which have been deleted. This likely means that function-local variables were created and not referenced elsewhere in the program. This is generally a mistake; consider storing variables in an object attribute on first call.
Here is a minimal example that reproduces the error
import tensorflow as tf
class TEST:
def __init__(self, a=1):
self.a = tf.Variable(a)
@tf.function
def increment(self):
self.a = self.a + 1
return self.a
tst = TEST()
tst.increment()
How should I fix this?
Solution
When you do:
self.a = self.a + 1
You are overwriting the reference in self.a
, which was initially associated to the variable created above, with the result of that operation. You are not updating the value of the TensorFlow variable, only replacing the Python reference. That new tensor you are creating (the result of self.a + 1
) does, in turn, use that variable in its computation. The problem is, the moment self.a
is overwritten, the variable is forgotten, and cannot be used anymore. It's a bit of a chicken-egg thing, but tf.function
considers that to be invalid. If you want to have the variable and assign it a new value, do something like this:
@tf.function
def increment(self):
self.a.assign(self.a + 1)
return self.a
Or just his:
@tf.function
def increment(self):
self.a.assign_add(1)
return self.a
Answered By - jdehesa
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.