Issue
I'm struggling with dimension and matric multiplication in pytorch. I want to multiply matrix A
tensor([[[104.7500, 111.3750, 138.2500, 144.8750],
[104.2500, 110.8750, 137.7500, 144.3750]],
[[356.8750, 363.5000, 390.3750, 397.0000],
[356.3750, 363.0000, 389.8750, 396.5000]]])
with matrix B
tensor([[[[ 0., 1., 2., 5., 6., 7., 10., 11., 12.],
[ 2., 3., 4., 7., 8., 9., 12., 13., 14.],
[ 10., 11., 12., 15., 16., 17., 20., 21., 22.],
[ 12., 13., 14., 17., 18., 19., 22., 23., 24.]],
[[ 25., 26., 27., 30., 31., 32., 35., 36., 37.],
[ 27., 28., 29., 32., 33., 34., 37., 38., 39.],
[ 35., 36., 37., 40., 41., 42., 45., 46., 47.],
[ 37., 38., 39., 42., 43., 44., 47., 48., 49.]],
[[ 50., 51., 52., 55., 56., 57., 60., 61., 62.],
[ 52., 53., 54., 57., 58., 59., 62., 63., 64.],
[ 60., 61., 62., 65., 66., 67., 70., 71., 72.],
[ 62., 63., 64., 67., 68., 69., 72., 73., 74.]]],
[[[ 75., 76., 77., 80., 81., 82., 85., 86., 87.],
[ 77., 78., 79., 82., 83., 84., 87., 88., 89.],
[ 85., 86., 87., 90., 91., 92., 95., 96., 97.],
[ 87., 88., 89., 92., 93., 94., 97., 98., 99.]],
[[100., 101., 102., 105., 106., 107., 110., 111., 112.],
[102., 103., 104., 107., 108., 109., 112., 113., 114.],
[110., 111., 112., 115., 116., 117., 120., 121., 122.],
[112., 113., 114., 117., 118., 119., 122., 123., 124.]],
[[125., 126., 127., 130., 131., 132., 135., 136., 137.],
[127., 128., 129., 132., 133., 134., 137., 138., 139.],
[135., 136., 137., 140., 141., 142., 145., 146., 147.],
[137., 138., 139., 142., 143., 144., 147., 148., 149.]]]])
However using the simple @
to multiply them, doesn'e lead me to the desired result.
What I want is somethinlg like: multiply the first two rows of A by the first 3 4x9 submatrices of B (let's say B[:,:,0,:]) so that I have two results, then in the same way muliply the third and fourth row of A with the second 3 4x9 submatrices of B, so to have again two results, then I want to sum the first results of each multiplication and the second results of each.
I know I have to work with some kind of reshapes but I find it so confusing, can you help me with a quite generalizable solution?
Solution
This example would be helpful:
a = torch.ones((4, 4)).long()
a = a.reshape(2, 2, 4)
b = torch.tensor(list(range(36*6)))
b = b.reshape(2, 3, 4, 9)
t1 = a[0] @ b[0, :]
t2 = a[1] @ b[1, :]
result = t1 + t2
accum = torch.zeros((b.shape[1], a.shape[1], b.shape[3]))
for i in range(a.shape[0]):
accum = accum + (a[i] @ b[i, :])
Answered By - Hayoung
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.