Issue
I'm trying to convert the Torchvision MNIST train and test datasets into NumPy arrays but can't find documentation to actually perform the conversion.
My goal would be to take an entire dataset and convert it into a single NumPy array, preferably without iterating through the entire dataset.
I've looked at How do I turn a Pytorch Dataloader into a numpy array to display image data with matplotlib? but it doesn't address my issue.
So my question is, utilizing torch.utils.data.DataLoader
, how would I go about converting the datasets (train/test) into two NumPy arrays such that all of the examples are present?
Note: I've left the batch size as the default of 1 for now; I could set it to 60,000 for train and 10,000 for test, but I'd prefer to not use magic numbers of that sort.
Thank you.
Solution
If I understand you correctly, you want to get the whole train dataset of MNIST images (in total 60000 images, each image of size 1x28x28 array with 1 for color channel) as a numpy array of size (60000, 1, 28, 28)?
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# Transform to normalized Tensors
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST('./MNIST/', train=True, transform=transform, download=True)
# test_dataset = datasets.MNIST('./MNIST/', train=False, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=len(train_dataset))
# test_loader = DataLoader(test_dataset, batch_size=len(test_dataset))
train_dataset_array = next(iter(train_loader))[0].numpy()
# test_dataset_array = next(iter(test_loader))[0].numpy()
This is the result:
>>> train_dataset_array
array([[[[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
...,
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296]]],
[[[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
...,
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296]]],
[[[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
...,
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296]]],
...,
[[[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
...,
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296]]],
[[[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
...,
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296]]],
[[[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
...,
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296],
[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
-0.42421296, -0.42421296]]]], dtype=float32)
Edit: You can also get the labels by next(iter(train_loader))[1].numpy()
. Alternatively you can use train_dataset.data.numpy()
and train_dataset.targets.numpy()
, but note that the data will not be transformed by transform
as is done when using the dataloader.
Answered By - Andreas K.
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.