Issue
My data has channels_first format. When I use tensorflow probability layers I get the following error:
Here is an example where the input shape is [1,28,28]
and the reproducible code: Gist (please make sure you are running the code on GPU.)
InvalidArgumentError: required broadcastable shapes
[[node gradient_tape/lambda/model_3_mixture_same_family_4_MixtureSameFamily_MixtureSameFamily/log_prob/model_3_mixture_same_family_4_MixtureSameFamily_independent_normal_4_IndependentNormal_Independentmodel_3_mixture_same_family_4_MixtureSameFamily_independent_normal_4_IndependentNormal_Normal/log_prob/model_3_mixture_same_family_4_MixtureSameFamily_independent_normal_4_IndependentNormal_Normal/log_prob/truediv/RealDiv_1 (defined at <ipython-input-22-243a182981d9>:9) ]] [Op:__inference_train_function_7663]
Errors may have originated from an input operation.
Input Source operations connected to node gradient_tape/lambda/model_3_mixture_same_family_4_MixtureSameFamily_MixtureSameFamily/log_prob/model_3_mixture_same_family_4_MixtureSameFamily_independent_normal_4_IndependentNormal_Independentmodel_3_mixture_same_family_4_MixtureSameFamily_independent_normal_4_IndependentNormal_Normal/log_prob/model_3_mixture_same_family_4_MixtureSameFamily_independent_normal_4_IndependentNormal_Normal/log_prob/truediv/RealDiv_1:
model_3/mixture_same_family_4/MixtureSameFamily/independent_normal_4/IndependentNormal/Softplus (defined at /usr/local/lib/python3.7/dist-packages/tensorflow_probability/python/layers/distribution_layer.py:988)
Function call stack:
train_function
I am not sure how to change the source code so that it works with channels first input shape. Can someone help me with this?
Solution
Your preprocess
function is returning image, image
instead of image, sample['label']
. If you change this, it should work!
I think you can then drop the K.cast in your loss as well.
Update: actually when i run this i get nan's in the loss. Probably something else is wrong. But at least it gets past the shape error! 🤷♂️
Answered By - Chris Suter
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.