Issue
I have one script code, where x1
and x2
size of 1x68x8x8
tmp_batch, tmp_channel, tmp_height, tmp_width = x1.size()
x1 = x1.view(tmp_batch*tmp_channel, -1)
max_ids = torch.argmax(x1, 1)
max_ids = max_ids.view(-1, 1)
x2 = x2.view(tmp_batch*tmp_channel, -1)
outputs_x_select = torch.gather(x2, 1, max_ids) # size of 68 x 1
As for the above code, I have trouble with torch.gather
when I used old onnx
. Hence, I would like to find an alternative solution that replaces the toch.gather
by other operators but gives the same output with the above code. Could you please give me some suggestions?
Solution
One workaround is to use the equivalent numpy method. If you include an import numpy as np
statement somewhere, you could do the following.
outputs_x_select = torch.Tensor(np.take_along_axis(x2,max_ids,1))
If that gives you a grad related error, try
outputs_x_select = torch.Tensor(np.take_along_axis(x2.detach(),max_ids,1))
An approach without numpy: in this case, it seems that max_ids
contains exactly one entry per row. Thus, I believe the following will work:
max_ids = torch.argmax(x1, 1) # do not reshape
x2 = x2.view(tmp_batch*tmp_channel, -1)
outputs_x_select = x2[torch.arange(tmp_batch*tmp_channel),max_ids]
Answered By - Ben Grossmann
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.