Issue
I want to do some linear algebra (e.g. tf.matmul) using the gradient. By default the gradient is returned as a list of tensors, where the tensors may have different shapes. My solution has been to reshape the gradient into a single vector. This works in eager mode, but now I want to compile my code using tf.function. It seems there is no way to write a function which can 'flatten' the gradient in graph mode (tf.function).
grad = [tf.ones((2,10)), tf.ones((3,))] # an example of what a gradient from tape.gradient can look like
# this works for flattening the gradient in eager mode only
def flatten_grad(grad):
return tf.concat([tf.reshape(grad[i], tf.math.reduce_prod(tf.shape(grad[i]))) for i in range(len(grad))], 0)
I tried converting it like this, but it doesn't work with tf.function either.
@tf.function
def flatten_grad1(grad):
temp = [None]*len(grad)
for i in tf.range(len(grad)):
i = tf.cast(i, tf.int32)
temp[i] = tf.reshape(grad[i], tf.math.reduce_prod(tf.shape(grad[i])))
return tf.concat(temp, 0)
I tried TensorArrays, but it also does not work.
@tf.function
def flatten_grad2(grad):
temp = tf.TensorArray(tf.float32, size=len(grad), infer_shape=False)
for i in tf.range(len(grad)):
i = tf.cast(i, tf.int32)
temp = temp.write(i, tf.reshape(grad[i], tf.math.reduce_prod(tf.shape(grad[i]))))
return temp.concat()
Solution
Maybe you could try directly iterating over your list
of tensors instead of getting individual tensors by their index:
import tensorflow as tf
grad = [tf.ones((2,10)), tf.ones((3,))] # an example of what a gradient from tape.gradient can look like
@tf.function
def flatten_grad1(grad):
temp = [None]*len(grad)
for i, g in enumerate(grad):
temp[i] = tf.reshape(g, (tf.math.reduce_prod(tf.shape(g)), ))
return tf.concat(temp, axis=0)
print(flatten_grad1(grad))
tf.Tensor([1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.], shape=(23,), dtype=float32)
With tf.TensorArray
:
@tf.function
def flatten_grad2(grad):
temp = tf.TensorArray(tf.float32, size=0, dynamic_size=True, infer_shape=False)
for g in grad:
temp = temp.write(temp.size(), tf.reshape(g, (tf.math.reduce_prod(tf.shape(g)), )))
return temp.concat()
print(flatten_grad2(grad))
Answered By - AloneTogether
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.