Issue
I'm learing pytorch.Reading the official tutorial,I met the preplexing code. input is a tensor, so is target.
def nll(input,target):
return -input[range(target.shape[0]),target].mean()
And the pred is:
target is:
the '-input[range(target.shape[0]),target]' is:
Output shows this is not substracting target from input or merging two tensors
Solution
The code input[range(target.shape[0]), target]
simply picks, from each row i
of input
the element at column indicated by the corresponding element of target
, that is target[i]
.
In other words, if out = input[range(target.shape[0]), target]
then out[i] = input[i, target[i]]
.
This is very similar to torch.gather
.
Answered By - Shai
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.