Issue
I need to alternate Pytorch Tensors (similar to numpy arrays) with rows and columns of zeros. Like this:
Input => [[ 1,2,3],
[ 4,5,6],
[ 7,8,9]]
output => [[ 1,0,2,0,3],
[ 0,0,0,0,0],
[ 4,0,5,0,6],
[ 0,0,0,0,0],
[ 7,0,8,0,9]]
I am using the accepted answer in this question that proposes the following
def insert_zeros(a, N=1):
# a : Input array
# N : number of zeros to be inserted between consecutive rows and cols
out = np.zeros( (N+1)*np.array(a.shape)-N,dtype=a.dtype)
out[::N+1,::N+1] = a
return out
The answers works perfectly, except that I need to perform this many times on many arrays and the time it takes has become the bottleneck. It is the step-sized slicing that takes most of the time.
For what it's worth, the matrices I am using it for are 4D, an example size of a matrix is 32x18x16x16 and I am inserting the alternate rows/cols only in the last two dimensions.
So my question is, is there another implementation with the same functionality but with reduced time?
Solution
I am not familiar to Pytorch, but to accelerate the code that you provided, I think JAX library will help a lot. So, if:
import numpy as np
import jax
import jax.numpy as jnp
from functools import partial
a = np.arange(10000).reshape(100, 100)
b = jnp.array(a)
@partial(jax.jit, static_argnums=1)
def new(a, N):
out = jnp.zeros( (N+1)*np.array(a.shape)-N,dtype=a.dtype)
out = out.at[::N+1,::N+1].set(a)
return out
will improve the runtime about 10 times on GPU. It depends to array size and N
(The increase in the sizes, the better performances). You can see Benchmarks on my Colab link based on the 4 answer proposed so far (JAX beats the others).
I believe that jax can be one of the best libraries for your case if you could adjust it on your problem (It is possible).
Answered By - Ali_Sh
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.