Issue
I'm currently using tensorflow probability to build an MDN to perform a regression problem. Everything works great, however, I would like to explore some properties of the model. Because I'm using a model with a mixture of gaussians, I should be able to see the mean and std of each gaussian component. Indeed, I can extract the weights from the model. It seems like there are three numbers from each gaussian component. I'm wondering which (if any) are the mean and std from the mixture of gaussians.
The model I am using is built as follows:
def keras_model_2gauss_mdn(n_variables, name='gauss2_mdn'):
event_shape = [1]
num_components = 2
param_size = tfp.layers.MixtureNormal.params_size(num_components, event_shape)
x_1 = tf.keras.Input(shape=n_variables)
hidden_0 = tf.keras.layers.Dense(192, activation='relu')(x_1)
hidden_1 = tf.keras.layers.Dense(192, activation='relu')(hidden_0)
hidden_2 = tf.keras.layers.Dense(192, activation='relu')(hidden_1)
hidden_3 = tf.keras.layers.Dense(128, activation='relu')(hidden_2)
hidden_4 = tf.keras.layers.Dense(64, activation='relu')(hidden_3)
hidden_5 = tf.keras.layers.Dense(param_size, activation=None)(hidden_4)
output = tfp.layers.MixtureNormal(num_components, event_shape)(hidden_5)
return tf.keras.Model(inputs=x_1, outputs=output, name=name)
After compiling and fitting (i.e. after training), I can get the weights from the whole model by calling .get_weights
. By selecting the last vector from this output, I can get the weights of the MixtureNormal layer. This looks something like
array([ 0.09415845, -0.0941584 , -0.02495631, -0.05152947, -0.04510244,
-0.00484127], dtype=float32)
I suspect the first number in each group of three is the weight, the second is the mean, and the third is the std, but need some clarity on if this is actually the case.
Notice that I've also tried the solution given here and it doesn't seem to work for tfp.layers.MixtureNormal
.
I'm rather new to ML and tensorflow, so any help is greatly appreciated!
Solution
The idea here is when you pass an input to your network, you get a distribution back. In order to make things work nicely with Keras and other things you might do with the output of a NN, the resulting distribution is wrapped in something called _TensorCoercible. This means that when you pass the distribution into a TF op, the distribution will turn itself into a tensor. The default way of doing this is to sample the distribution, but it's configurable via the convert_to_tensor_fn
argument that all TFP layers accept. Eg, you could use convert_to_tensor_fn=lambda dist: dist.mean()
(or whatever you like!). Anyway, this means that when you invoke your model on some input, you don't directly get the MixtureSameFamily (Distribution
!) instance underlying the MixtureNormal (TFP layer!) output -- you get a _TensorCoercible wrapper around it.
To get the MixtureSameFamily instance, look at the tensor_distribution
member on the resultant TC object. It appears that, within the MSF instance, the mixture distribution is not a TC, but the components distribution is. Not sure why. Here's a runnable snippet adapted from your code:
import tensorflow as tf
import tensorflow_probability as tfp
n_variables=[1]
name='blah'
event_shape = [1]
num_components = 2
param_size = tfp.layers.MixtureNormal.params_size(num_components, event_shape)
x_1 = tf.keras.Input(shape=n_variables)
hidden_0 = tf.keras.layers.Dense(192, activation='relu')(x_1)
hidden_1 = tf.keras.layers.Dense(192, activation='relu')(hidden_0)
hidden_2 = tf.keras.layers.Dense(192, activation='relu')(hidden_1)
hidden_3 = tf.keras.layers.Dense(128, activation='relu')(hidden_2)
hidden_4 = tf.keras.layers.Dense(64, activation='relu')(hidden_3)
hidden_5 = tf.keras.layers.Dense(param_size, activation=None)(hidden_4)
output = tfp.layers.MixtureNormal(num_components, event_shape)(hidden_5)
model = tf.keras.Model(inputs=x_1, outputs=output, name=name)
model.compile()
dist = model(tf.constant([[1.]]))
print('mixture component logits: ',
dist.tensor_distribution.mixture_distribution.logits.numpy())
print('mixutre component means: ',
dist.tensor_distribution.components_distribution.tensor_distribution.mean().numpy())
print('mixture component stddevs: ',
dist.tensor_distribution.components_distribution.tensor_distribution.stddev().numpy())
Output:
mixture component logits: [[0.01587015 0.03365375]]
mixutre component means: [[[ 0.04741365]
[-0.01594907]]]
mixture component stddevs: [[[0.68762577]
[0.687484 ]]]
HTH!
Answered By - Chris Suter
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.