Issue
This question is different from tf.cast equivalent in pytorch?.
bitcast do bitwise reinterpretation(like reinterpret_cast
in C++) instead of "safe" type conversion.
This operation is useful when you want to store bfloat16 tensor with numpy.
x = torch.ones(224, 224, 3, dtype=torch.bfloat16
x_np = bitcast(x, torch.uint8).numpy()
Currently numpy doesn't natively support bfloat16, so x.numpy()
will raise TypeError: Got unsupported ScalarType BFloat16
Solution
Use the 2nd overload torch.Tensor.view.
Its semantic is closely similar to numpy.ndarray.view.
Answered By - YouJiacheng
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.