Last active
July 10, 2021 14:05
-
-
Save urigoren/3fe048a5082a6cc04b7bf305a050eedb to your computer and use it in GitHub Desktop.
Bag-of-words baseline for conditional text classification
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from copy import deepcopy as clone | |
from sklearn.base import ClassifierMixin | |
from sklearn.pipeline import Pipeline | |
class ConditionedTextClassifier(ClassifierMixin): | |
def __init__(self, conditions, model, condition_sep=' <s> '): | |
self.condition_sep=condition_sep | |
self.conditions = {} | |
for c in conditions: | |
self.conditions[c] = clone(model) | |
def _filter_condition(self, X,y=None,c=None): | |
if y is None: | |
y = [None]*len(X) | |
if c is None: | |
raise SyntaxError("condition cannot be None") | |
IXY = [s.split(self.condition_sep, 1) for s in X] | |
IXY = [(yy[0], xx[1], yy[1]) for xx,yy in zip(IXY,enumerate(y)) if xx[0]==c] | |
if len(IXY)==0: | |
return [],[],[] | |
ind, X,y = zip(*IXY) | |
return ind, X, y | |
def fit(self, X, y): | |
for c in self.conditions: | |
ind_c, X_c, y_c = self._filter_condition(X,y,c) | |
if len(X_c)>0: | |
self.conditions[c].fit(X_c, y_c) | |
def predict(self, X): | |
ret = [] | |
for c in self.conditions: | |
ind_c, X_c, y_c = self._filter_condition(X, c=c) | |
if len(X_c)>0: | |
y_c = self.conditions[c].predict(X_c) | |
ret.extend(list(zip(ind_c, y_c))) | |
ret = [y for i,y in sorted(ret)] | |
return ret | |
base_model = Pipeline([ | |
("vec", CountVectorizer(min_df=1, max_df=0.7, binary=True)), | |
("model", LogisticRegression(dual=True, solver='liblinear')), | |
]) | |
if __name__ == "__main__": | |
bow_model = ConditionedTextClassifier(conditions, base_model) | |
bow_model.fit(X_train, y_train) | |
y_pred = bow_model.predict(X_test) | |
print(accuracy_score(y_test, y_pred)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment