Issue
How can I get the summary of customed tensorflow Model?
class Discriminator_block(tf.keras.layers.Layer):
def __init__(self, num_strides):
super(Discriminator_block, self).__init__(name='discriminator block')
self.num_strides = num_strides
self.conv1 = tf.keras.layers.Conv2D(filters=128, kernel_size=(3,3), strides=(num_strides, num_strides), padding='same', data_format='channels_first', activation=None)
self.bn1 = tf.keras.layers.BatchNormalization(axis=1)
self.leaky = keras.layers.advanced_activations.LeakyReLU()
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.leaky(x)
return x
I have coded my own discriminator blocks by using tensorflow and I want to see my models' summary
So I added
Discriminator_block.summary()
but I get a error:
'Discriminator_block' object has no attribute 'summary'
Which mistake did I make in my code?
Solution
You need to change and consider multiple steps in your code to print a summary
of your custom block.
- Important thing is to change
tf.keras.layers.Layer
totf.keras.Model
- Create an instance of your model.
- Input a random tensor to the model that your model is building then you can get the
summary
of the model.
import tensorflow as tf
class Discriminator_block(tf.keras.Model):
def __init__(self, num_strides):
super(Discriminator_block, self).__init__(name='discriminator block')
self.num_strides = num_strides
self.conv1 = tf.keras.layers.Conv2D(filters=16, kernel_size=(3,3),
strides=(num_strides, num_strides),
padding='same',
activation='relu',
input_shape=(28,28,3))
self.bn1 = tf.keras.layers.BatchNormalization(axis=1)
self.leaky = tf.keras.layers.LeakyReLU()
def call(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.leaky(x)
return x
block = Discriminator_block(num_strides = 1)
_ = block(tf.random.normal(shape=[2, 28, 28, 3]))
block.summary()
Model: "discriminator block"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d_1 (Conv2D) multiple 448
batch_normalization_1 (Batc multiple 112
hNormalization)
leaky_re_lu_1 (LeakyReLU) multiple 0
=================================================================
Total params: 560
Trainable params: 504
Non-trainable params: 56
_________________________________________________________________
Answered By - I'mahdi
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.