Issue
Today I suddenly started getting this error for no apparent reason, while I was running model.fit()
. This used to work before, I am using TF 2.3.0, more specifically its Keras module.
The function is called on validation inside a generator, which is fed into model.predict()
.
Basically, I load a checkpoint, I resume training the network, and I make a prediction on validation.
The error keeps occurring even when training a model from scratch, and erasing all the related data. It's like if something has been hardcoded, somewhere, as I was able to run model.fit()
up until a few hours ago.
I saw several solutions like THIS, but none of these variations really work for me, as they lead to more tricky error messages.
I even tried installing a different version of TF, thinking that this was due to some old version, but the error still occurs.
Solution
I will answer my own question, as this one was particularly tricky and none of the solutions I found on the internet has worked for me, probably because outdated.
I'll write down just the relevant part to add in the code, feel free to add more technical explanations.
I like using args
for passing variables, but it can work without:
from tensorflow.python.keras.backend import set_session
from tensorflow.keras.models import load_model
import generator # custom generator
def main(args):
# open new session and define TF graph
args.sess = tf.compat.v1.Session()
args.graph = tf.compat.v1.get_default_graph()
set_session(args.sess)
# define training generator
train_generator = generator(args.train_data)
# load model
args.model = load_model(args.model_path)
args.model.fit(train_generator)
Then, in the model prediction function:
# In my specific case, the predict_output() function is
# called inside the generator function
def predict_output(args, x):
with args.graph.as_default():
set_session(args.sess)
y = model.predict(x)
return y
Answered By - Phys
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.