Issue
I have a 4D tensor of shape [32,64,64,3]
which corresponds to [batch, timeframes, frequency_bins, features]
and I do tensor.flatten(start_dim=2)
(in PyTorch). I understand the shape will then transform to [32,64,64*3] --> [batch,timeframes,frequency_bins*features]
- but in terms of the actual ordering of the elements within that new flattened dimension of 64*3
are the first 64 indexes relating to what would have been [:,:,:,0]
the second 64 [:,:,:,1]
and the final 64 [:,:,:,2]
?
Solution
For the sake of understanding, let us first take the simplest case where we have a tensor of rank 2, i.e., a regular matrix. PyTorch performs flattening in what is called row-major order, traversing from the "innermost" axis to the "outermost" axis.
Taking a simple 3x3 array of rank 2, let us call it A[3, 3]
:
[[a, b, c],
[d, e, f],
[g, h, i]]
Flattening this from innermost to outermost axes would give you [a, b, c, d, e, f, g, h, i]
. Let us call this flattened array B[3]
.
The relation between corresponding elements in A
(at index [i, j]
) and B
(at index k
) can easily be derived as:
k = A.size[1] * i + j
This is because to reach the element at [i, j]
, we first move i
rows down, counting A.size[1]
(i.e., the width of the array) elements for each row. Once we reach row i
, we need to get to column j
, thus we add j
to obtain the index in the flattened array.
For example, element e
is at index [1, 1]
in A
. In B
, it would occupy the index 3 * 1 + 1 = 4
, as expected.
Let us extend that same idea to a tensor of rank of rank 4, as in your case, where we are flattening only the last two axes.
Again, taking a simple rank 4 tensor A
of shape (2, 2, 2, 2)
as below:
A =
[[[[ 1, 2],
[ 3, 4]],
[[ 5, 6],
[ 7, 8]]],
[[[ 9, 10],
[11, 12]],
[[13, 14],
[15, 16]]]]
Let us find a relation between the indices of A
and torch.flatten(A, start_dim=2)
(let's call the flattened version B
).
B =
[[[ 0, 1, 2, 3],
[ 4, 5, 6, 7]],
[[ 8, 9, 10, 11],
[12, 13, 14, 15]]]
Element 12 is at index [1, 1, 0, 0]
in A
and index [1, 1, 0]
in B
. Note that the indices at axes 0 and 1, i.e., [1, 1]
remain intact even after partial flattening. This is because those axes are not flattened and thus not impacted.
This is fantastic! Thus, we can represent the transformation from A
to B
as
B[i, j, _] = A[i, j, _, _]
Our task now reduces to finding a relation between the last axis of B
and the last 2 axes of A
. But A[i, j, _, _]
is a 2x2 array, for which we have already derived the relation k = A.size[1] * i + j
,
A.size[1]
would now change to A.size[3]
as 3 is now the last axis. But the general relation remains.
Filling in the blanks, we get the relation between corresponding elements in A
and B
as:
B[i, j, k] = A[i, j, m, n]
where k = A.size[3] * m + n
.
We can verify that this is correct. Element 14 is at [1, 1, 1, 0]
in A
. and moves to [1, 1, 2 * 1 + 0] = [1, 1, 2]
in B
.
EDIT: Added example
Taking @Molem7b5's example of array A
with shape (1, 4, 4, 3)
, from the comments:
Iterating from inner (dim=3
) to outer axes (dim=2
) of A
gives consecutive elements of B
. What I mean by this is:
// Using relation: A[:, :, i, j] == B[:, :, 3 * i + j]
// i = 0, all j
A[:, :, 0, 0] == B[:, :, 0]
A[:, :, 0, 1] == B[:, :, 1]
A[:, :, 0, 2] == B[:, :, 2]
// (Note the consecutive order in B.)
// i = 1, all j
A[:, :, 1, 0] == B[:, :, 3]
A[:, :, 1, 1] == B[:, :, 4]
// and so on until
A[:, :, 3, 2] == B[:, :, 11]
This should give you a better picture as to how the flattening occurs. When in doubt, extrapolate from the relation.
Answered By - Nikhil Kumar
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.