Issue
The process in the loop is to quantize the phase to find the entropy, which is currently a double loop, batch and channel.
It would be very helpful if you could just tell me how to solve the above question, but I would also be very happy if you could tell me how to parallelize using CPU or GPU that could further speed up the process.
I am willing to convert to numpy etc if you need.
Will appreciate any help. Thanks,
import torch
import time
QDIVNUM = 100 # q means quantization
def calc_entropy(latent):
l_batch, l_ch, l_height, l_width = latent.shape[0], latent.shape[1], latent.shape[2], latent.shape[3]
qwidth = 2*torch.pi/QDIVNUM # for quantize phase
qthres = torch.arange(0, 2*torch.pi+2*torch.pi/QDIVNUM, qwidth) - torch.pi #If QDIVNUM=4 then qthres=[-3.1416, -1.5708, 0, 1.5708, 3.1416])
qboxprob = torch.zeros(l_batch, l_ch, QDIVNUM)
sum_entropy = 0
###double loop###
for i in range(l_batch):
for j in range(l_ch):
qboxnum = torch.zeros(QDIVNUM)
loopc, one_entropy= 0, 0
# get phase
allphs = torch.angle(latent[i,j,:,:])
while(loopc < QDIVNUM):
# The number of quantization ranges included in a given quantization range is put into the count qboxnum.
qboxnum[loopc] = torch.count_nonzero(torch.logical_and(qthres[loopc] <= allphs, allphs < qthres[loopc+1]))
loopc += 1
# calc probability density distribution
qboxprob[i][j] = qboxnum / (l_height*l_width)
# calc entropy
for m in range(QDIVNUM):
if not (qboxprob[i][j][m]==0):
one_entropy += -1 * qboxprob[i][j][m] * torch.log(qboxprob[i][j][m])
sum_entropy = sum_entropy + one_entropy
mean_entropy = sum_entropy / (l_batch * l_ch)
return mean_entropy
start = time.time()
latent = torch.rand(2,16,32,32)+ 1j * torch.rand(2,16,32,32) # for advance: .to('cpu') or .to('cuda')
calc_entropy(latent)
print(time.time()-start)
# 0.145s
Solution
I am willing to convert to numpy etc if you need.
I am much more familiar with NumPy than I am with Torch, so that's what my answer will use. :)
I see the following optimization opportunities in your code:
Use histogram
In this code:
while(loopc < QDIVNUM): # The number of quantization ranges included in a given quantization range is put into the count qboxnum. qboxnum[loopc] = torch.count_nonzero(torch.logical_and(qthres[loopc] <= allphs, allphs < qthres[loopc+1])) loopc += 1
This loop is equivalent to
np.histogram(allphs, qthres)
. Thehistogram()
function will also be much faster.Conditional log
In this code:
for m in range(QDIVNUM): if not (qboxprob[i][j][m]==0): one_entropy += -1 * qboxprob[i][j][m] * torch.log(qboxprob[i][j][m])
I see that this checks for 0, and only evaluates the log if qboxprob is nonzero.
This can be vectorized. See numpy: Efficiently avoid 0s when taking log(matrix)
Misc vectorization
This code can be changed to compute angles for every member of latent:
allphs = torch.angle(latent[i,j,:,:])
If you change it to
allphs = np.angle(latent)
It will compute angles for every member of latent in one call.
In this code:
qboxprob[i][j] = qboxnum / (l_height*l_width)
You can change qboxnum to an array, then do
qboxnum / (l_height*l_width)
to divide every element of qboxnum by the height and width.Etc.
Vectorized histogram
Profiling the new code, I notice it spends lots of time inside
np.histogram()
. This is because I need to call it once for every batch and every channel. Unfortunately, histogram does not accept any kind of axis argument to vectorize it.Fortunately, someone on SO has already written a vectorized version. See Calculate histograms along axis
I noticed that your histogram bins are equally spaced, so I slightly modified their version to find the bin number using division rather than using
np.searchsorted()
, which I found was slightly faster.
Code
import torch
import time
import numpy as np
latent = torch.complex(torch.rand(2,16,32,32) - 0.5, torch.rand(2,16,32,32) - 0.5)
QDIVNUM = 100
def hist_laxis_div(data, n_bins, lo, hi):
# Source: https://stackoverflow.com/a/44155607/530160 plus modifications
data2D = data.reshape(-1, data.shape[-1])
# commented out for performance
# assert lo <= np.min(data) <= np.max(data) <= hi
all_range = hi - lo
bin_width = all_range / (n_bins)
data2D = (data2D - lo) / bin_width
dtype = np.int16
assert n_bins < np.iinfo(dtype).max, "Too many bins for datatype!"
idx = data2D.astype(dtype)
assert (idx >= 0).all()
assert (idx < n_bins).all()
# We need to use bincount to get bin based counts. To have unique IDs for
# each row and not get confused by the ones from other rows, we need to
# offset each row by a scale (using row length for this).
scaled_idx = n_bins*np.arange(data2D.shape[0])[:,None] + idx
# Allocate array for longest possible counts so that we can reshape it
limit = n_bins*data2D.shape[0]
# Get the counts and reshape to multi-dim
counts = np.bincount(scaled_idx.ravel(),minlength=limit+1)[:-1]
counts.shape = data.shape[:-1] + (n_bins,)
return counts
def calc_entropy(latent):
l_batch, l_ch, l_height, l_width = latent.shape
# Combine last two axes of latent array
latent = latent.reshape(l_batch, l_ch, -1)
allphs = np.angle(latent)
qboxnum = hist_laxis_div(allphs, QDIVNUM, -np.pi, np.pi)
qboxprob = qboxnum / (l_height * l_width)
# Take log of qboxprob where qboxprob is not zero. If it is zero, output
# zero as the probability. This is safe because in the next step, we multiply
# by qboxprob, so if it was zero, the output should be zero anyway.
logged = np.log(qboxprob, out=np.zeros_like(qboxprob), where=(qboxprob != 0))
# Multiply each element of qboxprob by logged, and sum them up.
sum_entropy = -1 * np.dot(qboxprob.flatten(), logged.flatten())
mean_entropy = sum_entropy / (l_batch * l_ch)
return mean_entropy
print(calc_entropy(latent))
Timings
Baseline (using torch)
215 ms ± 2.28 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Baseline (with `import numpy as torch`)
38.4 ms ± 699 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Optimized NumPy version
1.3 ms ± 12.5 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
Testing notes
I tested this and found it was equivalent to the original for a random complex tensor.
Answered By - Nick ODell
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.