Skip to content

Instantly share code, notes, and snippets.

@fancyerii
Created July 17, 2020 01:02
Show Gist options
  • Save fancyerii/d4cbf64151a0a80b0da196fd1e23cd1b to your computer and use it in GitHub Desktop.
Save fancyerii/d4cbf64151a0a80b0da196fd1e23cd1b to your computer and use it in GitHub Desktop.
import transformers
from transformers import BertModel, BertTokenizer, AdamW, get_linear_schedule_with_warmup
import numpy as np
import torch
import random
from torch.utils.data import Dataset, DataLoader
from torch import nn, optim
from collections import defaultdict
import matplotlib.pyplot as plt
SEP = "###"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def read_data(path):
result = []
with open(path, "r", encoding="utf8") as file:
first_line = True
for line in file:
if first_line:
first_line = False
continue
line = line.rstrip("\n")
tks = line.split("\t", maxsplit=2)
order_class = tks[0]
order_content = tks[1] + SEP + tks[2]
result.append((order_class, order_content))
return result
train_data = read_data("/home/lili/data/order_data_v2/train_data_v2.csv")
test_data = read_data("/home/lili/data/order_data_v2/test_data_v2.csv")
classes = list(set([pair[0] for pair in train_data]))
class_to_id = {}
for i, cls in enumerate(classes):
class_to_id[cls] = i
train_data = [(class_to_id[pair[0]], pair[1]) for pair in train_data]
random.shuffle(train_data)
train_texts = [pair[1] for pair in train_data]
train_targets = [pair[0] for pair in train_data]
test_data = [(class_to_id[pair[0]], pair[1]) for pair in test_data]
test_texts = [pair[1] for pair in test_data]
test_targets = [pair[0] for pair in test_data]
PRETRAINED_MODEL_PATH = '/home/lili/data/huggface/bert-base-chinese'
tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_PATH)
sample_txt = train_data[0][1]
MAX_LEN = 150
class OrderDataset(Dataset):
def __init__(self, texts, targets, tokenizer, max_len):
self.texts = texts
self.targets = targets
self.tokenizer = tokenizer
self.max_len = max_len
def __len__(self):
return len(self.texts)
def __getitem__(self, item):
text = self.texts[item]
target = self.targets[item]
encoding = self.tokenizer.encode_plus(
text,
add_special_tokens=True,
max_length=self.max_len,
return_token_type_ids=False,
pad_to_max_length=True,
return_attention_mask=True,
return_tensors='pt',
)
return {
'text': text,
'input_ids': encoding['input_ids'].flatten(),
'attention_mask': encoding['attention_mask'].flatten(),
'targets': torch.tensor(target, dtype=torch.long)
}
def create_data_loader(texts, targets, tokenizer, max_len, batch_size):
ds = OrderDataset(
texts=texts,
targets=targets,
tokenizer=tokenizer,
max_len=max_len
)
return DataLoader(
ds,
batch_size=batch_size,
num_workers=1,
shuffle=True
)
BATCH_SIZE = 16
train_data_loader = create_data_loader(train_texts, train_targets, tokenizer, MAX_LEN, BATCH_SIZE)
test_data_loader = create_data_loader(test_texts, test_targets, tokenizer, MAX_LEN, BATCH_SIZE)
class OrderClassifier(nn.Module):
def __init__(self, n_classes, path):
super(OrderClassifier, self).__init__()
self.bert = BertModel.from_pretrained(path)
self.drop = nn.Dropout(p=0.3)
self.out = nn.Linear(self.bert.config.hidden_size, n_classes)
def forward(self, input_ids, attention_mask):
_, pooled_output = self.bert(
input_ids=input_ids,
attention_mask=attention_mask
)
output = self.drop(pooled_output)
return self.out(output)
data = next(iter(train_data_loader))
model = OrderClassifier(len(class_to_id), PRETRAINED_MODEL_PATH)
model = model.to(device)
input_ids = data['input_ids'].to(device)
attention_mask = data['attention_mask'].to(device)
EPOCHS = 3
optimizer = AdamW(model.parameters(), lr=2e-5, correct_bias=False)
total_steps = len(train_data_loader) * EPOCHS
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=0,
num_training_steps=total_steps
)
loss_fn = nn.CrossEntropyLoss().to(device)
def train_epoch(model, data_loader, loss_fn, optimizer,
device, scheduler, n_examples):
model = model.train()
losses = []
correct_predictions = 0
print_interval = 10
print_counter = 0
for d in data_loader:
input_ids = d["input_ids"].to(device)
attention_mask = d["attention_mask"].to(device)
targets = d["targets"].to(device)
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask
)
_, preds = torch.max(outputs, dim=1)
loss = loss_fn(outputs, targets)
print_counter += 1
if print_counter % print_interval == 0:
print("batch {}, loss {}".format(print_counter, loss))
correct_predictions += torch.sum(preds == targets)
losses.append(loss.item())
loss.backward()
#nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
return correct_predictions.double() / n_examples, np.mean(losses)
def eval_model(model, data_loader, loss_fn, device, n_examples):
model = model.eval()
losses = []
correct_predictions = 0
with torch.no_grad():
for d in data_loader:
input_ids = d["input_ids"].to(device)
attention_mask = d["attention_mask"].to(device)
targets = d["targets"].to(device)
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask
)
_, preds = torch.max(outputs, dim=1)
loss = loss_fn(outputs, targets)
correct_predictions += torch.sum(preds == targets)
losses.append(loss.item())
return correct_predictions.double() / n_examples, np.mean(losses)
best_accuracy = 0
history = defaultdict(list)
for epoch in range(EPOCHS):
print(f'Epoch {epoch + 1}/{EPOCHS}')
print('-' * 10)
train_acc, train_loss = train_epoch(
model,
train_data_loader,
loss_fn,
optimizer,
device,
scheduler,
len(train_texts)
)
print(f'Train loss {train_loss} accuracy {train_acc}')
val_acc, val_loss = eval_model(
model,
test_data_loader,
loss_fn,
device,
len(test_texts)
)
print(f'Val loss {val_loss} accuracy {val_acc}')
history['train_acc'].append(train_acc)
history['train_loss'].append(train_loss)
history['val_acc'].append(val_acc)
history['val_loss'].append(val_loss)
if val_acc > best_accuracy:
torch.save(model, 'best_model.bin')
best_accuracy = val_acc
test_acc, _ = eval_model(
model,
test_data_loader,
loss_fn,
device,
len(test_texts)
)
print("test acc {}".format(test_acc.item()))
plt.title('Training history')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend()
plt.ylim([0, 1])
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment