Issue
I'm trying to do maxpooling over channel dimension:
class ChannelPool(nn.Module):
def forward(self, input):
return torch.max(input, dim=1)
but I get the error
AttributeError: 'torch.return_types.max' object has no attribute 'dim'
Solution
The torch.max
function called with dim
returns a tuple so:
class ChannelPool(nn.Module):
def forward(self, input):
input_max, max_indices = torch.max(input, dim=1)
return input_max
From the documentation of torch.max:
Returns a namedtuple (values, indices) where values is the maximum value of each row of the input tensor in the given dimension dim. And indices is the index location of each maximum value found (argmax).
Answered By - Guglie
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.