Issue
In the manual on the Dataset class in Tensorflow, it shows how to shuffle the data and how to batch it. However, it's not apparent how one can shuffle the data each epoch. I've tried the below, but the data is given in exactly the same order the second epoch as in the first. Does anybody know how to shuffle between epochs using a Dataset?
n_epochs = 2
batch_size = 3
data = tf.contrib.data.Dataset.range(12)
data = data.repeat(n_epochs)
data = data.batch(batch_size)
next_batch = data.make_one_shot_iterator().get_next()
sess = tf.Session()
for _ in range(4):
print(sess.run(next_batch))
print("new epoch")
data = data.shuffle(12)
for _ in range(4):
print(sess.run(next_batch))
Solution
My environment: Python 3.6, TensorFlow 1.4.
TensorFlow has added Dataset
into tf.data
.
You should be cautious with the position of data.shuffle
. In your code, the epochs of data has been put into the dataset
's buffer before your shuffle
. Here is two usable examples to shuffle dataset.
shuffle all elements
# shuffle all elements
import tensorflow as tf
n_epochs = 2
batch_size = 3
buffer_size = 5
dataset = tf.data.Dataset.range(12)
dataset = dataset.shuffle(buffer_size=buffer_size)
dataset = dataset.batch(batch_size)
dataset = dataset.repeat(n_epochs)
iterator = dataset.make_one_shot_iterator()
next_batch = iterator.get_next()
sess = tf.Session()
print("epoch 1")
for _ in range(4):
print(sess.run(next_batch))
print("epoch 2")
for _ in range(4):
print(sess.run(next_batch))
OUTPUT:
epoch 1
[1 4 5]
[3 0 7]
[6 9 8]
[10 2 11]
epoch 2
[2 0 6]
[1 7 4]
[5 3 8]
[11 9 10]
shuffle between batches, not shuffle in a batch
# shuffle between batches, not shuffle in a batch
import tensorflow as tf
n_epochs = 2
batch_size = 3
buffer_size = 5
dataset = tf.data.Dataset.range(12)
dataset = dataset.batch(batch_size)
dataset = dataset.repeat(n_epochs)
dataset = dataset.shuffle(buffer_size=buffer_size)
iterator = dataset.make_one_shot_iterator()
next_batch = iterator.get_next()
sess = tf.Session()
print("epoch 1")
for _ in range(4):
print(sess.run(next_batch))
print("epoch 2")
for _ in range(4):
print(sess.run(next_batch))
OUTPUT:
epoch 1
[0 1 2]
[6 7 8]
[3 4 5]
[6 7 8]
epoch 2
[3 4 5]
[0 1 2]
[ 9 10 11]
[ 9 10 11]
Answered By - William
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.