Issue
I am attempting to segment a 1D pytorch tensor each time when a sequence of x consecutive zeros is encountered. If additional zero elements follow this 'split,' I intend to remove them until the next non-zero value. Currently, I'm employing a for-loop over the zero indices to achieve this. However, this approach is slow, especially when dealing with large tensors containing numerous zero values. Do you have any suggestions on how I can enhance and optimize this code, possibly using PyTorch-specific functions for improved performance?
My tensors here have 2 dims, but the first dim doesnt matter for this task (ignore it).
def _split_tensor_gpu(split_flow, consecutive_zeros):
zero_indices = torch.nonzero(split_flow[:, 1] == 0).view(-1)
if len(zero_indices) == 0:
return [split_flow]
splitted_list = []
first_index = 0
zero_counter = 0
for i in range(1, len(zero_indices)):
if zero_indices[i] - zero_indices[i - 1] == 1:
zero_counter += 1
else:
zero_counter = 0
if zero_counter == consecutive_zeros:
splitted_list.append(split_flow[first_index:zero_indices[i]])
first_index = zero_indices[i] + 1
if zero_counter > consecutive_zeros:
first_index = zero_indices[i] + 1
if first_index <= len(split_flow) - 1:
splitted_list.append(split_flow[first_index:])
return splitted_list
Solution: Based on the first comment, which did most of the job but didn't remove the zeros after splitting, I adapted the function and got the following (this should do the job now):
def _split_tensor_gpu2(tensor_, consecutive_zeros):
# step 1: identify Zero Sequences
# create a mask of zeros and find the difference between consecutive elements
is_zero = tensor_[:, 1] == 0
diff = torch.diff(is_zero.float(), prepend=torch.tensor([0.0], device=tensor_.device))
# start and end indices of zero sequences
start_indices = torch.where(diff == 1)[0]
end_indices = torch.where(diff == -1)[0]
# adjust for cases where sequences reach the end of the tensor
if len(end_indices) == 0 or (len(start_indices) > 0 and end_indices[-1] < start_indices[-1]):
end_indices = torch.cat([end_indices, tensor_.size(0) * torch.ones(1, dtype=torch.long, device=tensor_.device)])
# step 2: mark split points
# find sequences with length >= consecutive_zeros
valid_seqs = (end_indices - start_indices) > consecutive_zeros
valid_start_indices = start_indices[valid_seqs] + consecutive_zeros # 0:st+2
valid_end_indices = end_indices[valid_seqs]
splits = []
end_idx = 0
for i in range(len(valid_start_indices)):
splits.append(tensor_[end_idx:valid_start_indices[i]])
end_idx = valid_end_indices[i]
# add the remaining part of the tensor if any
if end_idx < tensor_.size(0):
splits.append(tensor_[end_idx:])
return splits
Solution
You can use PyTorch's built-in functions on tensor-related operations:
import torch
def _split_tensor_gpu(tensor, consecutive_zeros):
# step 1: identify Zero Sequences
# create a mask of zeros and find the difference between consecutive elements
is_zero = tensor[:, 1] == 0
diff = torch.diff(is_zero.float(), prepend=torch.tensor([0.0], device=tensor.device))
# start and end indices of zero sequences
start_indices = torch.where(diff == 1)[0]
end_indices = torch.where(diff == -1)[0]
# adjust for cases where sequences reach the end of the tensor
if len(end_indices) == 0 or (len(start_indices) > 0 and end_indices[-1] < start_indices[-1]):
end_indices = torch.cat([end_indices, tensor.size(0) * torch.ones(1, dtype=torch.long, device=tensor.device)])
# step 2: mark split points
# find sequences with length >= consecutive_zeros
valid_seqs = (end_indices - start_indices) >= consecutive_zeros
valid_start_indices = start_indices[valid_seqs]
valid_end_indices = end_indices[valid_seqs]
# step 3: split the tensor
# split the tensor at valid indices
splits = []
start_idx = 0
for end_idx in valid_end_indices:
splits.append(tensor[start_idx:end_idx])
start_idx = end_idx
# add the remaining part of the tensor if any
if start_idx < tensor.size(0):
splits.append(tensor[start_idx:])
return splits
# Example usage
tensor = torch.tensor([[1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1]], dtype=torch.float32).t()
consecutive_zeros = 3
split_tensors = _split_tensor_gpu(tensor, consecutive_zeros)
Answered By - inverted_index
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.