Issue
I have a subclassed model with some custom attributes like this:
class MyModel(tf.keras.Model):
def __init__(self, *args, my_var, **kwargs):
super().__init__(*args, **kwargs)
self.my_var = my_var
def my_func(self):
pass
def get_config(self):
config = super().get_config()
config.update(
{
"my_var": self.my_var
}
)
return config
Now I define model and clone it with clone_model
x_in = layers.Input(shape=(100, 100, 3))
x_out = layers.Conv2D(filters=16, kernel_size=3, activation="relu")(x_in)
model = MyModel(inputs=x_in, outputs=x_out, my_var="my_var")
cloned = tf.keras.models.clone_model(model)
print(cloned.my_var)
The model is copied ok, but without my_var
Is there any way to copy this type of model properly with all attributes (my_var and my_func)?
Solution
You need to add
cloned = model.__class__.from_config(model.get_config())
as shown in the doc https://www.tensorflow.org/api_docs/python/tf/keras/models/clone_model#example
Answered By - Alexandre Catalano
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.