Scikit FTW!

I haven’t learned much about why the ELMo+LSTM combination isn’t working so well, but I did learn something nifty about scikit! If you change this little bit of code:

classification = lreg.predict(self.elmo_text)
return int(classification[0]) * 100

to look like this:

classification = lreg.predict_proba(self.elmo_text)
# probabilities are [[probability_false, probability_true]]
return classification[0][1] * 100

You get a nice looking histogram for ELMo! And just judging from typing in a few random sentences, it is way more accurate than LSTM. This approach is the clear winner so far.

Along the way of trying to fix ELMo+LSTM I discovered keras, which seems kind of nifty… you can create a model with just a few lines of code like this:

# build model
elmo_input = Input(shape=(1024,), dtype="float")
dense = Dense(256, activation='relu', kernel_regularizer=keras.regularizers.l2(0.001))(elmo_input)
pred = Dense(1, activation='softmax')(dense)
model = Model(inputs=[elmo_input], outputs=pred)
model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])

That looks nice, doesn’t it? Each layer in the model is a line of code, and each line has just a few parameters. It feels like you could actually learn what they mean. Unfortunately, feeding the ELMo vectors into this gives really lousy results. Like, so bad I’m not even going to post it. It just gives all 0’s for any input. Not sure why.

The scikit model, on the other hand, goes like this:

# create and train classification model
lreg = LogisticRegression()
lreg.fit(xtrain, ytrain)

Bam! That’s it. It does really well with the ELMo vectors, but it’s black box magic. I have no idea what’s going on inside LogisticRegression(), and if you look at the documentation… holy hell what a bunch of gobbledygook. I feel like it will be a while before I understand what’s going on in there. But perhaps if I can get one of the other packages to approximate this LogisticRegression thing, I can figure it out.

Leave a Reply

Your email address will not be published. Required fields are marked *