Skip to content

Instantly share code, notes, and snippets.

View prrao87's full-sized avatar

Prashanth Rao prrao87

View GitHub Profile
"""
Adapted from Oliver Atanaszov's notebook on transformer fine-tuning
https://github.com/ben0it8/containerized-transformer-finetuning/blob/develop/research/finetune-transformer-on-imdb5k.ipynb
"""
from concurrent.futures import ProcessPoolExecutor
import multiprocessing
import os
import numpy as np
import pandas as pd
import torch
"""
Below code is as per the NAACL transfer learning tutorial:
https://github.com/huggingface/naacl_transfer_learning_tutorial
"""
import torch
import torch.nn as nn
class Transformer(nn.Module):
def __init__(self, embed_dim, hidden_dim, num_embeddings, num_max_positions, num_heads, num_layers, dropout, causal):
super().__init__()
"""
Code below is as per the NAACL transfer learning tutorial:
https://github.com/huggingface/naacl_transfer_learning_tutorial
"""
class TransformerWithClfHead(nn.Module):
def __init__(self, config, fine_tuning_config):
""" Transformer with a classification head. """
super().__init__()
self.config = fine_tuning_config
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
nlp = spacy.load('en_core_web_sm', disable=['tagger', 'parser', 'ner'])
nlp.add_pipe(nlp.create_pipe('sentencizer'))
def get_stopwords():
"Return a set of stopwords read in from a file."
with open(stopwordfile) as f:
stopwords = []
for line in f:
stopwords.append(line.strip("\n"))
# Convert to set for performance
stopwords_set = set(stopwords)
return stopwords_set
def read_data(inputfile):
"Read in a tab-separated file with date, headline and news content"
df = pd.read_csv(inputfile, sep='\t', header=None,
names=['date', 'headline', 'content'])
df['date'] = pd.to_datetime(df['date'], format="%Y-%m-%d")
return df
df = read_data(inputfile)
df.head(3)
def cleaner(df):
"Extract relevant text from DataFrame using a regex"
# Regex pattern for only alphanumeric, hyphenated text with 3 or more chars
pattern = re.compile(r"[A-Za-z0-9\-]{3,50}")
df['clean'] = df['content'].str.findall(pattern).str.join(' ')
return df
df_preproc = cleaner(df)
df_preproc.head(3)
def lemmatize(text):
"""Perform lemmatization and stopword removal in the clean text
Returns a list of lemmas
"""
doc = nlp(text)
lemma_list = [str(tok.lemma_).lower() for tok in doc
if tok.is_alpha and tok.text.lower() not in stopwords]
return lemma_list
def lemmatize_pipe(doc):
lemma_list = [str(tok.lemma_).lower() for tok in doc
if tok.is_alpha and tok.text.lower() not in stopwords]
return lemma_list
def preprocess_pipe(texts):
preproc_pipe = []
for doc in nlp.pipe(texts, batch_size=20):
preproc_pipe.append(lemmatize_pipe(doc))
return preproc_pipe