Issue
I want to fit a model with torch.matrix_power operation in the forward method of a neural network class. However, I can only fit batch sizes of 1, as torch.matrix_power will only accept a scalar power.
k is a limited number of integers, e.g. between 1-10. I tried precomputing the matrix_powers to avoid the costly power calculation in each iteration, but I got an error the second time the precomputed Ak matrix was called.
A small self-contained example of what I'm trying to do is below:
class NN(nn.Module):
def __init__(self, dim_x):
self.A = nn.Parameter(torch.randn(dim_x, dim_x))
def forward(self, X, k):
k = k.item()
A_pow = torch.matrix_power(self.A, k)
return X @ A_pow
Solution
I think index_select
does what you need. Consider this, based on your example:
class NN(nn.Module):
def __init__(self, dim_x: int, max_k: int, *args, **kwargs):
super().__init__(*args, **kwargs)
self.A = nn.Parameter(torch.randn(dim_x, dim_x))
self.A_k = torch.concat([torch.matrix_power(self.A, k).unsqueeze(dim=0) for k in torch.arange(max_k)], dim=0)
def forward(self, X, k):
A_pow = torch.index_select(self.A_k, dim=0, index=k)
return X @ A_pow
I calculated ahead of time the powers of A up to max_k
. During the forward pass, I used index_select
to choose a subset of those according to the vector k
Answered By - Yakov Dan
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.