This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Adapted from Oliver Atanaszov's notebook on transformer fine-tuning | |
https://github.com/ben0it8/containerized-transformer-finetuning/blob/develop/research/finetune-transformer-on-imdb5k.ipynb | |
""" | |
from concurrent.futures import ProcessPoolExecutor | |
import multiprocessing | |
import os | |
import numpy as np | |
import pandas as pd | |
import torch |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Below code is as per the NAACL transfer learning tutorial: | |
https://github.com/huggingface/naacl_transfer_learning_tutorial | |
""" | |
import torch | |
import torch.nn as nn | |
class Transformer(nn.Module): | |
def __init__(self, embed_dim, hidden_dim, num_embeddings, num_max_positions, num_heads, num_layers, dropout, causal): | |
super().__init__() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Code below is as per the NAACL transfer learning tutorial: | |
https://github.com/huggingface/naacl_transfer_learning_tutorial | |
""" | |
class TransformerWithClfHead(nn.Module): | |
def __init__(self, config, fine_tuning_config): | |
""" Transformer with a classification head. """ | |
super().__init__() | |
self.config = fine_tuning_config |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
nlp = spacy.load('en_core_web_sm', disable=['tagger', 'parser', 'ner']) | |
nlp.add_pipe(nlp.create_pipe('sentencizer')) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_stopwords(): | |
"Return a set of stopwords read in from a file." | |
with open(stopwordfile) as f: | |
stopwords = [] | |
for line in f: | |
stopwords.append(line.strip("\n")) | |
# Convert to set for performance | |
stopwords_set = set(stopwords) | |
return stopwords_set |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def read_data(inputfile): | |
"Read in a tab-separated file with date, headline and news content" | |
df = pd.read_csv(inputfile, sep='\t', header=None, | |
names=['date', 'headline', 'content']) | |
df['date'] = pd.to_datetime(df['date'], format="%Y-%m-%d") | |
return df | |
df = read_data(inputfile) | |
df.head(3) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def cleaner(df): | |
"Extract relevant text from DataFrame using a regex" | |
# Regex pattern for only alphanumeric, hyphenated text with 3 or more chars | |
pattern = re.compile(r"[A-Za-z0-9\-]{3,50}") | |
df['clean'] = df['content'].str.findall(pattern).str.join(' ') | |
return df | |
df_preproc = cleaner(df) | |
df_preproc.head(3) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def lemmatize(text): | |
"""Perform lemmatization and stopword removal in the clean text | |
Returns a list of lemmas | |
""" | |
doc = nlp(text) | |
lemma_list = [str(tok.lemma_).lower() for tok in doc | |
if tok.is_alpha and tok.text.lower() not in stopwords] | |
return lemma_list |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def lemmatize_pipe(doc): | |
lemma_list = [str(tok.lemma_).lower() for tok in doc | |
if tok.is_alpha and tok.text.lower() not in stopwords] | |
return lemma_list | |
def preprocess_pipe(texts): | |
preproc_pipe = [] | |
for doc in nlp.pipe(texts, batch_size=20): | |
preproc_pipe.append(lemmatize_pipe(doc)) | |
return preproc_pipe |