Issue
I am working on a feature aimed at prompting users to attach images based on their recent conversations where images were shared.
If a user attempts to send a message without an image for similar contexts, I've implemented a prediction mechanism using scikit-learn
. However, I'm encountering an issue where the prediction for the message 'how are you?' returns 0.66
, while it should ideally be than 0.5
.
Here is the code:
import re
import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
messages = {'Did you receive the image yesterday?': 0,
'I\'m going to send you a picture of my cat.': 1,
'I\'m at the beach and I\'m taking a photo of the sunset.': 1,
'I\'m going to send you a video of my dog playing fetch.': 1,
'I\'m going to send you a screenshot of my computer screen.': 1,
'Please find the attachment': 1, 'attached file': 1,
'attached image': 1, 'attached': 1, 'yesterday?': 1, 'did you receive': 0}
pd_messages = pd.DataFrame({'text': messages.keys(),
'has_image': messages.values()})
features = pd_messages['text']
labels = pd_messages['has_image']
tfidf_vectorizer = TfidfVectorizer()
features_tfidf = tfidf_vectorizer.fit_transform(features)
model = LogisticRegression(solver='liblinear')
model.fit(features_tfidf, labels)
message = "how are you?"
message = re.sub(r'[^\w\s]', '', message.lower())
message_tfidf = tfidf_vectorizer.transform([message])
prediction = model.predict_proba(message_tfidf)[:, 1]
print(prediction)
Even after adding 'how are you?': 0
to the messages dictionary, the prediction is still above 0.5
. Why is this happening?
Solution
There are a few issues with your example:
- There is class imbalance present which in general is not good when using regression-based models.
- The training dataset is extremely small.
- TF-IDF might not be the best choice when extracting conversational context.
Possible solutions:
- Generate more samples of class 0 until your training set is balanced or use balancing techniques like minority-oversampling or majority-undersampling.
# Add more examples with has_image=0
messages['how are you?'] = 0
messages['hello'] = 0
- Increase the number of training samples.
- You might want to try out word embeddings.
Answered By - DataJanitor
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.