Issue
I am very confused by the following behavior. Take this program:
import tensorflow_datasets as tfds
# %% Train dataset
(ds_train_original, ds_test_original), ds_info = tfds.load(
"mnist",
split=["train", "test"],
shuffle_files=True,
as_supervised=True,
with_info=True,
)
iterator = iter(ds_train_original)
el = iterator.get_next()[0]
el[0].ref() == el[0].ref() # <- this should be True
The last line IMO should return True
. However, this is False
.
I cannot understand why.
According to the ref documentation:
Returns a hashable reference object to this Tensor. The primary use case for this API is to put tensors in a set/dictionary.
My understanding is that you should be able to use the ref() to check for equality between Tensor. Here the problem doesn't happen anymore once I have extracted the ref. For example, this is True:
a_ref = el[0].ref()
a_deref = a_ref.deref()
another_ref = a_deref.ref()
a_ref == another_ref
So the "problem" seems confined to extracting the ref() from iterator
.
Can anybody explain to me what is happening and why el[0].ref() == el[0].ref()
is False
?
Solution
After posting an issue on Github, it seems like the only viable solution is to compare the samples values, since only weakrefs are created.
Thus the solution is:
import tensorflow_datasets as tfds
# %% Train dataset
(ds_train_original, ds_test_original), ds_info = tfds.load(
"mnist",
split=["train", "test"],
shuffle_files=True,
as_supervised=True,
with_info=True,
)
iterator = iter(ds_train_original)
el = iterator.get_next()[0]
(el[0].numpy() == el[0].numpy()).all()
Answered By - Malcolm
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.