Issue
For giving the data in good dimensions to a PyTorch Model, I use squeeze en unsqueeze function like this:
inps = torch.FloatTensor(data[0])
tgts = torch.FloatTensor(data[1])
tgts = torch.unsqueeze(tgts, -1)
tgts = torch.unsqueeze(tgts, -1)
tgts = torch.unsqueeze(tgts, -1)
inps = torch.unsqueeze(inps, -1)
inps = torch.unsqueeze(inps, -1)
inps = torch.unsqueeze(inps, -1)
and this:
inps = torch.FloatTensor(data[0])
tgts = torch.FloatTensor(data[1])
tgts = torch.unsqueeze(tgts, 1)
tgts = torch.unsqueeze(tgts, 1)
tgts = torch.unsqueeze(tgts, 1)
inps = torch.unsqueeze(inps, 1)
inps = torch.unsqueeze(inps, 1)
inps = torch.unsqueeze(inps, 1)
But of course, I'm kinda embarrassed to have this repetitive part in my code. Is there another way, more pythonic and clean, to write this code, please?
Solution
You can use torch.Tensor.view
like below:
how_many_unsqueeze = 3
extra_dims = (1,) * how_many_unsqueeze
# extra_dims -> (1,1,1)
inps.view(-1, *extra_dims) # -> (-1,1,1,1)
tgts.view(-1, *extra_dims) # -> (-1,1,1,1)
You can use torch.reshape
like below:
But after using like in your question you need back to original shape
Instead of unsqueeze
inps = torch.reshape(inps, (len(data[0]),1,1,1))
tgts = torch.reshape(tgts, (len(data[1]),1,1,1))
Instead of squeeze
inps = torch.reshape(inps, (len(data[0]),))
tgts = torch.reshape(tgts, (len(data[1]),))
Answered By - I'mahdi
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.