Issue
I was using pytorch and realized that for a linear layer you could pass not only 1d tensors but multidmensional tensors as long as the last dimensions matched. Multi dimensional inputs in pytorch Linear method?
I tried looping over each item, but is that what pytorch does? I'm having trouble thinking how you would program looping with more dimensions, maybe recursion but that seems messy. What is the most efficient way to implement this?
Solution
I tried looping over each item, but is that what pytorch does?
The short answer is yes, loops are used, but it's more complicated than you probably think. If input
is a 2D tensor (a matrix), then the output of a linear operation is computed as input @ weight.T + bias
using an external BLAS library's GEMM operation. Otherwise it uses torch.matmul(input, weight.T) + bias
which uses broadcast semantics to compute a batched version of the operation. Broadcasting is a semantic, not an implementation, so how the broadcasting is performed is going to be backend-dependent. Ultimately some form of looping combined with parallel processing will be used for most of these implementation.
To go a little deeper, lets take a look at the PyTorch implementation of the linear layer. This quickly leads down some rabbit holes since PyTorch uses different backend libraries for performing linear algebra operations efficiently on the hardware available (libraries like oneAPI, Intel MKL, or MAGMA) but perhaps understanding some of the details can help.
Starting at the C++ entrypoint to nn.functional.linear
:
Tensor linear(const Tensor& input, const Tensor& weight, const Tensor& bias) {
if (input.is_mkldnn()) {
return at::mkldnn_linear(input, weight, bias);
}
if (input.dim() == 2 && bias.defined()) {
// Fused op is marginally faster.
return at::addmm(bias, input, weight.t());
}
auto output = at::matmul(input, weight.t());
if (bias.defined()) {
output.add_(bias);
}
return output;
}
There are three cases here.
input.is_mkldnn()
. This condition occurs if the input tensor is in the MKL-DNN format (Tensor.to_mkldnn
) and will make PyTorch use theat::mkldnn_linear
function, which in turn makes calls to ideep, which in turn makes calls to the oneDNN library (previous known as Intel MKL-DNN), which ultimately selects a specific general matrix-matrix multiplication (GEMM) routine dependent on platform and data types. The simplest implementation is the reference implementation, and from that we can see that they use a parallel-for loop (note the anonymous function they use uses a quadruple nested for-loop). In practice the reference implementation probably isn't used, instead, you would probably be calling the x86 optimized version compiled with the Xbyak JIT assembler to produce highly optimized code. I'm not going to pretend to follow all the details of the optimized code, but efficient GEMM is a heavily studied topic that I only have a passing knowledge of.input.dim() == 2 && bias.defined()
. This condition means thatinput
is a 2D tensor (shape[B,M]
) andbias
is defined. In this case pytorch uses theaddmm
function. This efficiently computes the output asinput @ weight.T + bias
where@
is matrix multiplication. There are multiple implementations ofaddmm
registered in PyTorch depending on what types of tensors are being used. The dense-CPU specific version is here which eventually makes calls to an external BLAS library's GEMM subroutine. The backend used is likely Intel MKL but you can check usingprint(torch.__config__.parallel_info())
. Whichever BLAS implementation is being used, its certainly a highly optimized implementation of matrix multiplication similar to the oneDNN implementation, probably using multi-threading and optimized compilation.If neither of the previous two conditions are met then PyTorch uses the
torch.matmul
function, which performs a broadcasted version ofinput @ weight.T
whereinput
is shape[..., M]
. The result of this operation is a tensor of shape[..., N]
. Similar toaddmm
, there are multiple implementations of this function depending on the tensor types but an external library will ultimately be used that uses parallelization and optimized matrix-multiplication subroutines. After the broadcasted matrix-multiplication a broadcastedadd_
operation is used to add thebias
term (ifbias
is defined).
Answered By - jodag
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.