Issue
I'm training a UNet, which class looks like this:
class UNet(nn.Module):
def __init__(self):
super().__init__()
# encoder (downsampling)
# Each enc_conv/dec_conv block should look like this:
# nn.Sequential(
# nn.Conv2d(...),
# ... (2 or 3 conv layers with relu and batchnorm),
# )
self.enc_conv0 = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1, stride=1),
nn.BatchNorm2d(64),
nn.ReLU()
)
self.pool0 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=False) # 256 -> 128
self.enc_conv1 = nn.Sequential(
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU()
)
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=False) # 128 -> 64
self.enc_conv2 = nn.Sequential(
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU()
)
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) # 64 -> 32
self.enc_conv3 = nn.Sequential(
nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(),
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(),
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(512),
nn.ReLU()
)
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) # 32 -> 16
# bottleneck
self.bottleneck_conv = nn.Sequential(
nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(1024),
nn.ReLU(),
nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(512),
nn.ReLU()
)
# decoder (upsampling)
self.upsample0 = nn.UpsamplingBilinear2d(scale_factor=2) # 16 -> 32
self.dec_conv0 = nn.Sequential(
nn.Conv2d(in_channels=512*2, out_channels=256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(256),
nn.ReLU()
)
self.upsample1 = nn.UpsamplingBilinear2d(scale_factor=2) # 32 -> 64
self.dec_conv1 = nn.Sequential(
nn.Conv2d(in_channels=256*2, out_channels=128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(128),
nn.ReLU()
)
self.upsample2 = nn.UpsamplingBilinear2d(scale_factor=2) # 64 -> 128
self.dec_conv2 = nn.Sequential(
nn.Conv2d(in_channels=128*2, out_channels=64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64),
nn.ReLU()
)
self.upsample3 = nn.UpsamplingBilinear2d(scale_factor=2) # 128 -> 256
self.dec_conv3 = nn.Sequential(
nn.Conv2d(in_channels=64*2, out_channels=1, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(1),
nn.ReLU(),
nn.Conv2d(in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(1),
nn.ReLU(),
nn.Conv2d(in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(1)
)
def forward(self, x):
# encoder
e0 = self.enc_conv0(x)
pool0 = self.pool0(e0)
e1 = self.enc_conv1(pool0)
pool1 = self.pool1(e1)
e2 = self.enc_conv2(pool1)
pool2 = self.pool2(e2)
e3 = self.enc_conv3(pool2)
pool3 = self.pool3(e3)
# bottleneck
b = self.bottleneck_conv(pool3)
# decoder
d0 = self.dec_conv0(torch.cat([self.upsample0(b), e3], 1))
d1 = self.dec_conv1(torch.cat([self.upsample1(d0), e2], 1))
d2 = self.dec_conv2(torch.cat([self.upsample2(d1), e1], 1))
d3 = self.dec_conv3(torch.cat([self.upsample3(d2), e0], 1)) # no activation
return d3
Train method:
def train(model, opt, loss_fn, score_fn, epochs, data_tr, data_val):
torch.cuda.empty_cache()
losses_train = []
losses_val = []
scores_train = []
scores_val = []
for epoch in range(epochs):
tic = time()
print('* Epoch %d/%d' % (epoch+1, epochs))
avg_loss = 0
model.train() # train mode
for X_batch, Y_batch in data_tr:
# data to device
X_batch = X_batch.to(device)
Y_batch = Y_batch.to(device)
# set parameter gradients to zero
opt.zero_grad()
# forward
Y_pred = model(X_batch)
loss = loss_fn(Y_pred, Y_batch) # forward-pass
loss.backward() # backward-pass
opt.step() # update weights
# calculate loss to show the user
avg_loss += loss / len(data_tr)
toc = time()
print('loss: %f' % avg_loss)
losses_train.append(avg_loss)
avg_score_train = score_fn(model, iou_pytorch, data_tr)
scores_train.append(avg_score_train)
# show intermediate results
model.eval() # testing mode
avg_loss_val = 0
#Y_hat = # detach and put into cpu
for X_val, Y_val in data_val:
with torch.no_grad():
Y_hat = model(X_val.to(device)).detach().cpu()
loss = loss_fn(Y_hat, Y_val)
avg_loss_val += loss / len(data_val)
toc = time()
print('loss_val: %f' % avg_loss_val)
losses_val.append(avg_loss_val)
avg_score_val = score_fn(model, iou_pytorch, data_val)
scores_val.append(avg_score_val)
torch.cuda.empty_cache()
# Visualize tools
clear_output(wait=True)
for k in range(5):
plt.subplot(2, 6, k+1)
plt.imshow(np.rollaxis(X_val[k].numpy(), 0, 3), cmap='gray')
plt.title('Real')
plt.axis('off')
plt.subplot(2, 6, k+7)
plt.imshow(Y_hat[k, 0], cmap='gray')
plt.title('Output')
plt.axis('off')
plt.suptitle('%d / %d - loss: %f' % (epoch+1, epochs, avg_loss))
plt.show()
return (losses_train, losses_val, scores_train, scores_val)
However, when executing I get train_loss and val_loss both equal nan and also a warning. In addition, when plotting the segmented picture and the target one, the output picture is not shown. I tried to execute with different loss function, but still the same. There is probably something wrong with my class.
Could you please help me? Thanks in advance.
Solution
I am not sure if this is your error, but your last Convolution layer (self.dec_conv3) has looks odd. I would only reduce to 1 channel at the very last convolution and do not perform 2 Convolutions with 1 In and 1 Out channel. Also ending with a batchnorm can only produce normalized outputs, which could be far from what you really want:
self.dec_conv3 = nn.Sequential(
nn.Conv2d(in_channels=64*2, out_channels=32, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Conv2d(in_channels=32, out_channels=1, kernel_size=3, stride=1, padding=1)
)
It would be interesting if your loss is Nan already at the first iteration or only after a few iterations. Maybe, you use a loss function, that devides by zero?
Answered By - MarcoM
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.