Issue
I'm trying to use the u-net segmentation model at https://github.com/khanhha/crack_segmentation, and incorporate it into my pipeline. However, I noticed that whenever I use 'inference_unet.py', for the first time in the session, it downloads a .pth file for vgg.
Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to C:\Users\hedey/.cache\torch\hub\checkpoints\vgg16-397923af.pth
It's not practical to download that file every time I make an inference, especially that this will be a part of an application. How can I avoid having to download that file every time?
Here is the code at 'inference_unet.py':
import sys
import os
import numpy as np
from pathlib import Path
import cv2 as cv
import torch
import torch.nn.functional as F
from torch.autograd import Variable
import torchvision.transforms as transforms
from unet.unet_transfer import UNet16, input_size
import matplotlib.pyplot as plt
import argparse
from os.path import join
from PIL import Image
import gc
from utils import load_unet_vgg16, load_unet_resnet_101, load_unet_resnet_34
from tqdm import tqdm
def evaluate_img(model, img):
input_width, input_height = input_size[0], input_size[1]
img_1 = cv.resize(img, (input_width, input_height), cv.INTER_AREA)
X = train_tfms(Image.fromarray(img_1))
X = Variable(X.unsqueeze(0)).cuda() # [N, 1, H, W]
mask = model(X)
mask = F.sigmoid(mask[0, 0]).data.cpu().numpy()
mask = cv.resize(mask, (img_width, img_height), cv.INTER_AREA)
return mask
def evaluate_img_patch(model, img):
input_width, input_height = input_size[0], input_size[1]
img_height, img_width, img_channels = img.shape
if img_width < input_width or img_height < input_height:
return evaluate_img(model, img)
stride_ratio = 0.1
stride = int(input_width * stride_ratio)
normalization_map = np.zeros((img_height, img_width), dtype=np.int16)
patches = []
patch_locs = []
for y in range(0, img_height - input_height + 1, stride):
for x in range(0, img_width - input_width + 1, stride):
segment = img[y:y + input_height, x:x + input_width]
normalization_map[y:y + input_height, x:x + input_width] += 1
patches.append(segment)
patch_locs.append((x, y))
patches = np.array(patches)
if len(patch_locs) <= 0:
return None
preds = []
for i, patch in enumerate(patches):
patch_n = train_tfms(Image.fromarray(patch))
X = Variable(patch_n.unsqueeze(0)).cuda() # [N, 1, H, W]
masks_pred = model(X)
mask = F.sigmoid(masks_pred[0, 0]).data.cpu().numpy()
preds.append(mask)
probability_map = np.zeros((img_height, img_width), dtype=float)
for i, response in enumerate(preds):
coords = patch_locs[i]
probability_map[coords[1]:coords[1] + input_height, coords[0]:coords[0] + input_width] += response
return probability_map
def disable_axis():
plt.axis('off')
plt.gca().axes.get_xaxis().set_visible(False)
plt.gca().axes.get_yaxis().set_visible(False)
plt.gca().axes.get_xaxis().set_ticklabels([])
plt.gca().axes.get_yaxis().set_ticklabels([])
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-img_dir',type=str, help='input dataset directory')
parser.add_argument('-model_path', type=str, help='trained model path')
parser.add_argument('-model_type', type=str, choices=['vgg16', 'resnet101', 'resnet34'])
parser.add_argument('-out_viz_dir', type=str, default='', required=False, help='visualization output dir')
parser.add_argument('-out_pred_dir', type=str, default='', required=False, help='prediction output dir')
parser.add_argument('-threshold', type=float, default=0.2 , help='threshold to cut off crack response')
args = parser.parse_args()
if args.out_viz_dir != '':
os.makedirs(args.out_viz_dir, exist_ok=True)
for path in Path(args.out_viz_dir).glob('*.*'):
os.remove(str(path))
if args.out_pred_dir != '':
os.makedirs(args.out_pred_dir, exist_ok=True)
for path in Path(args.out_pred_dir).glob('*.*'):
os.remove(str(path))
if args.model_type == 'vgg16':
model = load_unet_vgg16(args.model_path)
elif args.model_type == 'resnet101':
model = load_unet_resnet_101(args.model_path)
elif args.model_type == 'resnet34':
model = load_unet_resnet_34(args.model_path)
print(model)
else:
print('undefind model name pattern')
exit()
channel_means = [0.485, 0.456, 0.406]
channel_stds = [0.229, 0.224, 0.225]
paths = [path for path in Path(args.img_dir).glob('*.*')]
for path in tqdm(paths):
#print(str(path))
#train_tfms = transforms.Compose([transforms.ToTensor(), transforms.Normalize(channel_means, channel_stds)])
train_tfms = transforms.Compose([transforms.ToTensor()])
img_0 = Image.open(str(path))
img_0 = np.asarray(img_0)
if len(img_0.shape) != 3:
print(f'incorrect image shape: {path.name}{img_0.shape}')
continue
img_0 = img_0[:,:,:3]
img_height, img_width, img_channels = img_0.shape
#img_height, img_width = img_0.shape
prob_map_full = evaluate_img(model, img_0)
if args.out_pred_dir != '':
#cv.imwrite(filename=join(args.out_pred_dir, f'{path.stem}.jpg'), img=(prob_map_full * 255).astype(np.uint8))
cv.imwrite(filename=join(args.out_pred_dir, f'{path.stem}.jpg'), img=(prob_map_full).astype(np.uint8))
if args.out_viz_dir != '':
# plt.subplot(121)
# plt.imshow(img_0), plt.title(f'{img_0.shape}')
if img_0.shape[0] > 2000 or img_0.shape[1] > 2000:
img_1 = cv.resize(img_0, None, fx=0.2, fy=0.2, interpolation=cv.INTER_AREA)
else:
img_1 = img_0
# plt.subplot(122)
# plt.imshow(img_0), plt.title(f'{img_0.shape}')
# plt.show()
prob_map_patch = evaluate_img_patch(model, img_1)
#plt.title(f'name={path.stem}. \n cut-off threshold = {args.threshold}', fontsize=4)
prob_map_viz_patch = prob_map_patch.copy()
prob_map_viz_patch = prob_map_viz_patch/ prob_map_viz_patch.max()
prob_map_viz_patch[prob_map_viz_patch < args.threshold] = 0.0
fig = plt.figure()
st = fig.suptitle(f'name={path.stem} \n cut-off threshold = {args.threshold}', fontsize="x-large")
ax = fig.add_subplot(231)
ax.imshow(img_1)
ax = fig.add_subplot(232)
ax.imshow(prob_map_viz_patch)
ax = fig.add_subplot(233)
ax.imshow(img_1)
ax.imshow(prob_map_viz_patch, alpha=0.4)
prob_map_viz_full = prob_map_full.copy()
prob_map_viz_full[prob_map_viz_full < args.threshold] = 0.0
ax = fig.add_subplot(234)
ax.imshow(img_0)
ax = fig.add_subplot(235)
ax.imshow(prob_map_viz_full)
ax = fig.add_subplot(236)
ax.imshow(img_0)
ax.imshow(prob_map_viz_full, alpha=0.4)
plt.savefig(join(args.out_viz_dir, f'{path.stem}.jpg'), dpi=500)
plt.close('all')
gc.collect()
Here is the code at 'utils.py':
import json
from datetime import datetime
from pathlib import Path
import random
import numpy as np
import torch
import tqdm
from unet.unet_transfer import UNet16, UNetResNet
class AverageMeter(object):
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def cuda(x):
#return x.cuda(async=True) if torch.cuda.is_available() else x
return x.cuda(non_blocking=True) if torch.cuda.is_available() else x
def write_event(log, step, **data):
data['step'] = step
data['dt'] = datetime.now().isoformat()
log.write(json.dumps(data, sort_keys=True))
log.write('\n')
log.flush()
def check_crop_size(image_height, image_width):
"""Checks if image size divisible by 32.
Args:
image_height:
image_width:
Returns:
True if both height and width divisible by 32 and False otherwise.
"""
return image_height % 32 == 0 and image_width % 32 == 0
def create_model(device, type ='vgg16'):
assert type == 'vgg16' or type == 'resnet101'
if type == 'vgg16':
model = UNet16(pretrained=True)
elif type == 'resnet101':
model = UNetResNet(pretrained=True, encoder_depth=101, num_classes=1)
else:
assert False
model.eval()
return model.to(device)
def load_unet_vgg16(model_path):
model = UNet16(pretrained=True)
#model = UNet16(pretrained=False)
checkpoint = torch.load(model_path)
if 'model' in checkpoint:
model.load_state_dict(checkpoint['model'])
elif 'state_dict' in checkpoint:
model.load_state_dict(checkpoint['check_point'])
else:
raise Exception('undefind model format')
model.cuda()
model.eval()
return model
def load_unet_resnet_101(model_path):
#model = UNetResNet(pretrained=True, encoder_depth=101, num_classes=1)
model = UNetResNet(pretrained=True, encoder_depth=101, num_classes=8)
checkpoint = torch.load(model_path)
if 'model' in checkpoint:
model.load_state_dict(checkpoint['model'])
elif 'state_dict' in checkpoint:
model.load_state_dict(checkpoint['check_point'])
else:
raise Exception('undefind model format')
model.cuda()
model.eval()
return model
def load_unet_resnet_34(model_path):
model = UNetResNet(pretrained=True, encoder_depth=34, num_classes=1)
checkpoint = torch.load(model_path)
if 'model' in checkpoint:
model.load_state_dict(checkpoint['model'])
elif 'state_dict' in checkpoint:
model.load_state_dict(checkpoint['check_point'])
else:
raise Exception('undefind model format')
model.cuda()
model.eval()
return model
def train(args, model, criterion, train_loader, valid_loader, validation, init_optimizer, n_epochs=None, fold=None,
num_classes=None):
lr = args.lr
n_epochs = n_epochs or args.n_epochs
optimizer = init_optimizer(lr)
root = Path(args.model_path)
model_path = root / 'model_{fold}.pt'.format(fold=fold)
if model_path.exists():
state = torch.load(str(model_path))
epoch = state['epoch']
step = state['step']
model.load_state_dict(state['model'])
print('Restored model, epoch {}, step {:,}'.format(epoch, step))
else:
epoch = 1
step = 0
save = lambda ep: torch.save({
'model': model.state_dict(),
'epoch': ep,
'step': step,
}, str(model_path))
report_each = 10
log = root.joinpath('train_{fold}.log'.format(fold=fold)).open('at', encoding='utf8')
valid_losses = []
for epoch in range(epoch, n_epochs + 1):
model.train()
random.seed()
tq = tqdm.tqdm(total=(len(train_loader) * args.batch_size))
tq.set_description('Epoch {}, lr {}'.format(epoch, lr))
losses = []
tl = train_loader
try:
mean_loss = 0
for i, (inputs, targets) in enumerate(tl):
inputs = cuda(inputs)
with torch.no_grad():
targets = cuda(targets)
outputs = model(inputs)
#print(outputs.shape, targets.shape)
loss = criterion(outputs, targets)
optimizer.zero_grad()
batch_size = inputs.size(0)
loss.backward()
optimizer.step()
step += 1
tq.update(batch_size)
losses.append(loss.item())
mean_loss = np.mean(losses[-report_each:])
tq.set_postfix(loss='{:.5f}'.format(mean_loss))
if i and i % report_each == 0:
write_event(log, step, loss=mean_loss)
write_event(log, step, loss=mean_loss)
tq.close()
save(epoch + 1)
valid_metrics = validation(model, criterion, valid_loader, num_classes)
write_event(log, step, **valid_metrics)
valid_loss = valid_metrics['valid_loss']
valid_losses.append(valid_loss)
except KeyboardInterrupt:
tq.close()
print('Ctrl+C, saving snapshot')
save(epoch)
print('done.')
return
Here is the code at 'unet_transfer.py':
from torch import nn
from torch.nn import functional as F
import torch
from torchvision import models
import torchvision
input_size = (448, 448)
class Interpolate(nn.Module):
def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=False):
super(Interpolate, self).__init__()
self.interp = nn.functional.interpolate
self.size = size
self.mode = mode
self.scale_factor = scale_factor
self.align_corners = align_corners
def forward(self, x):
x = self.interp(x, size=self.size, scale_factor=self.scale_factor,
mode=self.mode, align_corners=self.align_corners)
return x
def conv3x3(in_, out):
return nn.Conv2d(in_, out, 3, padding=1)
class ConvRelu(nn.Module):
def __init__(self, in_, out):
super().__init__()
self.conv = conv3x3(in_, out)
self.activation = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.activation(x)
return x
class DecoderBlockV2(nn.Module):
def __init__(self, in_channels, middle_channels, out_channels, is_deconv=True):
super(DecoderBlockV2, self).__init__()
self.in_channels = in_channels
if is_deconv:
"""
Paramaters for Deconvolution were chosen to avoid artifacts, following
link https://distill.pub/2016/deconv-checkerboard/
"""
#self.block = nn.ModuleList(
self.block = nn.Sequential(
ConvRelu(in_channels, middle_channels),
nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=4, stride=2,
padding=1),
nn.ReLU(inplace=True)
)
else:
self.block = nn.Sequential(
Interpolate(scale_factor=2, mode='bilinear'),
ConvRelu(in_channels, middle_channels),
ConvRelu(middle_channels, out_channels),
)
def forward(self, x):
return self.block(x)
class UNet16(nn.Module):
def __init__(self, num_classes=1, num_filters=32, pretrained=False, is_deconv=False):
#def __init__(self, num_classes=8, num_filters=32, pretrained=False, is_deconv=False):
"""
:param num_classes:
:param num_filters:
:param pretrained:
False - no pre-trained network used
True - encoder pre-trained with VGG16
:is_deconv:
False: bilinear interpolation is used in decoder
True: deconvolution is used in decoder
"""
super().__init__()
self.num_classes = num_classes
self.pool = nn.MaxPool2d(2, 2)
#print(torchvision.models.vgg16(pretrained=pretrained))
self.encoder = torchvision.models.vgg16(pretrained=pretrained).features
#self.encoder = torchvision.models.vgg16(pretrained=False).features
self.relu = nn.ReLU(inplace=True)
self.conv1 = nn.Sequential(self.encoder[0],
self.relu,
self.encoder[2],
self.relu)
self.conv2 = nn.Sequential(self.encoder[5],
self.relu,
self.encoder[7],
self.relu)
self.conv3 = nn.Sequential(self.encoder[10],
self.relu,
self.encoder[12],
self.relu,
self.encoder[14],
self.relu)
self.conv4 = nn.Sequential(self.encoder[17],
self.relu,
self.encoder[19],
self.relu,
self.encoder[21],
self.relu)
self.conv5 = nn.Sequential(self.encoder[24],
self.relu,
self.encoder[26],
self.relu,
self.encoder[28],
self.relu)
self.center = DecoderBlockV2(512, num_filters * 8 * 2, num_filters * 8, is_deconv)
self.dec5 = DecoderBlockV2(512 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
self.dec4 = DecoderBlockV2(512 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
self.dec3 = DecoderBlockV2(256 + num_filters * 8, num_filters * 4 * 2, num_filters * 2, is_deconv)
self.dec2 = DecoderBlockV2(128 + num_filters * 2, num_filters * 2 * 2, num_filters, is_deconv)
self.dec1 = ConvRelu(64 + num_filters, num_filters)
self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)
def forward(self, x):
conv1 = self.conv1(x)
conv2 = self.conv2(self.pool(conv1))
conv3 = self.conv3(self.pool(conv2))
conv4 = self.conv4(self.pool(conv3))
conv5 = self.conv5(self.pool(conv4))
center = self.center(self.pool(conv5))
dec5 = self.dec5(torch.cat([center, conv5], 1))
dec4 = self.dec4(torch.cat([dec5, conv4], 1))
dec3 = self.dec3(torch.cat([dec4, conv3], 1))
dec2 = self.dec2(torch.cat([dec3, conv2], 1))
dec1 = self.dec1(torch.cat([dec2, conv1], 1))
if self.num_classes > 1:
x_out = F.log_softmax(self.final(dec1), dim=1)
else:
x_out = self.final(dec1)
#x_out = F.sigmoid(x_out)
return x_out
class UNetResNet(nn.Module):
def __init__(self, encoder_depth, num_classes, num_filters=32, dropout_2d=0.2,
pretrained=False, is_deconv=False):
super().__init__()
self.num_classes = num_classes
self.dropout_2d = dropout_2d
if encoder_depth == 34:
self.encoder = torchvision.models.resnet34(pretrained=pretrained)
bottom_channel_nr = 512
elif encoder_depth == 101:
self.encoder = torchvision.models.resnet101(pretrained=pretrained)
bottom_channel_nr = 2048
elif encoder_depth == 152:
self.encoder = torchvision.models.resnet152(pretrained=pretrained)
bottom_channel_nr = 2048
else:
raise NotImplementedError('only 34, 101, 152 version of Resnet are implemented')
self.pool = nn.MaxPool2d(2, 2)
self.relu = nn.ReLU(inplace=True)
#self.conv1 = nn.Sequential(self.encoder.conv1,
# self.encoder.bn1,
# self.encoder.relu,
# self.pool)
self.conv1 = nn.Sequential(nn.Conv2d(1,64,kernel_size=(7,7),stride=(2,2),padding=(3,3),bias=False), # 1 Here is for grayscale images, replace by 3 if you need RGB/BGR
nn.BatchNorm2d(64),
nn.ReLU(),
self.pool
)
self.conv2 = self.encoder.layer1
self.conv3 = self.encoder.layer2
self.conv4 = self.encoder.layer3
self.conv5 = self.encoder.layer4
self.center = DecoderBlockV2(bottom_channel_nr, num_filters * 8 * 2, num_filters * 8, is_deconv)
self.dec5 = DecoderBlockV2(bottom_channel_nr + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
self.dec4 = DecoderBlockV2(bottom_channel_nr // 2 + num_filters * 8, num_filters * 8 * 2, num_filters * 8,
is_deconv)
self.dec3 = DecoderBlockV2(bottom_channel_nr // 4 + num_filters * 8, num_filters * 4 * 2, num_filters * 2,
is_deconv)
self.dec2 = DecoderBlockV2(bottom_channel_nr // 8 + num_filters * 2, num_filters * 2 * 2, num_filters * 2 * 2,
is_deconv)
self.dec1 = DecoderBlockV2(num_filters * 2 * 2, num_filters * 2 * 2, num_filters, is_deconv)
self.dec0 = ConvRelu(num_filters, num_filters)
self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)
#self.final = nn.Conv2d(num_filters, 1, kernel_size=1)
def forward(self, x):
conv1 = self.conv1(x)
conv2 = self.conv2(conv1)
conv3 = self.conv3(conv2)
conv4 = self.conv4(conv3)
conv5 = self.conv5(conv4)
pool = self.pool(conv5)
center = self.center(pool)
dec5 = self.dec5(torch.cat([center, conv5], 1))
dec4 = self.dec4(torch.cat([dec5, conv4], 1))
dec3 = self.dec3(torch.cat([dec4, conv3], 1))
dec2 = self.dec2(torch.cat([dec3, conv2], 1))
dec1 = self.dec1(dec2)
dec0 = self.dec0(dec1)
return self.final(F.dropout2d(dec0, p=self.dropout_2d))
'''
class UNetResNet(nn.Module):
"""PyTorch U-Net model using ResNet(34, 101 or 152) encoder.
UNet: https://arxiv.org/abs/1505.04597
ResNet: https://arxiv.org/abs/1512.03385
Proposed by Alexander Buslaev: https://www.linkedin.com/in/al-buslaev/
Args:
encoder_depth (int): Depth of a ResNet encoder (34, 101 or 152).
num_classes (int): Number of output classes.
num_filters (int, optional): Number of filters in the last layer of decoder. Defaults to 32.
dropout_2d (float, optional): Probability factor of dropout layer before output layer. Defaults to 0.2.
pretrained (bool, optional):
False - no pre-trained weights are being used.
True - ResNet encoder is pre-trained on ImageNet.
Defaults to False.
is_deconv (bool, optional):
False: bilinear interpolation is used in decoder.
True: deconvolution is used in decoder.
Defaults to False.
"""
def __init__(self, encoder_depth, num_classes, num_filters=32, dropout_2d=0.2,
pretrained=False, is_deconv=False):
super().__init__()
self.num_classes = num_classes
self.dropout_2d = dropout_2d
if encoder_depth == 34:
self.encoder = torchvision.models.resnet34(pretrained=pretrained)
bottom_channel_nr = 512
elif encoder_depth == 101:
self.encoder = torchvision.models.resnet101(pretrained=pretrained)
bottom_channel_nr = 2048
elif encoder_depth == 152:
self.encoder = torchvision.models.resnet152(pretrained=pretrained)
bottom_channel_nr = 2048
else:
raise NotImplementedError('only 34, 101, 152 version of Resnet are implemented')
self.pool = nn.MaxPool2d(2, 2)
self.relu = nn.ReLU(inplace=True)
self.conv1 = nn.Sequential(self.encoder.conv1,
self.encoder.bn1,
self.encoder.relu,
self.pool)
self.conv2 = self.encoder.layer1
self.conv3 = self.encoder.layer2
self.conv4 = self.encoder.layer3
self.conv5 = self.encoder.layer4
self.center = DecoderBlockV2(bottom_channel_nr, num_filters * 8 * 2, num_filters * 8, is_deconv)
self.dec5 = DecoderBlockV2(bottom_channel_nr + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
self.dec4 = DecoderBlockV2(bottom_channel_nr // 2 + num_filters * 8, num_filters * 8 * 2, num_filters * 8,
is_deconv)
self.dec3 = DecoderBlockV2(bottom_channel_nr // 4 + num_filters * 8, num_filters * 4 * 2, num_filters * 2,
is_deconv)
self.dec2 = DecoderBlockV2(bottom_channel_nr // 8 + num_filters * 2, num_filters * 2 * 2, num_filters * 2 * 2,
is_deconv)
self.dec1 = DecoderBlockV2(num_filters * 2 * 2, num_filters * 2 * 2, num_filters, is_deconv)
self.dec0 = ConvRelu(num_filters, num_filters)
self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)
def forward(self, x):
self.conv1 = torch.nn.Conv2d(1, 64, (7, 7), (2, 2), (3, 3), bias=False)
conv1 = self.conv1(x)
conv2 = self.conv2(conv1)
conv3 = self.conv3(conv2)
conv4 = self.conv4(conv3)
conv5 = self.conv5(conv4)
pool = self.pool(conv5)
center = self.center(pool)
dec5 = self.dec5(torch.cat([center, conv5], 1))
dec4 = self.dec4(torch.cat([dec5, conv4], 1))
dec3 = self.dec3(torch.cat([dec4, conv3], 1))
dec2 = self.dec2(torch.cat([dec3, conv2], 1))
dec1 = self.dec1(dec2)
dec0 = self.dec0(dec1)
return self.final(F.dropout2d(dec0, p=self.dropout_2d))
'''
Solution
In 'utils.py', 'load_unet_vgg16' function receives a path to a checkpoint to load the model from, so the initialization of the weights step (which happens before loading from the checkpoint) is unnecessary. The function 'load_unet_vgg16' could be as follows:
def load_unet_vgg16(model_path, pretrained=True):
model = UNet16(pretrained=pretrained)
checkpoint = torch.load(model_path)
if 'model' in checkpoint:
model.load_state_dict(checkpoint['model'])
elif 'state_dict' in checkpoint:
model.load_state_dict(checkpoint['check_point'])
else:
raise Exception('undefind model format')
model.cuda()
model.eval()
return model
Then you pass a path to a checkpoint and 'pretrained=False' since you are making an inference.
Answered By - Hatem
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.