Issue
I'm using Amazon's SageMaker Studio Lab to train a model using a certain dataset.
The code is as follow (which saves the History object in history variable):
model = tf.keras.models.load_model('best_model.hdf5') # Every run after runtime end, use the last saved model
model.compile(optimizer=SGD(lr=0.0001, momentum=0.9), loss='categorical_crossentropy', metrics=['accuracy'])
checkpointer = ModelCheckpoint(filepath='best_model.hdf5', verbose=1, save_best_only=True)
csv_logger = CSVLogger('history.log')
history = model.fit_generator(train_generator,
steps_per_epoch = nb_train_samples // batch_size,
validation_data=validation_generator,
validation_steps=nb_validation_samples // batch_size,
epochs=30,
verbose=1,
callbacks=[csv_logger, checkpointer])
I had to make several pauses due to ending runtime, and with each pause I saved the .log file. Now after appending those .log files, I'm trying to access them using the standard accuracy and loss plotting methods:
def plot_accuracy(history,title):
plt.title(title)
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train_accuracy', 'validation_accuracy'], loc='best')
plt.show()
def plot_loss(history,title):
plt.title(title)
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train_loss', 'validation_loss'], loc='best')
plt.show()
But the issue is I can't seem to manage to receive a working History object file.
Among the things I tried:
I read about this possible method, of re-loading the model and trying to get it's history, but it didn't work, "'NoneType' object has no attribute 'history'"
model = tf.keras.models.load_model('best_model.hdf5')
history = model.history
Another try was using pandas package and loading the file, which generated an error "'DataFrame' object has no attribute 'history'":
history = pd.read_csv('history.log', sep=',', engine='python')
And this try generated a CSVLogger object, "'CSVLogger' object has no attribute 'history'":
history = CSVLogger('history.log')
Appreciate any help on how to recover the History object, so I can plot those results (if it's even possible?)...
Thanks.
Solution
Instead of recreating the History object, what I did was read the .log file using pandas package, read_csv
method, and create a DataFrame data structure with the wanted columns and plot it. Code below:
history = pd.read_csv('history.log')
history_acc = pd.DataFrame(history, columns=["accuracy", "val_accuracy"])
history_loss = pd.DataFrame(history, columns=["loss", "val_loss"])
plot_accuracy(history_acc,'plot title...')
plot_loss(history_loss,'plot title...')
def plot_accuracy(history,title):
plt.title(title)
plt.plot(history)
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train_accuracy', 'validation_accuracy'], loc='best')
plt.show()
def plot_loss(history,title):
plt.title(title)
plt.plot(history)
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train_loss', 'validation_loss'], loc='best')
plt.show()
Hope this helps someone having the same issue as I did in the future.
Answered By - SA.93
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.