Issue
i need to create an array of all the permutations of the digits 0-9 of size N (input, 1 <= N <= 10).
I've tried this:
np.array(list(itertools.permutations(range(10), n)))
for n=6:
timeit np.array(list(itertools.permutations(range(10), 6)))
on my machine gives:
68.5 ms ± 881 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
But it simply not fast enough. I need it to be below 40ms.
Note: I cannot change the machine from numpy version 1.22.3
Solution
Refer to the link provided by @KellyBundy to get a fast method:
def permutations_(n, k):
if k == 0:
return np.empty((1, 0), np.uint8)
shape = (math.perm(n, k), k)
out = np.zeros(shape, np.uint8)
out[:n - k + 1, -1] = np.arange(n - k + 1, dtype=np.uint8)
start = n - k + 1
for col in reversed(range(1, k)):
block = out[:start, col:]
length = start
for i in range(1, n - col + 1):
stop = start + length
out[start:stop, col:] = block + (block >= i)
out[start:stop, col - 1] = i
start = stop
block += 1 # block is a sub-view on `out`
return out
Simple test:
In [125]: %timeit permutations_(10, 6)
3.73 ms ± 30.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
In [128]: np.array_equal(permutations_(10, 6), np.array(list(permutations(range(10), 6))))
Out[128]: True
Old answer
Using itertools.chain.from_iterable
to concatenate iterators of each tuple to construct array lazily can get a little improvement:
In [94]: from itertools import chain, permutations
In [95]: %timeit np.array(list(permutations(range(10), 6)), np.int8)
63.2 ms ± 500 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
In [96]: %timeit np.fromiter(chain.from_iterable(permutations(range(10), 6)), np.int8).reshape(-1, 6)
28.4 ms ± 110 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
@KellyBundy proposed a faster solution in the comments area, using the fast iteration in the bytes
constructor and buffer protocol. It seems that the numpy.fromiter
wasted a lot of time in iteration:
In [98]: %timeit np.frombuffer(bytes(chain.from_iterable(permutations(range(10), 6))), np.int8).reshape(-1, 6)
11.3 ms ± 23.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
However, it should be noted that the above results are read-only (thanks for @MichaelSzczesny's reminder):
In [109]: ar = np.frombuffer(bytes(chain.from_iterable(permutations(range(10), 6))), np.int8).reshape(-1, 6)
In [110]: ar[0, 0] = 1
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In [110], line 1
----> 1 ar[0, 0] = 1
ValueError: assignment destination is read-only
Answered By - Mechanic Pig
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.