Issue
Hi I am trying to understand how the following PyTorch AutoEncoder code works. The code below uses the MNIST dataset which is 28X28. My question is how the nn.Linear(128,3) parameters where chosen?
I have a dataset which is 512X512 and I would like to modify the code for this AutoEncoder to support.
class LitAutoEncoder(pl.LightningModule):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3))
self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28))
def forward(self, x):
# in lightning, forward defines the prediction/inference actions
embedding = self.encoder(x)
return embedding
def training_step(self, batch, batch_idx):
# training_step defined the train loop. It is independent of forward
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
Solution
I am assuming input image data are in this shape: x.shape == [bs, 1, h, w]
, where bs
is batch size. Then, x
is first viewed as [bs, h*w]
, i.e. [bs, 28*28]
. This means all pixels in an image are flattened into a 1D vector.
Then in the encoder:
nn.Linear(28*28, 128)
takes flattened input of size[bs, 28*28]
and outputs intermediate result of size[bs, 128]
nn.Linear(128, 3)
:[bs, 128] -> [bs, 3]
Then in the decoder:
nn.Linear(3, 128)
:[bs, 3] -> [bs, 128]
nn.Linear(128, 28*28)
:[bs, 128] -> [bs, 28*28]
The final output is then matched against the input.
If you want to use the exact architecture for your 512x512 images, simply change every occurrence of 28*28
in the code to 512*512
. However, this is a quite infeasible choice, for these reasons:
- For MNIST images,
nn.Linear(28*28, 128)
contains 28x28x128+128=100480 parameters, while for your imagesnn.Linear(512*512, 128)
contains 512x512x128+128=33554560 parameters. The size is too large, and it may lead to overfitting - The intermediate data
[bs, 3]
uses only 3 floats to encode a 512x512 image. I don't think you can recover anything with such compression
I'd suggest looking up convolutional architectures for you purpose
Answered By - ihdv
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.