Issue
I am trying to code up something similar to the positional encoding in the transformers paper. In order to do so I need to do the following:
For the following three matrices, I want to concatenate them at row level (i.e. the first row from each one stacked together, the second rows together, etc.), and then apply dot product between each matrix and its transpose, and finally, flatten them and stack them together. I'll clarify this in the following example:
x = torch.tensor([[1,1,1,1],
[2,2,2,2],
[3,3,3,3]])
y = torch.tensor([[0,0,0,0],
[0,0,0,0],
[0,0,0,0]])
z = torch.tensor([[4,4,4,4],
[5,5,5,5],
[6,6,6,6]])
concat = torch.cat([x, y, z], dim=-1).view(-1, x.shape[-1])
print(concat)
tensor([[1, 1, 1, 1], [0, 0, 0, 0], [4, 4, 4, 4], [2, 2, 2, 2], [0, 0, 0, 0], [5, 5, 5, 5], [3, 3, 3, 3], [0, 0, 0, 0], [6, 6, 6, 6]])
# Here I get each three rows together, and then apply dot product, flatten, and stack them.
concat = torch.stack([
torch.flatten(
torch.matmul(
concat[i:i+3, :], # 3 is the number of tensors (x,y,z)
torch.transpose(concat[i:i+3, :], 0, 1))
)
for i in range(0, concat.shape[0], 3)
])
print(concat)
tensor([[ 4, 0, 16, 0, 0, 0, 16, 0, 64], [ 16, 0, 40, 0, 0, 0, 40, 0, 100], [ 36, 0, 72, 0, 0, 0, 72, 0, 144]])
Finally, I was able to get the final matrix that I want. My question is, is there a way to achieve this without using a loop as I did in the final step? I want everything to be in tensors.
Solution
The loop you introduce only needs to be there to get a "list of slices" of the data, which is practically the same as reshaping it. You are basically introducing a additional dimension, in which there are 3 entries. Basically from shape [n, k]
to [n, 3, k]
.
For working directly with tensors, you can just call .reshape
to get to the same shape. After that, the rest of the code you use also works almost the exact same. The transpose has to be changed slightly, due to the change in dimensions.
All in all, what you want can be achieved with:
concat2 = concat.reshape((-1, 3, concat.shape[1]))
torch.flatten(
torch.matmul(
concat2,
concat2.transpose(1,2)
),
start_dim=1,
)
Answered By - Jonas V
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.