Issue
Note: this question has an accompanying, documented Colab notebook.
TensorFlow's documentation can, at times, leave a lot to be desired. Some of the older docs for lower level apis seem to have been expunged, and most newer documents point towards using higher level apis such as TensorFlow's subset of keras
or estimators
. This would not be so problematic if the higher level apis did not so often rely closely on their lower levels. Case in point, estimators
(especially the input_fn
when using TensorFlow Records).
Over the following Stack Overflow posts:
- Tensorflow v1.10: store images as byte strings or per channel?
- Tensorflow 1.10 TFRecordDataset - recovering TFRecords
- Tensorflow v1.10+ why is an input serving receiver function needed when checkpoints are made without it?
- TensorFlow 1.10+ custom estimator early stopping with train_and_evaluate
- TensorFlow custom estimator stuck when calling evaluate after training
and with the gracious assistance of the TensorFlow / StackOverflow community, we have moved closer to doing what the TensorFlow "Creating Custom Estimators" guide has not, demonstrating how to make an estimator one might actually use in practice (rather than toy example) e.g. one which:
- has a validation set for early stopping if performance worsen,
- reads from TF Records because many datasets are larger than the TensorFlow recommend 1Gb for in memory, and
- that saves its best version whilst training
While I still have many questions regarding this (from the best way to encode data into a TF Record, to what exactly the serving_input_fn
expects), there is one question that stands out more prominently than the rest:
How to predict with the custom estimator we just made?
Under the documentation for predict, it states:
input_fn
: A function that constructs the features. Prediction continues untilinput_fn
raises an end-of-input exception (tf.errors.OutOfRangeError
orStopIteration
). See Premade Estimators for more information. The function should construct and return one of the following:
- A tf.data.Dataset object: Outputs of Dataset object must have same constraints as below.
- features: A tf.Tensor or a dictionary of string feature name to Tensor. features are consumed by model_fn. They should satisfy the expectation of model_fn from inputs.
- A tuple, in which case the first item is extracted as features.
(perhaps) Most likely, if one is using estimator.predict
, they are using data in memory such as a dense tensor (because a held out test set would likely go through evaluate
).
So I, in the accompanying Colab, create a single dense example, wrap it up in a tf.data.Dataset
, and call predict
to get a ValueError
.
I would greatly appreciate it if someone could explain to me how I can:
- load my saved estimator
- given a dense, in memory example, predict the output with the estimator
Solution
to_predict = random_onehot((1, SEQUENCE_LENGTH, SEQUENCE_CHANNELS))\
.astype(tf_type_string(I_DTYPE))
pred_features = {'input_tensors': to_predict}
pred_ds = tf.data.Dataset.from_tensor_slices(pred_features)
predicted = est.predict(lambda: pred_ds, yield_single_examples=True)
next(predicted)
ValueError: Tensor("IteratorV2:0", shape=(), dtype=resource) must be from the same graph as Tensor("TensorSliceDataset:0", shape=(), dtype=variant).
When you use the tf.data.Dataset
module, it actually defines an input graph which is independant from the model graph. What happens here is that you first created a small graph by calling tf.data.Dataset.from_tensor_slices()
, then the estimator API created a second graph by calling dataset.make_one_shot_iterator()
automatically. These 2 graphs can't communicate so it throws an error.
To circumvent this, you should never create a dataset outside of estimator.train/evaluate/predict. This is why everything data related is wrapped inside input functions.
def predict_input_fn(data, batch_size=1):
dataset = tf.data.Dataset.from_tensor_slices(data)
return dataset.batch(batch_size).prefetch(None)
predicted = est.predict(lambda: predict_input_fn(pred_features), yield_single_examples=True)
next(predicted)
Now, the graph is not created outside of the predict call.
I also added dataset.batch()
because the rest of your code expect batched data and it was throwing a shape error. Prefetch just speed things up.
Answered By - Olivier Dehaene
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.