Issue
I have a Train
and Validation
Batch dataset:
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
train_path,
label_mode = 'categorical', #it is used for multiclass classification. It is one hot encoded labels for each class
validation_split = 0.2, #percentage of dataset to be considered for validation
subset = "training", #this subset is used for training
seed = 1337, # seed is set so that same results are reproduced
image_size = img_size, # shape of input images
batch_size = batch_size, # This should match with model batch size
)
valid_ds = tf.keras.preprocessing.image_dataset_from_directory(
train_path,
label_mode ='categorical',
validation_split = 0.2,
subset = "validation", #this subset is used for validation
seed = 1337,
image_size = img_size,
batch_size = batch_size,
)
I was trying to display 9 images to show what they looked like, which i managed, but i cant seem to be able to plot their respective label.
Here is the code:
class_names = train_ds.class_names
plt.figure(figsize=(10, 10))
for images, labels in train_ds.take(1):
for i in range(9):
ax = plt.subplot(3, 3, i + 1)
plt.imshow(images[i].numpy().astype("uint8"))
plt.axis("off")
Which displays this:
If i try to get the labels by adding: plt.title(class_names[labels[i]])
I get the following error: TypeError: only integer scalar arrays can be converted to a scalar index
I have tried solutions from other posts like the following plt.title(class_names[labels[i][0]])
but without any success.
When i print the labels[i] i get one hot encoding of the labels...maybe this is why?
The result code:
plt.figure(figsize=(10, 10))
for images, labels in train_ds.take(1):
for i in range(9):
ax = plt.subplot(3, 3, i + 1)
plt.imshow(images[i].numpy().astype("uint8"))
plt.title(class_names[np.argmax(labels[i], axis=None, out=None)])
plt.axis("off")
Solution
Try the below code:
plt.figure(figsize=(10, 10))
for images, labels in train_ds.take(1):
for i in range(9):
ax = plt.subplot(3, 3, i + 1)
plt.imshow(images[i].numpy().astype("uint8"))
plt.title(class_names[labels[i]])
plt.
plt.axis("off")
Answered By - user17651088
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.