Issue
I have a Keras/TensorFlow Probability model where I would like to include values from the prior layer in the convert_to_tensor_fn
parameter in the following DistributionLambda
layer. Ideally, I wish I could do something like this:
from functools import partial
import tensorflow as tf
from tensorflow.keras import layers, Model
import tensorflow_probability as tfp
from typing import Union
tfd = tfp.distributions
zero_buffer = 1e-5
def quantile(s: tfd.Distribution, q: Union[tf.Tensor, float]) -> Union[tf.Tensor, float]:
return s.quantile(q)
# 4 records (1st value represents CDF value,
# 2nd represents location,
# 3rd represents scale)
sample_input = tf.constant([[0.25, 0.0, 1.0],
[0.5, 1.0, 0.5],
[0.75, -1.0, 2.0],
[0.95, 3.0, 2.5]], dtype=tf.float32)
# Build toy model for demonstration
input_layer = layers.Input(3)
dist = tfp.layers.DistributionLambda(
make_distribution_fn=lambda t: tfd.Normal(loc=t[..., 1],
scale=zero_buffer + tf.nn.softplus(t[..., 2])),
convert_to_tensor_fn=lambda t, s: partial(quantile, q=t[..., 0])(s)
)(input_layer)
model = Model(input_layer, dist)
However, according to the documentation, the convert_to_tensor_fn
is required to only take a tfd.Distribution
as input; the convert_to_tensor_fn=lambda t, s:
code doesn't work in the code above.
How can I access data from the prior layer in the convert_to_tensor_fn
? I'm assuming there's a clever way to create a partial
function, or something similar, to get this to work.
Outside of the Keras model framework, this is fairly easy to do using code similar to the example below:
# input data in Tensor Constant form
cdf_data = tf.constant([0.25, 0.5, 0.75, 0.95], dtype=tf.float32)
norm_mu = tf.constant([0.0, 1.0, -1.0, 3.0], dtype=tf.float32)
norm_scale = tf.constant([1.0, 0.5, 2.0, 2.5], dtype=tf.float32)
quant = partial(quantile, q=cdf_data)
norm = tfd.Normal(loc=norm_mu, scale=norm_scale)
quant(norm)
Output:
<tf.Tensor: shape=(4,), dtype=float32, numpy=array([-0.6744898, 1. , 0.3489796, 7.112134 ], dtype=float32)>
Solution
I found a solution to this problem on my own, and decided to post it here.
You can create a wrapper class for the tfp.Normal
distribution that takes in the cdf
value as an argument, and then you overwrite a couple of methods to do what you want. You especially need to overwrite the _sample_n
method and replace it with the quantile function instead of a random draw from the distribution. The class would look something like this:
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_probability.python.internal import dtype_util, tensor_util, reparameterization, samplers
from tensorflow_probability.python.internal import prefer_static as ps
tfd = tfp.distributions
class NormalWrapper(tfp.distributions.Normal):
def __init__(self,
loc,
scale,
cdf_vals,
validate_args=False,
allow_nan_stats=True,
name='NormalCDF'):
parameters = dict(locals())
with tf.name_scope(name) as name:
dtype = dtype_util.common_dtype([loc, scale], dtype_hint=tf.float32)
self._cdf_vals = tensor_util.convert_nonref_to_tensor(
cdf_vals, dtype=dtype, name='cdf_vals')
super(NormalWrapper, self).__init__(loc=loc,
scale=scale,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
name=name)
self._parameters = parameters
def _parameter_properties(self, dtype=tf.float32, num_classes=None):
return dict(
loc=tfp.util.ParameterProperties(),
scale=tfp.util.ParameterProperties(
default_constraining_bijector_fn=(
lambda: tf.nn.softplus(low=dtype_util.eps(dtype)))),
cdf_vals=tfp.util.ParameterProperties(),
)
@property
def cdf_vals(self):
return self._cdf_vals
def _sample_n(self, n, seed=None):
loc = tf.convert_to_tensor(self.loc)
scale = tf.convert_to_tensor(self.scale)
cdf_vals = tf.convert_to_tensor(self.cdf_vals)
shape = ps.concat([[n], self._batch_shape_tensor(loc=loc, scale=scale, cdf_vals=cdf_vals)], axis=0)
return tf.reshape(self.quantile(cdf_vals), shape=shape)
Once you have that class, you can create your DistributionLambda
layer like this:
dist = tfp.layers.DistributionLambda(
make_distribution_fn=lambda t: NormalWrapper(loc=t[..., 1],
scale=zero_buffer + tf.nn.softplus(t[..., 2]),
cdf_vals=t[..., 0]),
)(input_layer)
Answered By - Jed
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.