Issue
There is pretrained model in a repository that its file type is .pyth. I searched the web to find out about this file type and which language is able to read that but I could not find anything. Since I am working with PyTorch, is it possible to read such file in PyTorch? Moreover, normally how it is possible to read and generate that?
To be clearer, in the repository of the TimeSformer model, the pretrained models are of this filetype and as an example, you can find the following commands in that repository:
import torch
from timesformer.models.vit import TimeSformer
model = TimeSformer(img_size=224, num_classes=400, num_frames=8, attention_type='divided_space_time', pretrained_model='/path/to/pretrained/model.pyth')
dummy_video = torch.randn(2, 3, 8, 224, 224) # (batch x channels x frames x height x width)
pred = model(dummy_video,) # (2, 400)
Solution
The file extension can literally be anything, it doesn't change the file contents. If you run torch.load("file.pyth")
it will load a weight dictionary. You can find this in the code in the repo you included. They save the model using this code:
path_to_checkpoint = get_path_to_checkpoint(path_to_job, epoch + 1)
with PathManager.open(path_to_checkpoint, "wb") as f:
torch.save(checkpoint, f)
and the get_path_to_checkpoint
function can be found here:
def get_path_to_checkpoint(path_to_job, epoch):
"""
Get the full path to a checkpoint file.
Args:
path_to_job (string): the path to the folder of the current job.
epoch (int): the number of epoch for the checkpoint.
"""
name = "checkpoint_epoch_{:05d}.pyth".format(epoch)
return os.path.join(get_checkpoint_dir(path_to_job), name)
So, they just pass a filename with an extension .pyth
to torch.save
.
Answered By - jhso
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.