Issue
I taught a model (tensorflow tutorial) in Jupyter then saved it, then succesfully loaded it back (kernel was restarted). Here's the code:
# Directory where the checkpoints will be saved
checkpoint_dir = '/home/charlie-chin/william_model/training_checkpoints'
# Name of the checkpoint files
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")
checkpoint_callback=tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_prefix,
save_weights_only=True)
model.save('/home/charlie-chin/william_model')
model = keras.models.load_model('/home/charlie-chin/william_model', custom_objects={'loss':loss})
checkpoint_num = 10
model.load_weights(tf.train.Checkpoint("/home/charlie-chin/william_model/training_checkpoints/ckpt_" + str(checkpoint_num)))
All went good except the last 2 lines which gave me this error:
ValueError: `Checkpoint` was expecting root to be a trackable object (an object derived from `Trackable`), got /home/charlie-chin/william_model/training_checkpoints/ckpt_1. If you believe this object should be trackable (i.e. it is part of the TensorFlow Python API and manages state), please open an issue.
I checked the path - it is correct. Here's full output of the error:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Input In [39], in <cell line: 4>()
1 checkpoint_num = 10
2 # model.load_weights(tf.train.load_checkpoint("./william_model/training_checkpoints/ckpt_"))
3 # model.load_weights(tf.train.Checkpoint("/home/charlie-chin/william_model/training_checkpoints/ckpt_" + str(checkpoint_num)+".data-00000-of-00001"))
----> 4 model.load_weights(tf.train.Checkpoint("/home/charlie-chin/william_model/training_checkpoints/ckpt_" + str(checkpoint_num)))
File ~/.local/lib/python3.8/site-packages/tensorflow/python/training/tracking/util.py:2107, in Checkpoint.__init__(self, root, **kwargs)
2105 if root:
2106 trackable_root = root() if isinstance(root, weakref.ref) else root
-> 2107 _assert_trackable(trackable_root, "root")
2108 attached_dependencies = []
2110 # All keyword arguments (including root itself) are set as children
2111 # of root.
File ~/.local/lib/python3.8/site-packages/tensorflow/python/training/tracking/util.py:1546, in _assert_trackable(obj, name)
1543 def _assert_trackable(obj, name):
1544 if not isinstance(
1545 obj, (base.Trackable, def_function.Function)):
-> 1546 raise ValueError(
1547 f"`Checkpoint` was expecting {name} to be a trackable object (an "
1548 f"object derived from `Trackable`), got {obj}. If you believe this "
1549 "object should be trackable (i.e. it is part of the "
1550 "TensorFlow Python API and manages state), please open an issue.")
ValueError: `Checkpoint` was expecting root to be a trackable object (an object derived from `Trackable`), got /home/charlie-chin/william_model/training_checkpoints/ckpt_10. If you believe this object should be trackable (i.e. it is part of the TensorFlow Python API and manages state), please open an issue.
Solution
You should be able to load the checkpoints according to the TensorFlow documentation like this:
checkpoint_num = 10
model.load_weights("/home/charlie-chin/william_model/training_checkpoints/ckpt_" + str(checkpoint_num))
Answered By - claudia
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.