Issue
I have this simple Python function:
import numpy as np
def fast_transform(img, offset, factor):
rep = (img.shape[0]//2, img.shape[1]//2)
out = (img.astype(np.float32) - np.tile(offset, rep)) * np.tile(factor, rep)
return out
The function gets an image (as a NXM numpy ndarray) and two 2x2 arrays (offset and factor). It then calculates a basic linear transformation on every pixel in the image based on it's parity in each dimension:
out[i,j] = (out[i,j] - offset[i%2,j%2]) * factor[i%2,j%2]
As you can see I used np.tile to try and speed up the function but this isn't fast enough for my needs (and I think the creation of the dummy np.tile arrays makes it sub-optimal). I tried to use numba but it doesn't support np.tile yet.
Can you help me optimize this function as much as possible? I am sure there is some simple way to do it I am missing.
Solution
If you're willing to use another library, you can use JAX to make your numpy function ~7x faster (though if your arrays have different shapes, this may not be ideal as JAX recompiles the function for different shapes):
from jax import config
config.update("jax_enable_x64", True)
import jax.numpy as jnp
import jax
@jax.jit
def fast_transform_jax(img, offset, factor):
rep = (img.shape[0]//2, img.shape[1]//2)
out = (img.astype(np.float32) - jnp.tile(offset, rep)) * jnp.tile(factor, rep)
return out
Slight modifications to the numba functions in @Andrej's answer so that they pass allclose
with OPs function:
@nb.njit
def fast_transform_numba(img, offset, factor):
img = img.astype(np.float32)
out = np.empty(img.shape, dtype=np.float64)
for i in range(img.shape[0]):
for j in range(img.shape[1]):
out[i, j] = (img[i, j] - offset[i % 2, j % 2]) * factor[i % 2, j % 2]
return out
@nb.njit(parallel=True)
def fast_transform_numba_parallel(img, offset, factor):
img = img.astype(np.float32)
out = np.empty(img.shape, dtype=np.float64)
for i in nb.prange(img.shape[0]):
for j in nb.prange(img.shape[1]):
out[i, j] = (img[i, j] - offset[i % 2, j % 2]) * factor[i % 2, j % 2]
return out
Timings:
rng = np.random.default_rng()
N, M = 1000, 1000
img = rng.random((N, M)) * 50
offset = rng.random((2, 2)) * 40
factor = rng.random((2, 2)) * 30
assert np.allclose(fast_transform(img, offset, factor), fast_transform_numba(img, offset, factor))
assert np.allclose(fast_transform(img, offset, factor), fast_transform_numba_parallel(img, offset, factor))
assert np.allclose(fast_transform(img, offset, factor), fast_transform_jax(img, offset, factor))
%timeit fast_transform(img, offset, factor)
%timeit fast_transform_numba(img, offset, factor)
%timeit fast_transform_numba_parallel(img, offset, factor)
%timeit fast_transform_jax(img, offset, factor).block_until_ready()
Output:
3.59 ms ± 332 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
1.39 ms ± 37.1 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
871 µs ± 47.5 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
521 µs ± 36.8 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
Answered By - Nin17
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.