Issue
I have a list of sentences and I am looking to extract contents between two items. If the start or end item does not exist, I want it to return a row with padding only. I already have the sentences tokenized and padded with 0 to a fixed length.
I figured a way to do this using for loops, but it is extremely slow, so would like to know what is the best way to solve this, probably by using tensor operations.
import torch
start_value, end_value = 4,9
data = torch.tensor([
[3,4,7,8,9,2,0,0,0,0],
[1,5,3,4,7,2,8,9,10,0],
[3,4,7,8,10,0,0,0,0,0], # does not contain end value
[3,7,5,9,2,0,0,0,0,0], # does not contain start value
])
# expected output
[
[7,8,0,0,0,0,0,0,0,0],
[7,2,8,0,0,0,0,0,0,0],
[0,0,0,0,0,0,0,0,0,0],
[0,0,0,0,0,0,0,0,0,0],
]
# or
[
[0,0,7,8,0,0,0,0,0,0],
[0,0,0,0,7,2,8,0,0,0],
[0,0,0,0,0,0,0,0,0,0],
[0,0,0,0,0,0,0,0,0,0],
]
The current solution that I have, which uses a for loop. It does not produce a symmetric array like I want in the expected output.
def _get_part_from_tokens(
self,
data: torch.Tensor,
s_id: int,
e_id: int,
) -> list[str]:
input_ids = []
for row in data:
try:
s_index = (row == s_id).nonzero(as_tuple=True)[0][0]
e_index = (row == e_id).nonzero(as_tuple=True)[0][0]
except IndexError:
input_ids.append(torch.tensor([]))
continue
if s_index is None or e_index is None or s_index > e_index:
input_ids.append(torch.tensor([]))
continue
ind = torch.arange(s_index + 1, e_index)
input_ids.append(row.index_select(0, ind))
return input_ids
Solution
A possible loop-free approach is this:
import torch
# using the provided sample data
start_value, end_value = 4,9
data = torch.tensor([
[3,4,7,8,9,2,0,0,0,0],
[1,5,3,4,7,2,8,9,10,0],
[3,4,7,8,10,0,0,0,0,0], # does not contain end value
[3,7,5,9,2,0,0,0,0,0], # does not contain start value
[3,7,5,8,2,0,0,0,0,0], # does not contain start or end value
])
First, check which rows contain only a start_value
or an end_value
and fill these rows with 0
.
# fill 'invalid' rows with 0
starts = (data == start_value)
ends = (data == end_value)
invalid = ((starts.sum(axis=1) - ends.sum(axis=1)) != 0)
data[invalid] = 0
Then set the values up to (and including) the start_value
and after (and including) the end_value
to 0
in each row. This step targets mainly the 'valid' rows. Nevertheless, all other rows will (again) be overwritten with zeros.
# set values in the start and end of 'valid rows' to 0
row_length = data.shape[1]
start_idx = starts.long().argmax(axis=1)
start_mask = (start_idx[:,None] - torch.arange(row_length))>=0
data[start_mask] = 0
end_idx = row_length - ends.long().argmax(axis=1)
end_mask = (end_idx[:,None] + torch.arange(row_length))>=row_length
data[end_mask] = 0
Note: This works also, if a row contains neither a start_value
nor an end_value
(I added such a row to the sample data). Still, there are many more edge cases that one could think of (e.g. multiple start and end values in one row, start value after end value, ...). Not sure if they are of relevance for the specific problem.
Comparison of execution time
Using timeit and randomly generated data to compare the execution time of the different approaches suggests, that the approach without loops is considerably faster than the approach from the question. If the data is converted to numpy first and converted back to Pytorch afterwards some further (very minor) time savings are possible.
Each dot (execution time) in the plot is the minimum value of 3 trials each with 100 repetitions.
Answered By - rosa b.
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.