Issue
I am using YOLOV7 model. The pretrained weights shared are optimised and shared in float16 dtype.
How can I convert the dtype of parameters of model in PyTorch. I want to convert the type of the weights to float32 type.
weigths = torch.load('yolov7-mask.pt')
model = weigths['model']
Solution
Load weights to your model and just call .float()
.
example:
cp = torch.load('yolov7-mask.pt')
model.load_state_dict(cp['weight'])
model = model.float()
It'll work if the model's class is nn.Module
. (Checked for torch version 1.8)
Answered By - Hayoung
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.