Issue
Tf function doesn't change an object's attributes
class f:
v = 7
def __call__(self):
self.v = self.v + 1
@tf.function
def call(c):
tf.print(c.v) # always 7
c()
tf.print(c.v) # always 8
c = f()
call(c)
call(c)
expected print: 7 8 8 9
but instead: 7 8 7 8
All works as expected, when I remove @tf.function decorator. How to make my function work as expected with @tf.function
Solution
This behavior is documented here:
Side effects, like printing, appending to lists, and mutating globals, can behave unexpectedly inside a Function, sometimes executing twice or not all. They only happen the first time you call a Function with a set of inputs. Afterwards, the traced tf.Graph is reexecuted, without executing the Python code.The general rule of thumb is to avoid relying on Python side effects in your logic and only use them to debug your traces. Otherwise, TensorFlow APIs like tf.data, tf.print, tf.summary, tf.Variable.assign, and tf.TensorArray are the best way to ensure your code will be executed by the TensorFlow runtime with each call.
So, maybe try using tf.Variable
to see the expected changes:
import tensorflow as tf
class f:
v = tf.Variable(7)
def __call__(self):
self.v.assign_add(1)
@tf.function
def call(c):
tf.print(c.v) # always 7
c()
tf.print(c.v) # always 8
c = f()
call(c)
call(c)
7
8
8
9
Answered By - AloneTogether
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.