Issue
edited to address the comments
- added lines at the beginning where the data was imported from MNIST
- added the full error message from jupyter notebook as text
I am trying to implement a very simple code in python (jupyter notebook, if it matters):
from sklearn.datasets import fetch_openml
x, y = fetch_openml('mnist_784', version=1, return_X_y=True, data_home='./data/')
y = y.astype(int)
fig, ax = plt.subplots(2, 4, figsize=(20, 8))
for a in ax.ravel():
j = np.random.choice(len(y))
sns.heatmap(x[j].reshape(28,28), ax=a, cbar=False, cmap='gray_r')
a.set_title(f'Label: {y[j]}')
a.set_xticks([])
a.set_yticks([])
and I get the following error shown in the screenshot. I don't think this is a code problem, as this was taken directly from the lecturer's notes. Could anyone help me troubleshoot and enlighten me, please?
See error message below:
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
~/opt/anaconda3/lib/python3.8/site-packages/pandas/core/indexes/base.py in get_loc(self, key, method, tolerance)
3079 try:
-> 3080 return self._engine.get_loc(casted_key)
3081 except KeyError as err:
pandas/_libs/index.pyx in
pandas._libs.index.IndexEngine.get_loc()
pandas/_libs/index.pyx in
pandas._libs.index.IndexEngine.get_loc()
pandas/_libs/hashtable_class_helper.pxi in
pandas._libs.hashtable.PyObjectHashTable.get_item()
pandas/_libs/hashtable_class_helper.pxi in
pandas._libs.hashtable.PyObjectHashTable.get_item()
KeyError: 46220
The above exception was the direct cause of the following exception:
KeyError Traceback (most recent call last)
<ipython-input-6-02155e9f4730> in <module>
2 for a in ax.ravel():
3 j = np.random.choice(len(y))
----> 4 sns.heatmap(x[j].reshape(28,28), ax=a, cbar=False, cmap='gray_r')
5 a.set_title(f'Label: {y[j]}')
6 a.set_xticks([])
~/opt/anaconda3/lib/python3.8/site-packages/pandas/core/frame.py in __getitem__(self, key)
3022 if self.columns.nlevels > 1:
3023 return self._getitem_multilevel(key)
-> 3024 indexer = self.columns.get_loc(key)
3025 if is_integer(indexer):
3026 indexer = [indexer]
~/opt/anaconda3/lib/python3.8/site-packages/pandas/core/indexes/base.py in get_loc(self, key, method, tolerance)
3080 return self._engine.get_loc(casted_key)
3081 except KeyError as err:
-> 3082 raise KeyError(key) from err
3083
3084 if tolerance is not None:
KeyError: 46220
Solution
I suppose that with the line below you were trying to access the row j
of the pandas DataFrame x
:
sns.heatmap(x[j].reshape(28,28), ax=a, cbar=False, cmap='gray_r')
However in order to access the values of a row by name you should use x.iloc[j].values
instead. Lots of examples can be found here.
The complete code is:
from sklearn.datasets import fetch_openml
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
x, y = fetch_openml('mnist_784', version=1, return_X_y=True, data_home='./data/')
y = y.astype(int)
fig, ax = plt.subplots(2, 4, figsize=(20, 8))
for a in ax.ravel():
j = np.random.choice(len(y))
sns.heatmap(x.iloc[j].values.reshape(28,28), ax=a, cbar=False, cmap='gray_r')
a.set_title(f'Label: {y[j]}')
a.set_xticks([])
a.set_yticks([])
The result produced:
Answered By - ClaudiaR
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.