Issue
I use tensorflow to build a neural network that classifies images of covid-19 rapid tests into three classes (Negative, Positive, Empty).
During training the tensorboard logs denote a validation accuracy of around 90%. But when I test the network after being trained with the same images it was trained on, the classification performance is way worse (~60%). I observed the same behavior when I trained the network with different images (see section What I have tried
).
During training the images are preprocessed to grayscale and resized before being fed into the model. The batch size is 16.
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
image = cv2.resize(image, (height, width))
To augment the sparse data that I have (~450 images) I am using the keras.preprocessing.image.ImageDataGenerator
and its parameters are: width_shift_range=0.1
, brightness_range=[0.9, 1.1]
, height_shift_range=0.1
, zoom_range=0.2
, horizontal_flip=True
, rotation_range=10
, shear_range=0.2
, fill_mode="nearest"
, samplewise_center=True
, samplewise_std_normalization=True
I am converting the model to tflite because we need it for mobile platforms. I am using this code snippet:
model = tf.keras.models.load_model(model_path)
converter = tf.lite.TFLiteConverter.from_keras_model(model) # path to the SavedModel directory
# converter.optimizations = [tf.lite.Optimize.DEFAULT] # optimizations
tflite_model = converter.convert()
# Save the model.
with open('rapid_test_strip_cleaned_model.tflite', 'wb') as f:
f.write(tflite_model)
What I have tried:
- crop the images to the strip of the casette, train and test the network again
- check in the testing (inference) script if the labels are correct
- check if the images are converted to grayscale and resized correctly before being fed into the network during testing
- test the model before converting it to tflite, using
tensorflow.keras.models
Model:
img_width, img_height = (256, 256)
model = Sequential()
inputShape = (img_width, img_height, 1)
model.add(Conv2D(32, (3, 3), activation="relu", input_shape=inputShape))
# to prevent overfitting
model.add(Dropout(0.25))
model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))
model.add(Conv2D(64, (3, 3), activation="relu"))
# to prevent overfitting
model.add(Dropout(0.25))
model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))
model.add(Conv2D(128, (3, 3), activation="relu"))
# to prevent overfitting
model.add(Dropout(0.25))
model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))
model.add(Flatten())
model.add(Dense(512, activation="relu"))
# to prevent overfitting
model.add(Dropout(0.5))
model.add(Dense(3, activation="softmax"))
opt = Adam(learning_rate=INIT_LR, decay=INIT_LR / EPOCHS)
model.compile(loss="categorical_crossentropy", optimizer=opt,
metrics=["accuracy"])
This is the Tensorboard graph of the training. The straight line is from another training run.
Testing/Inference script:
interpreter = tf.lite.Interpreter(model_path=model_path)
# Load TFLite model and allocate tensors.
interpreter.allocate_tensors()
# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
height = input_details[0]['shape'][1]
width = input_details[0]['shape'][2]
labels = ["positive", "negative", "initial"]
# load image into numpy array
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
image = cv2.resize(image, (height, width))
input_arr = img_to_array(image)
input_arr = np.array([input_arr])
# normalize values
input_arr = input_arr / 255.0
interpreter.set_tensor(input_details[0]['index'], input_arr)
interpreter.invoke()
# The function `get_tensor()` returns a copy of the tensor data.
output_data = interpreter.get_tensor(output_details[0]['index'])
results = np.squeeze(output_data)
top_k = results.argsort()[-5:][::-1]
print(labels[top_k[0]])
Where may be the problem?
Solution
I found the issue with the tip of @Djinn. It was the normalization in the inference script. It is supposed to be input_arr = (input_arr - np.mean(input_arr)) / np.std(input_arr)
.
Answered By - Sohrab
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.