Issue
I'm new to PyTorch, and am having trouble with a basic autoencoder for a MNIST dataset. An autoencoder is a neural network that trains so that the output recovers the input through narrower layers in between, so that it learns a lower-dimensional representation of a high-dimensional input space.
The error is that all of my trained examples are outputting the same output. I'm not sure where the error is; I modified an online tutorial for this, and I believe all the different images aren't supposed to be outputted into one output. Could anyone help me find out if there is any simple bug or a wrong setting that I didn't spot?
Here is a code snippet that reproduces my problem.
# AE one-block
import time
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import torch
import torchvision
from torch import nn
from torchvision import transforms
from torchvision.transforms import ToTensor
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import ExponentialLR
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torchvision import datasets
class AE(torch.nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.encoder = torch.nn.Sequential(
torch.nn.Linear(28 * 28, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 64),
torch.nn.ReLU(),
torch.nn.Linear(64, 36),
torch.nn.ReLU(),
torch.nn.Linear(36, 18),
torch.nn.ReLU(),
torch.nn.Linear(18, 9)
)
self.decoder = torch.nn.Sequential(
torch.nn.Linear(9, 18),
torch.nn.ReLU(),
torch.nn.Linear(18, 36),
torch.nn.ReLU(),
torch.nn.Linear(36, 64),
torch.nn.ReLU(),
torch.nn.Linear(64, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 28 * 28),
torch.nn.Sigmoid()
)
def forward(self, x):
x = self.flatten(x)
encoded = self.encoder(x)
decoded = self.decoder(encoded)
return decoded
# Model initialization
model = AE()
loss_function = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-1, weight_decay = 1e-8)
tensor_transform = transforms.ToTensor()
dataset = datasets.MNIST(root = "./data",
train = True,
download = True,
transform = tensor_transform)
loader = torch.utils.data.DataLoader(dataset = dataset,
batch_size = 32,
shuffle = True)
# Train
epochs = 5
outputs = []
losses = []
for epoch in range(epochs):
tic = time.monotonic()
print(f'Epoch = {epoch}')
for (image, _) in loader:
image = image.reshape(-1, 28*28)
reconstructed = model(image)
loss = loss_function(reconstructed, image)
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.append(loss)
outputs.append((epochs, image, reconstructed))
toc = time.monotonic()
print(f'Time taken = {round(toc - tic, 2)}')
# Calculate difference between image outputs
im_batches = [image_batch for (image_batch, _) in loader]
only_image = model(im_batches[0]).detach().numpy()[0]
diff_total = 0
for i in range(len(im_batches)):
im_out = model(im_batches[i]).detach().numpy()
diff = np.linalg.norm(im_out - only_image)
print(f'Difference between outputs = {diff_total}')
# Show image outputs
im_out1 = model(im_batches[0]).detach().numpy()
im_out2 = model(im_batches[1]).detach().numpy()
for i in range(3):
plt.imshow(im_out1[i].reshape(28, 28))
plt.show()
for i in range(3):
plt.imshow(im_out2[i].reshape(28, 28))
plt.show()
My Python computation prints that the "difference between outputs" is zero, indicating that all output images from first 10 batches have the same output. A visual examination through directly looking at the first few images also reveals that outputs look like a strange amalgamation of all MNIST digit images.
Solution
Decrease learning rate of optimizer from 1e-1
(which is really big) down to 1e-4
(which is somewhat standard) and increase number of epochs from 5 to 10, outputs will no longer be same. When printing "difference between outputs" variable diff_total
is used and it does not changed, in cycle variable diff
is computed and there is no interaction with diff_total
. So even when epochs == 0
and model outputs are random, "difference between outputs" will also be equal to 0. Also it's better for memory consumption to do losses.append(loss.item())
.
Answered By - draw
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.