Issue
I have a Keras TextVectorization
layer which uses a custom standardization function.
def custom_standardization(input_string, preserve=['[', ']'], add=['¿']):
strip_chars = string.punctuation
for item in add:
strip_chars += item
for item in preserve:
strip_chars = strip_chars.replace(item, '')
lowercase = tf.strings.lower(input_string)
output = tf.strings.regex_replace(lowercase, f'[{re.escape(strip_chars)}]', '')
return output
target_vectorization = keras.layers.TextVectorization(max_tokens=vocab_size,
output_mode='int',
output_sequence_length=sequence_length + 1,
standardize=custom_standardization)
target_vectorization.adapt(train_spanish_texts)
I want to save the adapted configuration for an inference model to make use of.
One way, as described here, is to save the weights
and config
separately as a pickle file and reload them.
However, target_vectorization.get_config()
returns
{'name': 'text_vectorization_5',
'trainable': True,
...
'standardize': <function __main__.custom_standardization(input_string, preserve=['[', ']'], add=['¿'])>,
...
'vocabulary_size': 15000}
which is being saved into the pickle file.
Trying to load this config using keras.layers.TextVectorization.from_config(pickle.load(open('ckpts/spanish_vectorization.pkl', 'rb'))['config'])
results in TypeError: Could not parse config: <function custom_standardization at 0x2a1973a60>
, because the file does not have any information about this custom standardization function.
What is a good way to save the TextVectorization weights and configuration for an inference model to make use of, in this scenario?
Solution
The solution here was to define a wrapper around the TextVectorization object and use the custom standardizer as a method. Moreover, we needed to exclude callable objects while saving configuration to the pickle file. Here's the fixed code:
@keras.utils.register_keras_serializable(package='custom_layers', name='TextVectorizer')
class TextVectorizer(layers.Layer):
'''English - Spanish Text Vectorizer'''
def __init__(self, max_tokens=None, output_mode='int', output_sequence_length=None, standardize='lower_and_strip_punctuation', vocabulary=None, config=None):
super().__init__()
if config:
self.vectorization = layers.TextVectorization.from_config(config)
else:
self.max_tokens = max_tokens
self.output_mode = output_mode
self.output_sequence_length = output_sequence_length
self.vocabulary = vocabulary
if standardize != 'lower_and_strip_punctuation':
self.vectorization = layers.TextVectorization(max_tokens=self.max_tokens,
output_mode=self.output_mode,
output_sequence_length=self.output_sequence_length,
vocabulary=self.vocabulary,
standardize=self.standardize)
else:
self.vectorization = layers.TextVectorization(max_tokens=self.max_tokens,
output_mode=self.output_mode,
output_sequence_length=self.output_sequence_length,
vocabulary=self.vocabulary)
def standardize(self, input_string, preserve=['[', ']'], add=['¿']) -> str:
strip_chars = string.punctuation
for item in add:
strip_chars += item
for item in preserve:
strip_chars = strip_chars.replace(item, '')
lowercase = tf.strings.lower(input_string)
output = tf.strings.regex_replace(lowercase, f'[{re.escape(strip_chars)}]', '')
return output
def __call__(self, *args, **kwargs):
return self.vectorization.__call__(*args, **kwargs)
def get_config(self):
return {key: value if not callable(value) else None for key, value in self.vectorization.get_config().items()}
def from_config(config):
return TextVectorizer(config=config)
def set_weights(self, weights):
self.vectorization.set_weights(weights)
def adapt(self, dataset):
self.vectorization.adapt(dataset)
def get_vocabulary(self):
return self.vectorization.get_vocabulary()
To adapt and save weights [Training Phase]:
vocab_size = 15000
sequence_length = 20
source_vectorization = TextVectorizer(max_tokens=vocab_size,
output_mode='int',
output_sequence_length=sequence_length)
target_vectorization = TextVectorizer(max_tokens=vocab_size,
output_mode='int',
output_sequence_length=sequence_length + 1,
standardize='spanish')
train_english_texts = [pair[0] for pair in train_pairs]
train_spanish_texts = [pair[1] for pair in train_pairs]
source_vectorization.adapt(train_english_texts)
target_vectorization.adapt(train_spanish_texts)
pickle.dump({'config': source_vectorization.get_config(),
'weights': source_vectorization.get_weights()}, open('ckpts/english_vectorization.pkl', 'wb'))
pickle.dump({'config': target_vectorization.get_config(),
'weights': target_vectorization.get_weights()}, open('ckpts/spanish_vectorization.pkl', 'wb'))
To load and use them [Inference Phase]:
vectorization_data = pickle.load(open('ckpts/english_vectorization.pkl', 'rb'))
source_vectorization = TextVectorizer.from_config(vectorization_data['config'])
source_vectorization.set_weights(vectorization_data['weights'])
vectorization_data = pickle.load(open('ckpts/spanish_vectorization.pkl', 'rb'))
target_vectorization = TextVectorizer.from_config(vectorization_data['config'])
target_vectorization.set_weights(vectorization_data['weights'])
Answered By - Suprateem Banerjee
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.