Issue
I have a collection of tensors of common shape (2,ncol)
. Example:
torch.tensor([[1, 2, 3, 7, 8], [3, 3, 1, 8, 7]], dtype=torch.long)
For each tensor, I want to determine if, for each column [[a], [b]]
, the reversed column [[b], [a]]
is also in the tensor. For example, in this case, since ncol
is odd, I can immediately say that this is not the case. But in this other example
torch.tensor([[1, 2, 3, 7, 8, 4], [3, 3, 1, 8, 7, 2]], dtype=torch.long)
I would actually have to perform the check. A naive solution would be
test = torch.tensor([[1, 2, 3, 7, 8, 4], [3, 3, 1, 8, 7, 2]], dtype=torch.long)
def are_column_paired(matrix: torch_geometric.data.Data) -> bool:
ncol = matrix.shape[1]
if ncol % 2 != 0:
all_paired = False
return all_paired
column_has_match = torch.zeros(ncol, dtype=torch.bool)
for i in range(ncol):
if column_has_match[i]:
continue
column = matrix[:, i]
j = i + 1
while not (column_has_match[i]) and (j <= (ncol - 1)):
if column_has_match[j]:
j = j + 1
continue
current_column = matrix[:, j]
current_column = current_column.flip(dims=[0])
if torch.equal(column, current_column):
column_has_match[i], column_has_match[j] = True, True
j = j + 1
all_paired = torch.all(column_has_match).item()
return all_paired
But of course this is slow and possibly not pythonic. How can I write a more efficient code?
PS note that while test
here is very small, in the actual use case I expect ncol
to be O(10^5).
Solution
Here is one possible simple approach. It is likely not the most efficient you can get, but is much faster than your current solution. The idea is to simply check if the sorting the columns in the original and row-flipped tensors are identical. I believe the time complexity of this approach is O(n logn)
, as opposed to O(n^2)
in your case.
def are_columns_paired(matrix):
flipped_matrix = matrix.flip(dims=[0])
matrix_sorted = matrix[:,matrix[1].argsort()] # sort second row
matrix_sorted = matrix_sorted[:, matrix_sorted[0].sort(stable=True)[1]] # sort first row, keeping positions in second row fixed when there is a tie
flipped_matrix = flipped_matrix[:,flipped_matrix[1].argsort()]
flipped_matrix = flipped_matrix[:, flipped_matrix[0].sort(stable=True)[1]]
return (matrix_sorted == flipped_matrix).all()
Here, for both the original and flipped matrix, we sort the columns, first based on the first row, and when there is a tie, based on the second row.
I tested both approaches on a randomly generated tensor with ncol=2000000
and values ranging from 0 to 999999. The above code ran in about 1 second, while the approach from the question did not provide a solution even after an hour.
Answered By - GoodDeeds
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.