Issue
I'm playing around with a keras based multi label classifer. I created a function that loads training and test data and then I process/split X/Y within the function itself. I'm getting a error when running my model but not quite sure the meaning:
Here's my code:
def KerasClassifer(df_train, df_test):
X_train = df_train[columnType].copy()
y_train = df_train[variableToPredict].copy()
labels = y_train.unique()
print(X_train.shape[1])
#using keras to do classification
from tensorflow import keras
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Activation
from tensorflow.keras.optimizers import SGD
model = Sequential()
model.add(Dense(5000, activation='relu', input_dim=X_train.shape[1]))
model.add(Dropout(0.1))
model.add(Dense(600, activation='relu'))
model.add(Dropout(0.1))
model.add(Dense(len(labels), activation='sigmoid'))
sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='binary_crossentropy',
optimizer=sgd)
model.fit(X_train, y_train, epochs=5, batch_size=2000)
preds = model.predict(X_test)
preds[preds>=0.5] = 1
preds[preds<0.5] = 0
score = model.evaluate(X_test, y_test, batch_size=2000)
score
Here are attributes of my data(if it helps):
x train shape (392436, 109)
y train shape (392436,)
len of y labels 18
How can I fix the code to avoid this error?
Solution
If you have 18 categories the shape of y_train
should be (392436, 18)
. You can use tf.one_hot
for that:
import tensorflow as tf
y_train = tf.one_hot(y_train, depth=len(labels))
And if you're taking your values from one column, I suspect this is not "multi-label", but multi-class. Can a sample really belong to multiple categories? If not, you will need to change a few other things too. For instance, you would need softmax activation:
model.add(Dense(len(labels), activation='softmax'))
And also categorical crossentropy loss:
model.compile(loss='categorical_crossentropy', optimizer=sgd)
Answered By - Nicolas Gervais
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.