Issue
I created a custom Keras Layer
with the purpose of wrapping another Layer
so that I can have some additional computation before and/or after it. Here is some pseudo code:
class Wrapper(Layer):
def __init__(self, layer: Layer, **kwargs):
super(Wrapper, self).__init__(**kwargs)
self._layer = layer
def build(self, input_shape):
self._layer.build(input_shape)
# i.e. following is just an example
self._after_layer = Dense(10)
def call(self, x, **kwargs):
y = self._layer(x)
y = self._after_layer(y)
return y
def compute_output_shape(self, input_shape):
return self._after_layer.compute_output_shape(self._layer.compute_output_shape(input_shape))
def get_config(self):
config = super().get_config().copy()
config.update({
'layer': self._layer,
})
return config
now you can use the above wrapper the following way:
y = Wrapper(SomeOtherLayer(...))(x)
everything works like a charm, but having a Layer
as input make impossible to save the model. Trying to save raise a "TypeError: Cannot convert ... to a TensorFlow DType" and it is triggered by the fact that I added 'layer': self._layer,
in the Layer
config.
Is there any workaraound or best way to achieve the same as above and also save/load the model?
Solution
Since you are wrapping a TF layer, you must serialize it in the get_config
.
You can also implement the classmethod from_config
which should be able to recreate the wrapper from the output of get_config
. For this, you will need to deserialize the layer wrapped. This will be useful in case you are saving the architecture and not only the weights.
Here is the full working code:
import tensorflow as tf
class Wrapper(tf.keras.layers.Layer):
def __init__(self, layer: tf.keras.layers.Layer, **kwargs):
super(Wrapper, self).__init__(**kwargs)
self._layer = layer
def build(self, input_shape):
self._layer.build(input_shape)
# i.e. following is just an example
self._after_layer = tf.keras.layers.Dense(10)
def call(self, x, **kwargs):
y = self._layer(x)
y = self._after_layer(y)
return y
def compute_output_shape(self, input_shape):
return self._after_layer.compute_output_shape(self._layer.compute_output_shape(input_shape))
def get_config(self):
config = super().get_config().copy()
config["layer"] = tf.keras.layers.serialize(self._layer)
return config
@classmethod
def from_config(cls, config):
layer = tf.keras.layers.deserialize(config.pop("layer"))
return cls(layer, **config)
Answered By - M. Perier--Dulhoste
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.