Issue
I got a 2D function that takes a matrix - 2D tensor with shape (28, 28) and I got a tensor, lets say (64, 10, 28, 28) - it's a tensor that contains a batch of 64 images that passed through a (10 kernels) conv2d layer.
Now, I want to activate on the last two dimentions of the tensor, the (28,28) bit, a 2D function.
Now I did that in a very inefficient way:
def activation_func(input):
for batch_idx in range(input.shape[0]):
for channel_inx in range(input.shape[1]):
input[batch_idx][channel_inx] = 2D_function(input[batch_idx][channel_inx])
return input
which is highly inefficient as I noticed. is there any way of doing this efficiently?
I can write the entire code If necessary
EDIT:
def 2D_function(input):
global indices # yes I know, I will remove this global stuff later
# indices = [(i, j) for i in range(1, 28, 4) for j in range(1, 28, 4)]
for x, y in indices:
relu_decision = relu(input[x, y]) # standard relu - relu(x)=(x>1)*x
if not relu_decision:
# zero out the patch
input[x - 1: x + 3, y - 1: y + 3] = 0
return input
Solution
In such cases, I use a Kronecker product trick:
import torch
torch.set_printoptions(linewidth=200) # you can better see how the mask is shaped
# simulating an input
input = torch.rand(1, 1, 28, 28) - 0.5
ids = torch.meshgrid((torch.arange(1, 28, 4), torch.arange(1, 28, 4)))
# note that relu(x) = (x > 0.) * x, so adjust it to your needs
relus = torch.nn.functional.relu(input[(slice(None), slice(None), *ids)]).to(bool)
A = torch.ones(4, 4)
# generate a block matrix with ones in positions where blocks are set to 0 in correspondence of relus = 0
mask = torch.kron(relus, A)
print(mask.shape)
output = input * mask
print(mask[0, 0])
print(output[0, 0])
Answered By - aretor
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.