Issue
I'm trying to migrate TensorFlow checkpoint weights to PyTorch.
When I extract some weights with cp.load_variable(<CKPT>, <FIELD_NAME>)
, I get a 4D list ordered as HWCN, for example [1, 1, 512, 1024] which is clearly HWCN.
However, all convolution blocks data_format
are set to NHWC.
So, the question is, why there's mismatch?
what should I believe? does the 4D list from cp.load_variable
is correct and all left to do is permute
the dimensions?
Thanks!
Solution
The weights are not given as HWCN, as the weights do not have any batch dimension (N), otherwise that would apply a different weight for each sample in the batch. The shape is [kernel_height, kernel_width, in_channels, out_channels]. There is no mismatch, because data_format
specifies which format the input and output use.
In PyTorch the weight of convolutions is given as [out_channels, in_channels, kernel_height, kernel_width], therefore you only need to permute the dimensions.
Answered By - Michael Jungo
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.