Issue
Hi can someone improve this code ? The code is about Adaptive Median Filter. When working on large image the code is so slow.
import numpy as np
def padding(img,pad):
padded_img = np.zeros((img.shape[0]+2*pad,img.shape[1]+2*pad))
padded_img[pad:-pad,pad:-pad] = img
return padded_img
def AdaptiveMedianFilter(img,s=3,sMax=7):
if len(img.shape) == 3:
raise Exception ("Single channel image only")
H,W = img.shape
a = sMax//2
padded_img = padding(img,a)
f_img = np.zeros(padded_img.shape)
for i in range(a,H+a+1):
for j in range(a,W+a+1):
value = Lvl_A(padded_img,i,j,s,sMax)
f_img[i,j] = value
return f_img[a:-a,a:-a]
def Lvl_A(mat,x,y,s,sMax):
window = mat[x-(s//2):x+(s//2)+1,y-(s//2):y+(s//2)+1]
Zmin = np.min(window)
Zmed = np.median(window)
Zmax = np.max(window)
A1 = Zmed - Zmin
A2 = Zmed - Zmax
if A1 > 0 and A2 < 0:
return Lvl_B(window)
else:
s += 2
if s <= sMax:
return Lvl_A(mat,x,y,s,sMax)
else:
return Zmed
def Lvl_B(window):
h,w = window.shape
Zmin = np.min(window)
Zmed = np.median(window)
Zmax = np.max(window)
Zxy = window[h//2,w//2]
B1 = Zxy - Zmin
B2 = Zxy - Zmax
if B1 > 0 and B2 < 0 :
return Zxy
else:
return Zmed
Is there any way to improve this code ? For example using vectorized sliding window ? I dont know how to use what numpy function. Ps: For boundary checking its using padding so it dont have to check for out of bounds.
Solution
The numba's njit
is perfect for such kind of computation. Mixed with the parallel=True
+prange
it can be much faster. Moreover, you can pass the minimum, maximum and median values to Lvl_B
rather than recomputing them as @CrisLuengo pointed out.
Here is the modified code:
import numpy as np
from numba import njit,prange
@njit
def padding(img,pad):
padded_img = np.zeros((img.shape[0]+2*pad,img.shape[1]+2*pad))
padded_img[pad:-pad,pad:-pad] = img
return padded_img
@njit(parallel=True)
def AdaptiveMedianFilter(img,s=3,sMax=7):
if len(img.shape) == 3:
raise Exception ("Single channel image only")
H,W = img.shape
a = sMax//2
padded_img = padding(img,a)
f_img = np.zeros(padded_img.shape)
for i in prange(a,H+a+1):
for j in range(a,W+a+1):
value = Lvl_A(padded_img,i,j,s,sMax)
f_img[i,j] = value
return f_img[a:-a,a:-a]
@njit
def Lvl_A(mat,x,y,s,sMax):
window = mat[x-(s//2):x+(s//2)+1,y-(s//2):y+(s//2)+1]
Zmin = np.min(window)
Zmed = np.median(window)
Zmax = np.max(window)
A1 = Zmed - Zmin
A2 = Zmed - Zmax
if A1 > 0 and A2 < 0:
return Lvl_B(window, Zmin, Zmed, Zmax)
else:
s += 2
if s <= sMax:
return Lvl_A(mat,x,y,s,sMax)
else:
return Zmed
@njit
def Lvl_B(window, Zmin, Zmed, Zmax):
h,w = window.shape
Zxy = window[h//2,w//2]
B1 = Zxy - Zmin
B2 = Zxy - Zmax
if B1 > 0 and B2 < 0 :
return Zxy
else:
return Zmed
This code is 500 times faster on my machine with a 256x256 random image.
Note that the first call will not be much faster due to the (included) compilation time.
Note also that the computation can be even faster by not recomputing the min/max/median for each value as the sliding windows share many values (see the paper constant time median filtering (Perreault et al, 2007)).
Answered By - Jérôme Richard
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.