Issue
I'm interested in visualizing information from a particular layer of the model. In this instance I'm using a pytorch model for ResNet18 source code for which can be found here.
Essentially the idea is to get the information each layer has for any input image that it is being trained on, and reconstruct the input image with the information that particular Conv Layer contains for the input image with the feature maps. For example, if a convolutional layer with with Nth filter corresponding to a Dogs Ear, I'd like to be able to view which CNN layer corresponds to which attribute of the image.
A given input image vector x is encoded in each layer of the CNN by the filter responses to that image. A layer with N distinct filters has Nl feature maps each of size Ml, where Ml is the height times the width of the feature map.
Passing the data via the pytorch dataloaders:
data_transforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
data_dir = '/content/drive/MyDrive/Colab Notebooks/Animal Data/'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
data_transforms[x])
for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
shuffle=True, num_workers=4)
for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
I'm training the model here:
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
since = time.time()
best_model_wts = copy.deepcopy(model.state_dict())
best_acc = 0.0
for epoch in range(num_epochs):
print(f'Epoch {epoch}/{num_epochs - 1}')
print('-' * 10)
# Each epoch has a training and validation phase
for phase in ['train', 'val']:
if phase == 'train':
model.train() # Set model to training mode
else:
model.eval() # Set model to evaluate mode
running_loss = 0.0
running_corrects = 0
# Iterate over data.
for inputs, labels in dataloaders[phase]:
inputs = inputs.to(device)
labels = labels.to(device)
# zero the parameter gradients
optimizer.zero_grad()
# forward
# track history if only in train
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
# backward + optimize only if in training phase
if phase == 'train':
loss.backward()
optimizer.step()
# statistics
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
if phase == 'train':
scheduler.step()
epoch_loss = running_loss / dataset_sizes[phase]
epoch_acc = running_corrects.double() / dataset_sizes[phase]
print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
# deep copy the model
if phase == 'val' and epoch_acc > best_acc:
best_acc = epoch_acc
best_model_wts = copy.deepcopy(model.state_dict())
print()
time_elapsed = time.time() - since
print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
print(f'Best val Acc: {best_acc:4f}')
# load best model weights
model.load_state_dict(best_model_wts)
return model
I understand that I likely have to work with the model itself, but I'm lost on how I'd begin doing that. Any tips, or sources are highly appreciated.
Solution
I suggest that you search and read about PyTorch hooks, you can use hooks to observe the input and the output of any layer in the network, and then you can call a function to construct what you want. You can start by reading the documentation about it, you can find it here. The idea is that you should hook a function to a layer and this function will receive the input and the output of this layer and inside this function, you will write your code to construct what you want.
Answered By - Hatem
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.