Created
July 17, 2020 01:02
-
-
Save fancyerii/d4cbf64151a0a80b0da196fd1e23cd1b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import transformers | |
from transformers import BertModel, BertTokenizer, AdamW, get_linear_schedule_with_warmup | |
import numpy as np | |
import torch | |
import random | |
from torch.utils.data import Dataset, DataLoader | |
from torch import nn, optim | |
from collections import defaultdict | |
import matplotlib.pyplot as plt | |
SEP = "###" | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
def read_data(path): | |
result = [] | |
with open(path, "r", encoding="utf8") as file: | |
first_line = True | |
for line in file: | |
if first_line: | |
first_line = False | |
continue | |
line = line.rstrip("\n") | |
tks = line.split("\t", maxsplit=2) | |
order_class = tks[0] | |
order_content = tks[1] + SEP + tks[2] | |
result.append((order_class, order_content)) | |
return result | |
train_data = read_data("/home/lili/data/order_data_v2/train_data_v2.csv") | |
test_data = read_data("/home/lili/data/order_data_v2/test_data_v2.csv") | |
classes = list(set([pair[0] for pair in train_data])) | |
class_to_id = {} | |
for i, cls in enumerate(classes): | |
class_to_id[cls] = i | |
train_data = [(class_to_id[pair[0]], pair[1]) for pair in train_data] | |
random.shuffle(train_data) | |
train_texts = [pair[1] for pair in train_data] | |
train_targets = [pair[0] for pair in train_data] | |
test_data = [(class_to_id[pair[0]], pair[1]) for pair in test_data] | |
test_texts = [pair[1] for pair in test_data] | |
test_targets = [pair[0] for pair in test_data] | |
PRETRAINED_MODEL_PATH = '/home/lili/data/huggface/bert-base-chinese' | |
tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_PATH) | |
sample_txt = train_data[0][1] | |
MAX_LEN = 150 | |
class OrderDataset(Dataset): | |
def __init__(self, texts, targets, tokenizer, max_len): | |
self.texts = texts | |
self.targets = targets | |
self.tokenizer = tokenizer | |
self.max_len = max_len | |
def __len__(self): | |
return len(self.texts) | |
def __getitem__(self, item): | |
text = self.texts[item] | |
target = self.targets[item] | |
encoding = self.tokenizer.encode_plus( | |
text, | |
add_special_tokens=True, | |
max_length=self.max_len, | |
return_token_type_ids=False, | |
pad_to_max_length=True, | |
return_attention_mask=True, | |
return_tensors='pt', | |
) | |
return { | |
'text': text, | |
'input_ids': encoding['input_ids'].flatten(), | |
'attention_mask': encoding['attention_mask'].flatten(), | |
'targets': torch.tensor(target, dtype=torch.long) | |
} | |
def create_data_loader(texts, targets, tokenizer, max_len, batch_size): | |
ds = OrderDataset( | |
texts=texts, | |
targets=targets, | |
tokenizer=tokenizer, | |
max_len=max_len | |
) | |
return DataLoader( | |
ds, | |
batch_size=batch_size, | |
num_workers=1, | |
shuffle=True | |
) | |
BATCH_SIZE = 16 | |
train_data_loader = create_data_loader(train_texts, train_targets, tokenizer, MAX_LEN, BATCH_SIZE) | |
test_data_loader = create_data_loader(test_texts, test_targets, tokenizer, MAX_LEN, BATCH_SIZE) | |
class OrderClassifier(nn.Module): | |
def __init__(self, n_classes, path): | |
super(OrderClassifier, self).__init__() | |
self.bert = BertModel.from_pretrained(path) | |
self.drop = nn.Dropout(p=0.3) | |
self.out = nn.Linear(self.bert.config.hidden_size, n_classes) | |
def forward(self, input_ids, attention_mask): | |
_, pooled_output = self.bert( | |
input_ids=input_ids, | |
attention_mask=attention_mask | |
) | |
output = self.drop(pooled_output) | |
return self.out(output) | |
data = next(iter(train_data_loader)) | |
model = OrderClassifier(len(class_to_id), PRETRAINED_MODEL_PATH) | |
model = model.to(device) | |
input_ids = data['input_ids'].to(device) | |
attention_mask = data['attention_mask'].to(device) | |
EPOCHS = 3 | |
optimizer = AdamW(model.parameters(), lr=2e-5, correct_bias=False) | |
total_steps = len(train_data_loader) * EPOCHS | |
scheduler = get_linear_schedule_with_warmup( | |
optimizer, | |
num_warmup_steps=0, | |
num_training_steps=total_steps | |
) | |
loss_fn = nn.CrossEntropyLoss().to(device) | |
def train_epoch(model, data_loader, loss_fn, optimizer, | |
device, scheduler, n_examples): | |
model = model.train() | |
losses = [] | |
correct_predictions = 0 | |
print_interval = 10 | |
print_counter = 0 | |
for d in data_loader: | |
input_ids = d["input_ids"].to(device) | |
attention_mask = d["attention_mask"].to(device) | |
targets = d["targets"].to(device) | |
outputs = model( | |
input_ids=input_ids, | |
attention_mask=attention_mask | |
) | |
_, preds = torch.max(outputs, dim=1) | |
loss = loss_fn(outputs, targets) | |
print_counter += 1 | |
if print_counter % print_interval == 0: | |
print("batch {}, loss {}".format(print_counter, loss)) | |
correct_predictions += torch.sum(preds == targets) | |
losses.append(loss.item()) | |
loss.backward() | |
#nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) | |
optimizer.step() | |
scheduler.step() | |
optimizer.zero_grad() | |
return correct_predictions.double() / n_examples, np.mean(losses) | |
def eval_model(model, data_loader, loss_fn, device, n_examples): | |
model = model.eval() | |
losses = [] | |
correct_predictions = 0 | |
with torch.no_grad(): | |
for d in data_loader: | |
input_ids = d["input_ids"].to(device) | |
attention_mask = d["attention_mask"].to(device) | |
targets = d["targets"].to(device) | |
outputs = model( | |
input_ids=input_ids, | |
attention_mask=attention_mask | |
) | |
_, preds = torch.max(outputs, dim=1) | |
loss = loss_fn(outputs, targets) | |
correct_predictions += torch.sum(preds == targets) | |
losses.append(loss.item()) | |
return correct_predictions.double() / n_examples, np.mean(losses) | |
best_accuracy = 0 | |
history = defaultdict(list) | |
for epoch in range(EPOCHS): | |
print(f'Epoch {epoch + 1}/{EPOCHS}') | |
print('-' * 10) | |
train_acc, train_loss = train_epoch( | |
model, | |
train_data_loader, | |
loss_fn, | |
optimizer, | |
device, | |
scheduler, | |
len(train_texts) | |
) | |
print(f'Train loss {train_loss} accuracy {train_acc}') | |
val_acc, val_loss = eval_model( | |
model, | |
test_data_loader, | |
loss_fn, | |
device, | |
len(test_texts) | |
) | |
print(f'Val loss {val_loss} accuracy {val_acc}') | |
history['train_acc'].append(train_acc) | |
history['train_loss'].append(train_loss) | |
history['val_acc'].append(val_acc) | |
history['val_loss'].append(val_loss) | |
if val_acc > best_accuracy: | |
torch.save(model, 'best_model.bin') | |
best_accuracy = val_acc | |
test_acc, _ = eval_model( | |
model, | |
test_data_loader, | |
loss_fn, | |
device, | |
len(test_texts) | |
) | |
print("test acc {}".format(test_acc.item())) | |
plt.title('Training history') | |
plt.ylabel('Accuracy') | |
plt.xlabel('Epoch') | |
plt.legend() | |
plt.ylim([0, 1]) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment