Issue
I have this function:
def sampling(x):
zeros = x*0
samples = tf.random.categorical(tf.math.log(x), 1)
samples = tf.squeeze(tf.one_hot(samples, depth=2), axis=1)
return zeros+samples
That I call from this layer:
x = layers.Lambda(sampling, name="lambda")(x)
But I need to change the depth variable in the sampling function, so I would need something like this:
def sampling(x, depth):
But, how can I make it work with the Lambda layer ?
Thanks a lot
Solution
Use a lambda function inside the Lambda layer...
def sampling(x, depth):
zeros = x*0
samples = tf.random.categorical(tf.math.log(x), 1)
samples = tf.squeeze(tf.one_hot(samples, depth=depth), axis=1)
return zeros+samples
usage:
Lambda(lambda t: sampling(t, depth=3), name="lambda")(x)
Answered By - Marco Cerliani
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.