Issue
As I understand, you need to call tensor.contiguous()
explicitly whenever some function or module needs a contiguous tensor. Otherwise you get exceptions like:
RuntimeError: invalid argument 1: input is not contiguous at .../src/torch/lib/TH/generic/THTensor.c:231
(E.g. via.)
What functions or modules require contiguous input? Is this documented?
Or phrased differently, what are situations where you need to call contiguous
?
E.g. Conv1d
, does it require contiguous input? The documentation does not mention this. When the documentation does not mention this, this would always imply that it does not require contiguous input?
(I remember in Theano, any op getting some non-contiguous input, which required it to be contiguous, would just convert it automatically.)
Solution
After additional digging under the hood through source_code, it seems that view
is the only function that explicitly causes an exception when a non-contiguous input is passed.
One would expect any operation using Tensor Views to have the potential of failing with non-contiguous input. In reality, it seems to be the case that most or all of these functions are:
(a.) implemented with support for non-contiguous blocks (see example below), i.e. the tensor iterators can handle multiple pointers to the various chunks of the data in memory, perhaps at the expense of performance, or else
(b.) a call to .contiguous()
wraps the operation (One such example shown here for torch.tensor.diagflat()
). reshape
is essentially the contiguous()
-wrapped form of view
.
By extension, it seems, the main benefit of view
over reshape
would be the explicit Exception when tensors are unexpectedly non-contiguous versus code silently handling this discrepancy at the cost of performance.
This conclusion is based on:
- Testing of all Tensor View ops with non-contiguous inputs.
- Source code analysis of other non-Tensor View functions of interest (e.g. Conv1D, which includes calls to
contiguous
as necessary in all non-trivial input cases). - Inference from pytorch's design philosophy as a simple, at times slow, easy-to-use language.
- Cross-posting on Pytorch Discuss.
- Extensive review of web reported errors involving non-contiguous errors, all of which revolve around problematic calls to
view
.
I did not comprehensively test all pytorch functions, as there are thousands.
EXAMPLE OF (a.):
import torch
import numpy
import time
# allocation
start = time.time()
test = torch.rand([10000,1000,100])
torch.cuda.synchronize()
end = time.time()
print("Allocation took {} sec. Data is at address {}. Contiguous:
{}".format(end -
start,test.storage().data_ptr(),test.is_contiguous()))
# view of a contiguous tensor
start = time.time()
test.view(-1)
torch.cuda.synchronize()
end = time.time()
print("view() took {} sec. Data is at address {}. Contiguous:
{}".format(end -
start,test.storage().data_ptr(),test.is_contiguous()))
# diagonal() on a contiguous tensor
start = time.time()
test.diagonal()
torch.cuda.synchronize()
end = time.time()
print("diagonal() took {} sec. Data is at address {}. Contiguous:
{}".format(end -
start,test.storage().data_ptr(),test.is_contiguous()))
# Diagonal and a few tensor view ops on a non-contiguous tensor
test = test[::2,::2,::2] # indexing is a Tensor View op
resulting in a non-contiguous output
print(test.is_contiguous()) # False
start = time.time()
test = test.unsqueeze(-1).expand([test.shape[0],test.shape[1],test.shape[2],100]).diagonal()
torch.cuda.synchronize()
end = time.time()
print("non-contiguous tensor ops() took {} sec. Data is at
address {}. Contiguous: {}".format(end -
start,test.storage().data_ptr(),test.is_contiguous()))
# reshape, which requires a tensor copy operation to new memory
start = time.time()
test = test.reshape(-1) + 1.0
torch.cuda.synchronize()
end = time.time()
print("reshape() took {} sec. Data is at address {}. Contiguous: {}".format(end - start,test.storage().data_ptr(),test.is_contiguous()))
The following is output:
Allocation took 4.269254922866821 sec. Data is at address 139863636672576. Contiguous: True
view() took 0.0002810955047607422 sec. Data is at address 139863636672576. Contiguous: True
diagonal() took 6.532669067382812e-05 sec. Data is at address 139863636672576. Contiguous: True
False
non-contiguous tensor ops() took 0.00011277198791503906 sec. Data is at address 139863636672576. Contiguous: False
reshape() took 0.13828253746032715 sec. Data is at address 94781254337664. Contiguous: True
A few tensor view operations in block 4 are performed on a non-contiguous input tensor. The operation runs without error, maintains the data in the same memory addresses, and runs relatively faster than an operation requiring a copy to new memory addresses (such as reshape
in block 5). Thus, it seems these operations are implemented in a way that handles non-contiguous inputs without requiring a data copy.
Answered By - DerekG
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.