Issue
I am using the following code example:
Using the autoencoder, I want to display the recontracted image. How to display it?
from transformers import AutoImageProcessor, ViTMAEModel
from PIL import Image
import requests
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
image_processor = AutoImageProcessor.from_pretrained("facebook/vit-mae-base")
model = ViTMAEModel.from_pretrained("facebook/vit-mae-base")
inputs = image_processor(images=image, return_tensors="pt")
outputs = model(**inputs)
last_hidden_states = outputs.last_hidden_state
Solution
I encountered the same issue. According to the official doc of ViTMAE, please have a look at ViT_MAE_visualization_demo.ipynb.
import torch
import numpy as np
import matplotlib.pyplot as plt
from transformers import ViTFeatureExtractor, ViTMAEForPreTraining
import requests
from PIL import Image
imagenet_mean = np.array(feature_extractor.image_mean)
imagenet_std = np.array(feature_extractor.image_std)
def show_image(image, title=''):
# image is [H, W, 3]
assert image.shape[2] == 3
plt.imshow(torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int())
plt.title(title, fontsize=16)
plt.axis('off')
return
def visualize(pixel_values, model):
# forward pass
outputs = model(pixel_values)
y = model.unpatchify(outputs.logits)
y = torch.einsum('nchw->nhwc', y).detach().cpu()
# visualize the mask
mask = outputs.mask.detach()
mask = mask.unsqueeze(-1).repeat(1, 1, model.config.patch_size**2 *3) # (N, H*W, p*p*3)
mask = model.unpatchify(mask) # 1 is removing, 0 is keeping
mask = torch.einsum('nchw->nhwc', mask).detach().cpu()
x = torch.einsum('nchw->nhwc', pixel_values)
# masked image
im_masked = x * (1 - mask)
# MAE reconstruction pasted with visible patches
im_paste = x * (1 - mask) + y * mask
# make the plt figure larger
plt.rcParams['figure.figsize'] = [24, 24]
plt.subplot(1, 4, 1)
show_image(x[0], "original")
plt.subplot(1, 4, 2)
show_image(im_masked[0], "masked")
plt.subplot(1, 4, 3)
show_image(y[0], "reconstruction")
plt.subplot(1, 4, 4)
show_image(im_paste[0], "reconstruction + visible")
plt.show()
feature_extractor = ViTFeatureExtractor.from_pretrained("facebook/vit-mae-base")
url = "https://user-images.githubusercontent.com/11435359/147738734-196fd92f-9260-48d5-ba7e-bf103d29364d.jpg"
image = Image.open(requests.get(url, stream=True).raw)
pixel_values = feature_extractor(image, return_tensors="pt").pixel_values
# make random mask reproducible (comment out to make it change)
torch.manual_seed(2)
model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base")
visualize(pixel_values, model)
Answered By - Dedog Xu
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.