Issue
Assuming there is a model given as an h5
file, i.e., I can not change the code building the model's architecture:
from tensorflow.keras.layers import Input, BatchNormalization
from tensorflow.keras.models import Model
inputs = Input(shape=(4,))
outputs = BatchNormalization()(inputs, training=True)
model = Model(inputs=inputs, outputs=outputs)
model.save('model.h5', include_optimizer=False)
Now I'd like to remove the training=True
part, i.e., I want to the BatchNormalization
as if it was attached to the model without this flag.
My current attempt looks as follows:
import numpy as np
from tensorflow.keras.models import load_model
model = load_model('model.h5')
for layer in model.layers:
for node in layer.inbound_nodes:
if "training" in node.call_kwargs:
del node.call_kwargs["training"]
model.predict(np.asarray([[1, 2, 3, 4]]))
But the model.predict
calls fails with the following error (I'm using TensorFlow 2.5.0
):
ValueError: Could not pack sequence. Structure had 1 elements, but flat_sequence had 2 elements. Structure: ((<KerasTensor: shape=(None, 4) dtype=float32 (created by layer 'input_1')>,), {}), flat_sequence: [<tf.Tensor 'model/Cast:0' shape=(None, 4) dtype=float32>, True].
How can this be fixed/worked around?
(When using node.call_kwargs["training"] = False
instead of del node.call_kwargs["training"]
then model.predict
does not crash, but it simply behaves as if nothing was changed, i.e., the modified flag is ignored.)
Solution
I found simply saving and re-loading the model again after modifying the call_kwargs
helps.
import numpy as np
from tensorflow.keras.models import load_model
model = load_model('model.h5')
# Removing training=True
for layer in model.layers:
for node in layer.inbound_nodes:
if "training" in node.call_kwargs:
del node.call_kwargs["training"]
# The two following lines are the solution.
model.save('model_modified.h5', include_optimizer=False)
model = load_model('model_modified.h5')
model.predict(np.asarray([[1, 2, 3, 4]]))
And all is fine. :)
Answered By - Tobias Hermann
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.