Issue
So the output of my network looks like this:
output = tensor([[[ 0.0868, -0.2623],
[ 0.0716, -0.2668],
[ 0.0584, -0.2549],
[ 0.0482, -0.2386],
[ 0.0410, -0.2234],
[ 0.0362, -0.2111],
[ 0.0333, -0.2018],
[ 0.0318, -0.1951],
[ 0.0311, -0.1904],
[ 0.0310, -0.1873],
[ 0.0312, -0.1851],
[ 0.0315, -0.1837],
[ 0.0318, -0.1828],
[ 0.0322, -0.1822],
[ 0.0324, -0.1819],
[ 0.0327, -0.1817],
[ 0.0328, -0.1815],
[ 0.0330, -0.1815],
[ 0.0331, -0.1814],
[ 0.0332, -0.1814],
[ 0.0333, -0.1814],
[ 0.0333, -0.1814],
[ 0.0334, -0.1814],
[ 0.0334, -0.1814],
[ 0.0334, -0.1814]],
[[ 0.0868, -0.2623],
[ 0.0716, -0.2668],
[ 0.0584, -0.2549],
[ 0.0482, -0.2386],
[ 0.0410, -0.2234],
[ 0.0362, -0.2111],
[ 0.0333, -0.2018],
[ 0.0318, -0.1951],
[ 0.0311, -0.1904],
[ 0.0310, -0.1873],
[ 0.0312, -0.1851],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164]],
[[ 0.0868, -0.2623],
[ 0.0716, -0.2668],
[ 0.0584, -0.2549],
[ 0.0482, -0.2386],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164]],
[[ 0.0868, -0.2623],
[ 0.0716, -0.2668],
[ 0.0584, -0.2549],
[ 0.0482, -0.2386],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164]],
[[ 0.0868, -0.2623],
[ 0.0716, -0.2668],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164]],
[[ 0.0868, -0.2623],
[ 0.0716, -0.2668],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164]],
[[ 0.0868, -0.2623],
[ 0.0716, -0.2668],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164]],
[[ 0.0868, -0.2623],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164]]])
Which is a shape of [8, 24, 2]
Now 8 is my batch size. And i would like to get a data point from every batch, at the following locations:
index = tensor([24, 10, 3, 3, 1, 1, 1, 0])
So the 24th value from the first batch, the 10th value from the second batch, and so on.
Now i have problems figuring out the syntax. I've tried
torch.gather(output, 0, index)
But it keeps telling me, that my dimensions don't match. And trying
output[ : ,index]
Just gets me the values at all the indexes for each batch. What would be the correct syntax here, to get these values?
Solution
To select only one element per batch you need to enumerate the batch indices, which can be done easily with torch.arange
.
output[torch.arange(output.size(0)), index]
That essentially creates tuples between the enumerated tensor and your index
tensor to access the data, which results in indexing output[0, 24]
, output[1, 10]
etc.
Answered By - Michael Jungo
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.