Issue
I am trying to develop a small feature that plots dynamically the loss or accuracy during the training of a Tensorflow model. I basically plot the history of accuracies at the end of each batch processing for each epoch (the code still needs some corrections but it works correctly for now).
I have a small problem, as I run the following code in a jupyter notebook cell. I have the desired behavior, with a plot that evolves dynamically. However at the end of the training the final plot is duplicated for some reason and I can't figure out why it is the case.
from IPython.display import display, clear_output
import tensorflow as tf
from tensorflow.keras.models import Sequential
import numpy as np
import matplotlib.pyplot as plt
class CustomCallback(tf.keras.callbacks.Callback):
def on_train_begin(self, logs=None):
self.epoch = 0 # Initialize the epoch counter
self.accuracies = []
self.fig, self.ax = plt.subplots()
self.line, = self.ax.plot([], [])
self.ax.set_xlim(0, 30)
self.ax.set_ylim(0, 1)
self.displayed = False
display(self.fig)
def on_epoch_begin(self, epoch, logs=None):
self.epoch = epoch # Update the current epoch at the beginning of each epoch
def on_train_batch_end(self, batch, logs=None):
accuracy = logs['accuracy']
self.accuracies.append(accuracy)
self.line.set_data(range(1, len(self.accuracies) + 1), self.accuracies)
self.ax.relim()
self.ax.autoscale_view()
clear_output(wait=True)
display(self.fig)
custom_callback = CustomCallback()
model = Sequential()
model.add(tf.keras.layers.Dense(units=16, activation='relu'))
model.add(tf.keras.layers.Dropout(rate=0.35))
model.add(tf.keras.layers.Dense(units=1, activation='tanh'))
model.compile(optimizer=tf.keras.optimizers.Adam(), loss="binary_crossentropy", metrics=["accuracy"])
X = np.random.randn(10**2, 10**4)
y = np.random.randint(2, size=10**2)
abc = model.fit(X, y, epochs=7, batch_size=32, validation_split=0.025, verbose=False, callbacks=[custom_callback])
Solution
It's because jupyter notebook already shows a figure inline, so calls to display()
is duplicating it. For example, the following code shows the same line plot twice in jupyter notebook.
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
ax.plot(range(3))
display(fig)
To turn off the interactive mode, call plt.ioff()
right after matplotlib import. Alternatively, you can also close the figure at the end of training by including the following method to the class.
def on_train_end(self, logs=None):
plt.close(self.fig)
Answered By - cottontail
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.