Issue
I need to find a rectangle in a large matrix of integers that has the maximum sum. There is an O(n^3) time algorithm as described here and here for example.
These both work well but they are slow, because of Python partly. How much can the code be sped up for an 800 by 800 matrix for example? It takes 56 seconds on my PC.
Here is my sample code which is based on code from geeksforgeeks:
import numpy as np
def kadane(arr, start, finish, n):
# initialize subarray_sum, max_subarray_sum and
subarray_sum = 0
max_subarray_sum = float('-inf')
i = None
# Just some initial value to check
# for all negative values case
finish = -1
# local variable
local_start = 0
for i in range(n):
subarray_sum += arr[i]
if subarray_sum < 0:
subarray_sum = 0
local_start = i + 1
elif subarray_sum > max_subarray_sum:
max_subarray_sum = subarray_sum
start = local_start
finish = i
# There is at-least one
# non-negative number
if finish != -1:
return max_subarray_sum, start, finish
# Special Case: When all numbers
# in arr[] are negative
max_subarray_sum = arr[0]
start = finish = 0
# Find the maximum element in array
for i in range(1, n):
if arr[i] > max_subarray_sum:
max_subarray_sum = arr[i]
start = finish = i
return max_subarray_sum, start, finish
# The main function that finds maximum subarray_sum rectangle in M
def findMaxsubarray_sum(M):
num_rows, num_cols = M.shape
# Variables to store the final output
max_subarray_sum, finalLeft = float('-inf'), None
finalRight, finalTop, finalBottom = None, None, None
left, right, i = None, None, None
temp = [None] * num_rows
subarray_sum = 0
start = 0
finish = 0
# Set the left column
for left in range(num_cols):
# Initialize all elements of temp as 0
temp = np.zeros(num_rows, dtype=np.int_)
# Set the right column for the left
# column set by outer loop
for right in range(left, num_cols):
temp += M[:num_rows, right]
#print(temp, start, finish, num_rows)
subarray_sum, start, finish = kadane(temp, start, finish, num_rows)
# Compare subarray_sum with maximum subarray_sum so far.
# If subarray_sum is more, then update maxsubarray_sum
# and other output values
if subarray_sum > max_subarray_sum:
max_subarray_sum = subarray_sum
finalLeft = left
finalRight = right
finalTop = start
finalBottom = finish
# final values
print("(Top, Left)", "(", finalTop, finalLeft, ")")
print("(Bottom, Right)", "(", finalBottom, finalRight, ")")
print("Max subarray_sum is:", max_subarray_sum)
# np.random.seed(40)
square = np.random.randint(-3, 4, (800, 800))
# print(square)
%timeit findMaxsubarray_sum(square)
Can numba or pythran or parallelization or just better use of numpy be used to speed this up a lot? Ideally I would like it to take under a second.
There is claimed to be a faster algorithm but I don't know how hard it would be to implement.
Test cases
[[ 3 0 2]
[-3 -3 -1]
[-2 1 -1]]
The correct answer is the rectangle covering the top row with score 5.
[[-1 3 0]
[ 0 0 -2]
[ 0 2 1]]
The correct answer is the rectangle covering the second column with score 5.
[[ 2 2 -1]
[-1 -1 0]
[ 3 1 1]]
The correct answer is the rectangle covering the first two columns with score 6.
Solution
With a minor update to get it to numba compile, you can get the 800x800 matrix to 0.4 seconds (140x faster). With a much heavier hand at rewriting the functions to be more accessible to numba for multithreading, you can get the 800x800 matrix down to 0.05 seconds (>1000x faster).
With pretty minimal alterations and using numba as you already suggested, that code can run with 800x800 matrix on my machine in about 0.4 seconds.
I swapped out your float('-inf')
for -np.inf
to make numba happy and then slapped a numba.njit(cache=True)
on each function, and got that result. I may mess with it a little more to see if a little better performance can be squeezed out.
Here is the code as I ran it:
import numpy as np
import time
import numba
numba.config.DISABLE_JIT = False
def _gt(s=0.0):
return time.perf_counter() - s
@numba.njit(cache=True)
def kadane(arr, start, finish, n):
# initialize subarray_sum, max_subarray_sum and
subarray_sum = 0
max_subarray_sum = -np.inf
i = None
# Just some initial value to check
# for all negative values case
finish = -1
# local variable
local_start = 0
for i in range(n):
subarray_sum += arr[i]
if subarray_sum < 0:
subarray_sum = 0
local_start = i + 1
elif subarray_sum > max_subarray_sum:
max_subarray_sum = subarray_sum
start = local_start
finish = i
# There is at-least one
# non-negative number
if finish != -1:
return max_subarray_sum, start, finish
# Special Case: When all numbers
# in arr[] are negative
max_subarray_sum = arr[0]
start = finish = 0
# Find the maximum element in array
for i in range(1, n):
if arr[i] > max_subarray_sum:
max_subarray_sum = arr[i]
start = finish = i
return max_subarray_sum, start, finish
# The main function that finds maximum subarray_sum rectangle in M
@numba.njit(cache=True)
def findMaxsubarray_sum(M):
num_rows, num_cols = M.shape
# Variables to store the final output
max_subarray_sum, finalLeft = -np.inf, None
finalRight, finalTop, finalBottom = None, None, None
left, right, i = None, None, None
temp = [None] * num_rows
subarray_sum = 0
start = 0
finish = 0
# Set the left column
for left in range(num_cols):
# Initialize all elements of temp as 0
temp = np.zeros(num_rows, dtype=np.int_)
# Set the right column for the left
# column set by outer loop
for right in range(left, num_cols):
temp += M[:num_rows, right]
# print(temp, start, finish, num_rows)
subarray_sum, start, finish = kadane(temp, start, finish, num_rows)
# Compare subarray_sum with maximum subarray_sum so far.
# If subarray_sum is more, then update maxsubarray_sum
# and other output values
if subarray_sum > max_subarray_sum:
max_subarray_sum = subarray_sum
finalLeft = left
finalRight = right
finalTop = start
finalBottom = finish
# final values
if True:
print("(Top, Left)", "(", finalTop, finalLeft, ")")
print("(Bottom, Right)", "(", finalBottom, finalRight, ")")
print("Max subarray_sum is:", max_subarray_sum)
def _main():
# First loop may have numba compilations
# second loop shows true-er performance
for i in range(2):
rng = np.random.default_rng(seed=42)
N = 800
square = rng.integers(-3, 4, (N, N))
s = _gt()
findMaxsubarray_sum(square)
print(f'Run time: {N},{_gt(s):8.6f}\n')
if __name__ == '__main__':
_main()
And here is the result I got:
(Top, Left) ( 26 315 )
(Bottom, Right) ( 256 798 )
Max subarray_sum is: 1991.0
Run time: 800,1.262665
(Top, Left) ( 26 315 )
(Bottom, Right) ( 256 798 )
Max subarray_sum is: 1991.0
Run time: 800,0.379572
And with a major rewrite, taking out most of the variables, getting rid of the if statements in findMaxsubarray_sum
function, and parallelizing the outer most loop (for left in ...
), I got the average run time down to 0.05 seconds for the 800x800. I can't figure out how to optimize the loop in kadane
, and none of my test cases hit the all negatives fall through, so I didn't touch that block.
This one was tested with 25 randomly generated test cases using the original code as the baseline. (The issue in the prior version of this was how the variable out_sum_arr
was created, as I had used numpy.empty
thinking it would be fasted, but it really needed to be actually filled with valid values, so it now uses numpy.full
to address that issue.)
Here is that code:
import numpy as np
import time
import numba
numba.config.DISABLE_JIT = False
def _gt(s=0.0):
return time.perf_counter() - s
@numba.njit(cache=True)
def kadane(arr, n):
# initialize subarray_sum, max_subarray_sum and
subarray_sum = 0
max_subarray_sum = np.int32(-2147483648)
# Just some initial value to check
# for all negative values case
finish = -1
# local variable
local_start = 0
for i in range(n):
subarray_sum += arr[i]
if subarray_sum < 0:
subarray_sum = 0
local_start = i + 1
elif subarray_sum > max_subarray_sum:
max_subarray_sum = subarray_sum
start = local_start
finish = i
# There is at-least one
# non-negative number
if finish != -1:
return max_subarray_sum, start, finish
# raise AssertionError('Untested code block')
# Special Case: When all numbers
# in arr[] are negative
max_subarray_sum = arr[0]
start = finish = 0
# Find the maximum element in array
for i in range(1, n):
if arr[i] > max_subarray_sum:
max_subarray_sum = arr[i]
start = finish = i
return max_subarray_sum, start, finish
# The main function that finds maximum subarray_sum rectangle in M
@numba.njit(cache=True, parallel=True)
def findMaxsubarray_sum(M):
num_rows, num_cols = M.shape
out_pos_arr = np.empty((num_cols * num_cols, 4), dtype=np.int32)
out_sum_arr = np.full((num_cols * num_cols,), np.int32(-2147483648), dtype=np.int32)
# Set the left column
for left in numba.prange(num_cols):
# Initialize all elements of temp as 0
temp = np.zeros(num_rows, dtype=np.int_)
# Set the right column for the left
# column set by outer loop
for right in range(left, num_cols):
temp += M[:num_rows, right]
subarray_sum, start, finish = kadane(temp, num_rows)
out_sum_arr[left * num_cols + right] = subarray_sum
out_pos_arr[left * num_cols + right, :] = np.array((left, right, start, finish))
max_pos = np.argmax(out_sum_arr)
finalLeft, finalRight, finalTop, finalBottom = out_pos_arr[max_pos]
max_subarray_sum = out_sum_arr[max_pos]
return finalTop, finalLeft, finalBottom, finalRight, max_subarray_sum
def _main():
# First loop may have numba compilations
# second loop shows true-er performance
run_sum = 0.0
loop_count = 10
for i in range(loop_count):
rng = np.random.default_rng(seed=42)
N = 800
# N = 1700
square = rng.integers(-3, 4, (N, N), dtype=np.int32)
# square = rng.integers(-30, -4, (N, N))
s = _gt()
finalTop, finalLeft, finalBottom, finalRight, max_subarray_sum = findMaxsubarray_sum(square)
run_time = _gt(s)
print(f'Run time: {N},{run_time:8.6f}')
# Don't count numba compilation time
if i > 0:
run_sum += run_time
if False:
print("(Top, Left)", "(", finalTop, finalLeft, ")")
print("(Bottom, Right)", "(", finalBottom, finalRight, ")")
print("Max subarray_sum is:", max_subarray_sum)
print()
print(f'Average speed: {run_sum / (loop_count - 1):.5f}')
if __name__ == '__main__':
_main()
And here is the output it gave me:
Run time: 800,2.169412
Run time: 800,0.051767
Run time: 800,0.046097
Run time: 800,0.048518
Run time: 800,0.047188
Run time: 800,0.050306
Run time: 800,0.050326
Run time: 800,0.049614
Run time: 800,0.050655
Run time: 800,0.049150
Average speed: 0.04929
Answered By - BitsAreNumbersToo
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.