Issue
Is there a way to get the histograms of torch tensors in batches?
For Example:
x is a tensor of shape (64, 224, 224)
# x will have shape of (64, 256)
x = batch_histogram(x, bins=256, min=0, max=255)
Solution
It is possible to do this with torch.nn.functional.one_hot
in a single line of code:
torch.nn.functional.one_hot(data_tensor, num_classes).sum(dim=-2)
The rationale is that one_hot
does respect batches and, for each value v in the last dimension of the given tensor, creates a tensor filled with 0s, with the exception of the v-th component, which is 1. We them sum over all such one-hot encodings to obtain how many times v appears in each row of data in the 2-nd last dimension (which was the last dimension in tensor_data
).
One possibly serious drawback of this method is memory usage, since each value with be expanded into a tensor of size num_classes
(so, the size of tensor_data
is multiplied by num_classes
). This memory usage is temporary, however, since sum
collapses again this extra dimension and the result will typically be smaller than tensor_data
. I say "typically" because if num_classes
is much larger than the size of the last dimension of tensor_data
then the result will be correspondingly larger.
Here's the code with documentation, followed by pytest tests:
def batch_histogram(data_tensor, num_classes=-1):
"""
Computes histograms of integral values, even if in batches (as opposed to torch.histc and torch.histogram).
Arguments:
data_tensor: a D1 x ... x D_n torch.LongTensor
num_classes (optional): the number of classes present in data.
If not provided, tensor.max() + 1 is used (an error is thrown if tensor is empty).
Returns:
A D1 x ... x D_{n-1} x num_classes 'result' torch.LongTensor,
containing histograms of the last dimension D_n of tensor,
that is, result[d_1,...,d_{n-1}, c] = number of times c appears in tensor[d_1,...,d_{n-1}].
"""
return torch.nn.functional.one_hot(data_tensor, num_classes).sum(dim=-2)
def test_batch_histogram():
data = [2, 5, 1, 1]
expected = [0, 2, 1, 0, 0, 1]
run_test(data, expected)
data = [
[2, 5, 1, 1],
[3, 0, 3, 1],
]
expected = [
[0, 2, 1, 0, 0, 1],
[1, 1, 0, 2, 0, 0],
]
run_test(data, expected)
data = [
[[2, 5, 1, 1], [2, 4, 1, 1], ],
[[3, 0, 3, 1], [2, 3, 1, 1], ],
]
expected = [
[[0, 2, 1, 0, 0, 1], [0, 2, 1, 0, 1, 0], ],
[[1, 1, 0, 2, 0, 0], [0, 2, 1, 1, 0, 0], ],
]
run_test(data, expected)
def test_empty_data():
data = []
num_classes = 2
expected = [0, 0]
run_test(data, expected, num_classes)
data = [[], []]
num_classes = 2
expected = [[0, 0], [0, 0]]
run_test(data, expected, num_classes)
data = [[], []]
run_test(data, expected=None, exception=RuntimeError) # num_classes not provided for empty data
def run_test(data, expected, num_classes=-1, exception=None):
data_tensor = torch.tensor(data, dtype=torch.long)
if exception is None:
expected_tensor = torch.tensor(expected, dtype=torch.long)
actual = batch_histogram(data_tensor, num_classes)
assert torch.equal(actual, expected_tensor)
else:
with pytest.raises(exception):
batch_histogram(data_tensor, num_classes)
Answered By - user118967
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.