Issue
Using PyTorch, torch.combinations
will only take a 1D tensor as input but I would like to apply it to each 1D tensor in a multidimensional tensor.
inp = torch.tensor([[1, 2, 3],
[2, 3, 4]])
torch.combinations((inp), r=2)
The result is an error saying I can't apply it to that shape but I want to apply it to [1, 2, 3]
and [2, 3, 4]
individually. I can't do it one by one because the idea is to apply this to large sets of data.
inp = torch.tensor([[1,2,3],[2,3,4]])
inp_tuple = torch.unbind(inp)
print(inp_tuple)
(tensor([1, 2, 3]), tensor([2, 3, 4]))
torch.combinations((inp_tuple), r=2)
I also tried unbinding the tensor and applying it to the tuple of tensors but it gives an error saying it can't be applied to a tuple.
Is there any way that I can get torch.combinations
to automatically apply to each individual 1D tensor in a multidimensional tensor or each tensor in a tuple of tensors? If not are there any alternatives to achieve all combinations of each individual part of a multidimensional tensor?
Solution
Function torch.combinations
returns all possible combinations of size r
of the elements contained in the 1D input vector. The reason why multi-dimensional inputs are not supported is probably that you have no guarantee that the different vectors in your input have the exact same number of unique elements. Obviously if one of the vectors has a duplicate element then you would end up with one set of combinations bigger than another which is simply not possible to represent with a homogenous PyTorch tensor.
So from there on, I will assume that the input tensor inp
is a 2D tensor shaped (N, C)
where each of its N
vectors contains C
unique elements. The example you gave would fit to this requirement since both vectors have three unique elements each: {1, 2, 3}
and {2, 3, 4}
.
>>> inp = torch.tensor([[1,2,3],[2,3,4]])
The idea is to apply torch.combinations
on an arrangement tensor of length equal to that of our vectors. We can then use those as indices to gather values in our different vectors in our input tensor.
We can retrieve all combinations of an arrangement with the following:
>>> c = torch.combinations(torch.arange(inp.size(1)), r=2)
tensor([[0, 1],
[0, 2],
[1, 2]])
Then we need to reshape and expand both inp
and c
such that they match in number of dimensions:
>>> x = inp[:,None].expand(-1,len(c),-1)
tensor([[[1, 2, 3],
[1, 2, 3],
[1, 2, 3]],
[[2, 3, 4],
[2, 3, 4],
[2, 3, 4]]])
>>> idx = c[None].expand(len(x), -1, -1)
tensor([[[0, 1],
[0, 2],
[1, 2]],
[[0, 1],
[0, 2],
[1, 2]]])
Finally we can apply torch.gather
on x
and idx
on dim=2
. This will return a 3D tensor out
such that:
out[i][j][k] = x[i][j][index[i][j][k]]
Let's make our call on torch.gather
:
>>> x.gather(dim=2, index=idx)
tensor([[[1, 2],
[1, 3],
[2, 3]],
[[2, 3],
[2, 4],
[3, 4]]])
Which is the desired result.
Answered By - Ivan
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.