Issue
I have written a for loop based and einsum based code for matrix multiplication that i want to perform. Can you help me check its correctness ??
`
w = torch.randn((10,32,32))
x = torch.randn((3,32,32))
x_c = x.clone()
z = torch.zeros(x.shape)
for i in range(x.shape[0]):
dummy_x = torch.zeros((x.shape[1],w.shape[2]))
for j in range(w.shape[0]):
dummy_x += torch.matmul(x[i],w[j])
z[i]=dummy_x
result = torch.einsum("ijk,lkm->ijm",x_c,w)
# result = torch.einsum("iljm->ijm",result)
torch.eq(result,z)
I tried the above code and checked for equality using torch.eq but the answer was false
Solution
import torch
# Initialize the tensors
w = torch.randn((10, 32, 32))
x = torch.randn((3, 32, 32))
# For loop based matrix multiplication
z = torch.zeros(x.shape)
for i in range(x.shape[0]):
dummy_x = torch.zeros((x.shape[1], w.shape[2]))
for j in range(w.shape[0]):
dummy_x += torch.matmul(x[i], w[j])
z[i] = dummy_x
# Expand dimensions to make the tensors broadcastable
x_expanded = x.unsqueeze(1) # Shape: (3, 1, 32, 32)
w_expanded = w.unsqueeze(0) # Shape: (1, 10, 32, 32)
# Perform batch matrix multiplication and sum over the second dimension
result = torch.matmul(x_expanded, w_expanded).sum(dim=1) # Shape: (3, 32, 32)
# Check for equality
are_equal = torch.all(torch.eq(result, z))
print("Are the results equal: ", are_equal.item())
Answered By - Adesoji Alu
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.