Issue
I am ploting 15 images after training my VAE model but it generates the above error. The code is as following
n = 15 # figure with 15x15 digits
digit_size = 32
figure = np.zeros((digit_size * n, digit_size * n))
# We will sample n points within [-15, 15] standard deviations
grid_x = np.linspace(-15, 15, n)
grid_y = np.linspace(-15, 15, n)
for i, yi in enumerate(grid_x):
for j, xi in enumerate(grid_y):
z_sample = np.array([[xi, yi]])
x_decoded = decoder.predict(z_sample)
digit = x_decoded[0].reshape(digit_size, digit_size, 3)
figure[i * digit_size: (i + 1) * digit_size,
j * digit_size: (j + 1) * digit_size] = digit
plt.figure(figsize=(10, 10))
plt.imshow(figure)
plt.show()
I know that my model predicts image size (32 * 32 * 3) from 3072 latent space but I am giving just (32 * 32) here which is why it generates this error but I dont know how to generate (32 * 32 * 3) from below part.
figure[i * digit_size: (i + 1) * digit_size,
j * digit_size: (j + 1) * digit_size] = digit
Thanks
Solution
Try if this works, I can't check myself since you did not provide a complete code (with VAE model generator)
n = 15 # figure with 15x15 digits
digit_size = 32
figure = np.zeros((digit_size * n, digit_size * n, 3))
# We will sample n points within [-15, 15] standard deviations
grid_x = np.linspace(-15, 15, n)
grid_y = np.linspace(-15, 15, n)
for i, yi in enumerate(grid_x):
for j, xi in enumerate(grid_y):
z_sample = np.array([[xi, yi]])
x_decoded = decoder.predict(z_sample)
digit = x_decoded[0].reshape(digit_size, digit_size, 3)
figure[i * digit_size: (i + 1) * digit_size,
j * digit_size: (j + 1) * digit_size,
: ] = digit
plt.figure(figsize=(10, 10))
plt.imshow(figure)
plt.show()
Answered By - rv_normal
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.