Issue
I would like to do simple division and average using jit function where nopython = True.
import numpy as np
from numba import jit,prange,typed
A = np.array([[2,2,2],[1,0,0],[1,2,1]], dtype=np.float32)
B = np.array([[2,0,2],[0,1,0],[1,2,1]],dtype=np.float32)
C = np.array([[2,0,1],[0,1,0],[1,1,2]],dtype=np.float32)
my jit function goes
@jit(nopython=True)
def test(a,b,c):
mask = a+b >0
div = np.divide(c, a+b, where=mask)
result = div.mean(axis=1)
return result
test_res = test(A,B,C)
however this throws me an error, what would be the workaround for this? I am trying to do this without the loop, any lights would be appreiciate.
Solution
numba doesn't support some arguments for some of numpy modules (e.g. np.mean()
or where
in np.divid
) (including "axis" argument which is not included). You can do this by some alternative codes like:
@nb.njit("float64[::1](float32[:, ::1], float32[:, ::1], float32[:, ::1])") # parallel --> , parallel=True
def test(a, b, c):
result = np.zeros(c.shape[0])
c = np.copy(c)
for i in range(c.shape[0]): # parallel --> for i in nb.prange(c.shape[0]):
for j in range(c.shape[1]):
if a[i, j] + b[i, j] > 0:
c[i, j] = c[i, j] / (a[i, j] + b[i, j])
result[i] += c[i, j]
return result / c.shape[1]
JAX library can be used to accelerate as:
import jax
import jax.numpy as jnp
@jax.jit
def test_jax(a, b, c):
mask = a + b > 0
div = jnp.where(mask, jnp.divide(c, a + b), c)
return jnp.mean(div, axis=1)
Answered By - Ali_Sh
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.