Issue
I have following pandas dataframe df
column1 column2 list_numbers sublist_column
x y [10,-6,1,-4]
a b [1,3,7,-2]
p q [6,2,-3,-3.2]
the sublist_column will contain the numbers from the column "list_numbers" that adds up to 0 (0.5 is a tolerance) I have written following code.
def return_list(original_lst,target_sum,tolerance):
memo=dict()
sublist=[]
for i, x in enumerate(original_lst):
if memo_func(original_lst, i + 1, target_sum - x, memo,tolerance) > 0:
sublist.append(x)
target_sum -= x
return sublist
def memo_func(original_lst, i, target_sum, memo,tolerance):
if i >= len(original_lst):
if target_sum <=tolerance and target_sum>=-tolerance:
return 1
else:
return 0
if (i, target_sum) not in memo:
c = memo_func(original_lst, i + 1, target_sum, memo,tolerance)
c += memo_func(original_lst, i + 1, target_sum - original_lst[i], memo,tolerance)
memo[(i, target_sum)] = c
return memo[(i, target_sum)]
Then I am using the "return_list" function on the "sublist_column" to populate the result.
target_sum = 0
tolerance=0.5
df['sublist_column']=df['list_numbers'].apply(lambda x: return_list(x,0,tolerance))
the following will be the resultant dataframe
column1 column2 list_numbers sublist_column
x y [10,-6,1,-4] [10,-6,-4]
a b [1,3,7,-2] []
p q [6,2,-3,-3.2] [6,-3,-3.2] #sum is -0.2(within the tolerance)
This is giving me correct result but it's very slow(takes 2 hrs to run if i use spyder IDE), as my dataframe size has roughly 50,000 rows, and the length of some of the lists in the "list_numbers" column is more than 15. The running time is particularly getting affected when the number of elements in the lists in the "list_numbers" column is greater than 15. e.g following list is taking almost 15 minutes to process
[-1572.35,-76.16,-261.1,-7732.0,-1634.0,-52082.42,-3974.15,
-801.65,-30192.79,-671.98,-73.06,-47.72,57.96,-511.18,-391.87,-4145.0,-1008.61,
-17.53,-17.53,-1471.08,-119.26,-2269.7,-2709,-182939.59,-19.48,-516,-6875.75,-138770.16,-71.11,-295.84,-348.09,-3460.71,-704.01,-678,-632.15,-21478.76]
How can i significantly improve my running time?
Solution
Step 1: using Numba
Based on the comments, it appear that memo_func
is the main bottleneck. You can use Numba to speed up its execution. Numba compile the Python code to a native one thanks to a just-in-time (JIT) compiler. The JIT is able to perform tail-call optimizations and native function calls are significantly faster than the one of CPython. Here is an example:
import numba as nb
@nb.njit('(float64[::1], int64, float64, float64)')
def memo_func(original_arr, i, target_sum, tolerance):
if i >= len(original_arr):
if -tolerance <= target_sum <= tolerance:
return 1
return 0
c = memo_func(original_arr, i + 1, target_sum, tolerance)
c += memo_func(original_arr, i + 1, target_sum - original_arr[i], tolerance)
return c
@nb.njit('(float64[::1], float64, float64)')
def return_list(original_arr, target_sum, tolerance):
sublist = []
for i, x in enumerate(original_arr):
if memo_func(original_arr, np.int64(i + 1), target_sum - x,tolerance) > 0:
sublist.append(x)
target_sum -= x
return sublist
Using memoization does not seems to speed up the result and this is a bit cumbersome to implement in Numba. In fact, there are much better ways to improve the algorithm.
Note that you need to convert the lists in Numpy array before calling the functions:
lst = [-850.85,-856.05,-734.09,5549.63,77.59,-39.73,23.63,13.93,-6455.54,-417.07,176.72,-570.41,3621.89,-233.47,-471.54,-30.33,-941.49,-1014.6,1614.5]
result = return_list(np.array(lst, np.float64), 0, tolerance)
Step 2: tail call optimization
Calling many function to compute the right part of the input list is not efficient. The JIT is able to reduce the number of all but it is not able to completely remove them. You can unroll all the call when the depth of the tail calls is big. For example, when there is 6 items to compute, you can use this following code:
if n-i == 6:
c = 0
s0 = target_sum
v0, v1, v2, v3, v4, v5 = original_arr[i:]
for s1 in (s0, s0 - v0):
for s2 in (s1, s1 - v1):
for s3 in (s2, s2 - v2):
for s4 in (s3, s3 - v3):
for s5 in (s4, s4 - v4):
for s6 in (s5, s5 - v5):
c += np.int64(-tolerance <= s6 <= tolerance)
return c
This is pretty ugly but far more efficient since the JIT is able to unroll all the loop and produce a very fast code. Still, this is not enough for large lists.
Step 3: better algorithm
For large input lists, the problem is the exponential complexity of the algorithm. The thing is this problem looks really like a relaxed variant of subset-sum which is known to be NP-complete. Such class of algorithm is known to be very hard to solve. The best exact practical algorithms known so far to solve NP-complete problem are exponential. Put it shortly, this means that for any sufficiently large input, there is no known algorithm capable of finding an exact solution in a reasonable time (eg. less than the lifetime of a human).
That being said, there are heuristics and strategies to improve the complexity of the current algorithm. One efficient approach is to use a meet-in-the-middle algorithm. When applied to your use-case, the idea is to generate a large set of target sums, then sort them, and then use a binary search to find the number of matching values. This is possible here since -tolerance <= target_sum <= tolerance
where target_sum = partial_sum1 + partial_sum2
is equivalent to -tolerance + partial_sum2 <= partial_sum1 <= tolerance + partial_sum2
.
The resulting code is unfortunately quite big and not trivial, but this is certainly the cost to pay for trying to solve efficiently a complex problem like this one. Here it is:
# Generate all the target sums based on in_arr and put the result in out_sum
@nb.njit('(float64[::1], float64[::1], float64)', cache=True)
def gen_all_comb(in_arr, out_sum, target_sum):
assert in_arr.size >= 6
if in_arr.size == 6:
assert out_sum.size == 64
v0, v1, v2, v3, v4, v5 = in_arr
s0 = target_sum
cur = 0
for s1 in (s0, s0 - v0):
for s2 in (s1, s1 - v1):
for s3 in (s2, s2 - v2):
for s4 in (s3, s3 - v3):
for s5 in (s4, s4 - v4):
for s6 in (s5, s5 - v5):
out_sum[cur] = s6
cur += 1
else:
assert out_sum.size % 2 == 0
mid = out_sum.size // 2
gen_all_comb(in_arr[1:], out_sum[:mid], target_sum)
gen_all_comb(in_arr[1:], out_sum[mid:], target_sum - in_arr[0])
# Find the number of item in sorted_arr where:
# lower_bound <= item <= upper_bound
@nb.njit('(float64[::1], float64, float64)', cache=True)
def count_between(sorted_arr, lower_bound, upper_bound):
assert lower_bound <= upper_bound
lo_pos = np.searchsorted(sorted_arr, lower_bound, side='left')
hi_pos = np.searchsorted(sorted_arr, upper_bound, side='right')
return hi_pos - lo_pos
# Count all the target sums in:
# -tolerance <= all_target_sums(in_arr,sorted_target_sums)-s0 <= tolerance
@nb.njit('(float64[::1], float64[::1], float64, float64)', cache=True)
def multi_search(in_arr, sorted_target_sums, tolerance, s0):
assert in_arr.size >= 6
if in_arr.size == 6:
v0, v1, v2, v3, v4, v5 = in_arr
c = 0
for s1 in (s0, s0 + v0):
for s2 in (s1, s1 + v1):
for s3 in (s2, s2 + v2):
for s4 in (s3, s3 + v3):
for s5 in (s4, s4 + v4):
for s6 in (s5, s5 + v5):
lo = -tolerance + s6
hi = tolerance + s6
c += count_between(sorted_target_sums, lo, hi)
return c
else:
c = multi_search(in_arr[1:], sorted_target_sums, tolerance, s0)
c += multi_search(in_arr[1:], sorted_target_sums, tolerance, s0 + in_arr[0])
return c
@nb.njit('(float64[::1], int64, float64, float64)', cache=True)
def memo_func(original_arr, i, target_sum, tolerance):
n = original_arr.size
remaining = n - i
tail_size = min(max(remaining//2, 7), 16)
# Tail call: for very small list (trivial case)
if remaining <= 0:
return np.int64(-tolerance <= target_sum <= tolerance)
# Tail call: for big lists (better algorithm)
elif remaining >= tail_size*2:
partial_sums = np.empty(2**tail_size, dtype=np.float64)
gen_all_comb(original_arr[-tail_size:], partial_sums, target_sum)
partial_sums.sort()
return multi_search(original_arr[-remaining:-tail_size], partial_sums, tolerance, 0.0)
# Tail call: for medium-sized list (unrolling)
elif remaining == 6:
c = 0
s0 = target_sum
v0, v1, v2, v3, v4, v5 = original_arr[i:]
for s1 in (s0, s0 - v0):
for s2 in (s1, s1 - v1):
for s3 in (s2, s2 - v2):
for s4 in (s3, s3 - v3):
for s5 in (s4, s4 - v4):
for s6 in (s5, s5 - v5):
c += np.int64(-tolerance <= s6 <= tolerance)
return c
# Recursion
c = memo_func(original_arr, i + 1, target_sum, tolerance)
c += memo_func(original_arr, i + 1, target_sum - original_arr[i], tolerance)
return c
@nb.njit('(float64[::1], float64, float64)', cache=True)
def return_list(original_arr, target_sum, tolerance):
sublist = []
for i, x in enumerate(original_arr):
if memo_func(original_arr, np.int64(i + 1), target_sum - x,tolerance) > 0:
sublist.append(x)
target_sum -= x
return sublist
Note that the code takes few seconds to compile since it is quite big. The cache should help not to recompile it every time.
Benchmark
Here is the tested inputs:
target_sum = 0
tolerance = 0.5
small_lst = [-850.85,-856.05,-734.09,5549.63,77.59,-39.73,23.63,13.93,-6455.54,-417.07,176.72,-570.41,3621.89,-233.47,-471.54,-30.33,-941.49,-1014.6,1614.5]
big_lst = [-1572.35,-76.16,-261.1,-7732.0,-1634.0,-52082.42,-3974.15,-801.65,-30192.79,-671.98,-73.06,-47.72,57.96,-511.18,-391.87,-4145.0,-1008.61,-17.53,-17.53,-1471.08,-119.26,-2269.7,-2709,-182939.59,-19.48,-516,-6875.75,-138770.16,-71.11,-295.84,-348.09,-3460.71,-704.01,-678,-632.15,-21478.76]
Here is the timing with the small list on my machine:
Naive python algorithm: 173.45 ms
Naive algorithm using Numba: 7.21 ms
Tail call optimization + Numba: 0.33 ms
Efficient algorithm + optim + Numba: 0.16 ms
Here is the timing with the big list on my machine:
Naive python algorithm: >20000 s [estimation & out of memory]
Naive algorithm using Numba: ~900 s [estimation]
Tail call optimization + Numba: 42.61 s
Efficient algorithm + optim + Numba: 0.05 s
Thus, the final implementation is up to ~1000 times faster on the small input and more than 400_000 times faster on the large input! It also use far less RAM so it can actually be executed on a basic PC.
It is worth noting that the execution time can be reduced even further be using multiple thread so to reach a speed up >1_000_000 though it may be slower on small inputs and it will make the code a bit complex.
Answered By - Jérôme Richard
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.