Issue
I have a data input pipeline that has:
- input datapoints of types that are not castable to a
tf.Tensor
(dicts and whatnot) - preprocessing functions that could not understand tensorflow types and need to work with those datapoints; some of which do data augmentation on the fly
I've been trying to fit this into a tf.data
pipeline, and I'm stuck on running the preprocessing for multiple datapoints in parallel. So far I've tried this:
- use
Dataset.from_generator(gen)
and do the preprocessing in the generator; this works but it processes each datapoint sequentially, no matter what arrangement ofprefetch
and fakemap
calls I patch on it. Is it impossible to prefetch in parallel? - encapsulate the preprocessing in a
tf.py_function
so I couldmap
it in parallel over my Dataset, but- this requires some pretty ugly (de)serialization to fit exotic types into string tensors,
- apparently the execution of the
py_function
would be handed over to the (single-process) python interpreter, so I'd be stuck with the python GIL which would not help me much
- I saw that you could do some tricks with
interleave
but haven't found any which does not have issues from the first two ideas.
Am I missing anything here? Am I forced to either modify my preprocessing so that it can run in a graph or is there a way to multiprocess it?
Our previous way of doing this was using keras.Sequence which worked well but there's just too many people pushing the upgrade to the tf.data
API. (hell, even trying the keras.Sequence with tf 2.2 yields WARNING:tensorflow:multiprocessing can interact badly with TensorFlow, causing nondeterministic deadlocks. For high performance data pipelines tf.data is recommended.
)
Note: I'm using tf 2.2rc3
Solution
I came across the same problem and found a (relatively) easy solution.
It turns out that the proper way to do so is indeed to first create a tf.data.Dataset
object using the from_generator(gen)
method, before applying your custom python processing function (wrapped within a py_function
) with the map
method. As you mentioned, there is a trick to avoid serialization / deserialization of the input.
The trick is to use a generator which will only generates the indexes of your training set. Each called training index will be passed to the wrapped py_function, which can in return evaluate your original dataset at that index. You can then process your datapoint and return your processed data to the rest of your tf.data
pipeline.
def func(i):
i = i.numpy() # decoding from the EagerTensor object
x, y = processing_function(training_set[i])
return x, y # numpy arrays of types uint8, float32
z = list(range(len(training_set))) # the index generator
dataset = tf.data.Dataset.from_generator(lambda: z, tf.uint8)
dataset = dataset.map(lambda i: tf.py_function(func=func, inp=[i],
Tout=[tf.uint8, tf.float32]),
num_parallel_calls=12)
dataset = dataset.batch(1)
Note that in practice, depending on the model you train your dataset on, you will probably need to apply another map
to your dataset after the batch
:
def _fixup_shape(x, y):
x.set_shape([None, None, None, nb_channels])
y.set_shape([None, nb_classes])
return x, y
dataset = dataset.map(_fixup_shape)
This is a known issue which seems to be due to the incapacity of the from_generator
method to infer the shape properly in some cases. Hence you need to pass the expected output shape explicitly. For more information:
- https://github.com/tensorflow/tensorflow/issues/32912
- as_list() is not defined on an unknown TensorShape on y_t_rank = len(y_t.shape.as_list()) and related to metrics)
Answered By - A. Cordier
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.