Skip to content

Instantly share code, notes, and snippets.

@ZaxR
Last active November 13, 2020 22:14
Show Gist options
  • Save ZaxR/8e36b7e4ddc761003097dd723ff83a72 to your computer and use it in GitHub Desktop.
Save ZaxR/8e36b7e4ddc761003097dd723ff83a72 to your computer and use it in GitHub Desktop.
Advanced sklearn Pipeline / GridSearchCV Workflow
import re
from collections import Counter
import numpy as np
import pandas as pd
from google.cloud import storage
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
from nltk.stem.porter import PorterStemmer
from nltk import pos_tag # do this on token and get part of speech
from sklearn.base import BaseEstimator
from sklearn.ensemble import RandomForestClassifier
from sklearn.feature_extraction import FeatureHasher
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
from sklearn.metrics import confusion_matrix, accuracy_score, classification_report
from sklearn.model_selection import train_test_split, ShuffleSplit, GridSearchCV, cross_val_score, StratifiedShuffleSplit
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.tree import DecisionTreeClassifier
nltk.download('punkt')
nltk.download('stopwords')
# nltk check alphanumeric method for tokens
def keep_alphanumeric(ser, regex=r'[^A-Za-z]+'):
"""Keep only alphanumeric characters in a pd.Series."""
p = re.compile(regex)
return pd.Series([p.sub(' ', x) for x in ser.tolist()])
def tokenize(ser):
"""Tokenize a pd.Series."""
return ser.apply(word_tokenize)
def remove_stopwords(ser, stopwords=None):
"""Remove any stopwords from ps.Series given a stopwords set."""
if stopwords is None:
stopwords = nltk.corpus.stopwords.words('english')
return ser.apply(lambda x: [word for word in x if word not in stopwords])
def remove_len_n_words(ser, n):
"""Remove any words of length n or shorter from a pd.Series."""
return ser.apply(lambda x: [word for word in x if len(word) > n])
def lemmatize(ser):
"""Lemmatize a pd.Series."""
lemmatizer = PorterStemmer()
return ser.apply(lambda x: [lemmatizer.stem(word) for word in x])
# def filter_out_part_of_speech(ser, pos_list):
# """Removes specfied parts of speech from pd.Series."""
# return ser.apply(lambda x: [token for token in x if pos_tag(token) not in pos_list])
def sklearn_ingest(ser):
"""Reformats series so that it can be ingested by sklearn models."""
return [' '.join(x) for x in ser]
def hasher_ingest(ser):
"""Convert a pd.Series into a dictionary that can be ingested by a sklearn FeatureHasher."""
return ser.apply(lambda x: Counter(x))
def clean_text_series(ser, inplace=False):
"""Preprocesses a pd.Series with string values to enable it to be fed into an sklearn Pipeline."""
_ser = ser.copy() if not inplace else ser
_ser = _ser.fillna("")
_ser = keep_alphanumeric(_ser)
_ser = tokenize(_ser)
_ser = remove_stopwords(_ser)
_ser = lemmatize(_ser)
return _ser
class DummyEstimator(BaseEstimator):
def fit(self):
pass
def score(self):
pass
pipe = Pipeline([('len_n', FunctionTransformer(remove_len_n_words, validate=False)),
('stringify_tok', FunctionTransformer(sklearn_ingest, validate=False)),
('count', CountVectorizer()),
('tfidf', TfidfVectorizer()),
('clf', DummyEstimator())])
# parameters to be tested for all models
base_params = [{'len_n__kw_args': ({"n": 0}, {"n": 1}, {"n": 2}, {"n": 3}),
'count__ngram_range': [(1, 1), (1, 2)],
'count__max_features': [700, 1000, 1500],
'tfidf': [None]},
{'len_n__kw_args': ({"n": 0}, {"n": 1}, {"n": 2}, {"n": 3}),
'count__ngram_range': [(1, 1), (1, 2)],
'count__max_features': [700, 1000, 1500],
'count': [None]}]
# parameters specific to the model
model_params = [{'clf': [DecisionTreeClassifier()], 'clf__criterion': ['gini', 'entropy']},
{'clf': [RandomForestClassifier()], 'clf__criterion': ['gini', 'entropy']}]
# model params will overwrite base parameters if there is a conflicting key
params = [{**base, **d} for base in base_params for d in model_params]
n_jobs = -1
cv = ShuffleSplit(n_splits=3, random_state=random_state)
grid = GridSearchCV(estimator=pipe,
param_grid=params,
scoring='accuracy',
cv=cv,
verbose=1,
n_jobs=n_jobs)