Issue
I modified the example from the Pytorch VAE example to be a convolutional network. I then wanted to implement this in FastAI.
class convVAE(nn.Module):
def __init__(self, dim_z=20):
super(convVAE, self).__init__()
self.cv1 = nn.Conv2d(1, 32, 3, stride=2)
self.cv2 = nn.Conv2d(32, 64, 3, stride=2)
self.fc31 = nn.Linear(2304, dim_z)
self.fc32 = nn.Linear(2304, dim_z)
self.fc4 = nn.Linear(dim_z, 2304)
self.cv5 = nn.ConvTranspose2d(64, 32, 3, stride=2)
self.cv6 = nn.ConvTranspose2d(32, 1, 3, stride=2, output_padding=1)
def encode(self, x):
h1 = F.leaky_relu(self.cv1(x))
h2 = F.leaky_relu(self.cv2(h1)).view(-1, 2304)
return self.fc31(h2), self.fc32(h2)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5*logvar)
eps = torch.randn_like(std)
return mu + eps*std
def decode(self, z):
h5 = F.leaky_relu(self.fc4(z)).view(-1, 64, 6, 6)
h6 = F.leaky_relu(self.cv5(h5))
return torch.sigmoid(self.cv6(h6))
def forward(self, x):
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
return self.decode(z).view(-1, 784), mu, logvar
def get_loss(res,y):
y_hat, mu, logvar = res
BCE = F.binary_cross_entropy(
y.view(-1, 784),
y_hat,
reduction='sum')
KLD = -0.5 * torch.sum(1 + logvar -
mu.pow(2) - logvar.exp())
return BCE + KLD
block = DataBlock(
blocks=(ImageBlock(cls=PILImageBW),ImageBlock(cls=PILImageBW)),
get_items=get_image_files,
splitter=RandomSplitter(valid_pct=0.2, seed=42),
get_y=(lambda x: x),
batch_tfms=aug_transforms(mult=2., do_flip=False))
path = untar_data(URLs.MNIST)
loaders = block.dataloaders(path/“training”,num_workers=0,bs=32)
loaders.train.show_batch(max_n=4, nrows=1)
mdl = convVAE(5)
learn = Learner(loaders, mdl, loss_func = convVAE.get_loss)
learn.fit(1, cbs=ShortEpochCallback())
The gradient is not computing from the loss, as the parameters all become NaN after one step. The loss function does compute but was relatively large O(1e6). The model and loss function work in the native Pytorch implementation.
EDIT: SOLUTION APPEARS TO HAVE BEEN DUE TO def init(.)
instead of def __init__(.)
facepalm
Solution
There is a mistake in your BCE calculation:
BCE = F.binary_cross_entropy(
y.view(-1, 784), # this should be your model prediction
y_hat, # this should be the ground truth
reduction='sum')
A simple fix is to swap the two arguments.
Answered By - TQCH
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.