Issue
import numpy as np
from numpy import asarray
from matplotlib import pyplot as plt
import torch
# generate a signal
fs = 50 # sampling freq
ts = np.arange(0, 10, 1/fs) # times at which signal is sampled
s1 = np.sin(2 * np.pi * 2 * ts) # 2 hz
s2 = np.sin(2 * np.pi * 3 * ts) # 3 hz
s3 = np.sin(2 * np.pi * 6 * ts) # 6 hz
s = s1 + s2 + s3 # aggregate signal
# generate specgram
spectrum, freqs, t, im = plt.specgram(s, Fs=fs, xextent=((0, len(s)/fs)))
# convert matplotlib image to torch tensor
# bypassing the numpy part would be even better!
torch_tensor = torch.from_numpy(asarray(im, np.float32))
print(torch_tensor)
>>> TypeError: float() argument must be a string or a number, not 'AxesImage'
I should add that the 'spectrum' variable is kind of what I am looking for, except that I am a little confused by it since it has only two columns for time, and I think the specgram image has many more than two timesteps. If there is a way to use the spectrum variable to represent the whole image as a torch tensor, then that would also work for me.
Solution
plt.specgram
returns the spectrogram in the spectrum
variable. This means that you need to pass that variable to the torch.from_numpy
function. Additionally, according to this, specgram
shows the 10*log10(spectrum)
which means that you might want to do that operation ot compare the results shown by specgram
with the plot of your tensor. See code below:
import numpy as np
from numpy import asarray
import numpy as np
from matplotlib import pyplot as plt
import torch
# generate a signal
fs = 50 # sampling freq
ts = np.arange(0, 10, 1/fs) # times at which signal is sampled
s1 = np.sin(2 * np.pi * 2 * ts) # 2 hz
s2 = np.sin(2 * np.pi * 3 * ts) # 3 hz
s3 = np.sin(2 * np.pi * 6 * ts) # 6 hz
s = s1 + s2 + s3 # aggregate signal
# generate specgram
ax1=plt.subplot(121)
ax1.set_title('Specgram image')
spectrum, freqs, t, im = ax1.specgram(s, Fs=fs, xextent=((0, len(s)/fs)))
ax1.axis('tight')
torch_tensor = torch.from_numpy(spectrum)
#Plot torch tensor variable
ax2=plt.subplot(122)
ax2.set_title('Torch tensor image')
ax2.imshow(10*np.log10(torch_tensor),origin='lower left',extent=[0,10,0,25])
ax2.axis('tight')
plt.show()
And the output gives:
Answered By - jylls
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.