Last active
January 22, 2024 09:27
-
-
Save selfboot/8a0cb6129d000a01e0e3605f829b62ea to your computer and use it in GitHub Desktop.
Train a classification task, see Full Introduction https://selfboot.cn/2023/12/06/bert_nlp_classify/
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!pip install torch | |
#!pip install transformers | |
#!pip install scikit-learn | |
#!pip install numpy | |
import json | |
from sklearn.model_selection import train_test_split | |
import random | |
from datetime import datetime | |
from transformers import BertTokenizer, BertModel | |
import torch | |
import torch.nn as nn | |
from torch.utils.data import DataLoader, TensorDataset | |
from torch.optim import AdamW | |
from sklearn.metrics import f1_score, recall_score, accuracy_score, precision_score | |
import torch.nn.functional as F | |
import sys | |
# 打开一个文件用于记录日志 | |
log_file = open('lawer_train_output.log', 'w', buffering=1) | |
sys.stdout = log_file | |
sys.stderr = log_file | |
print("begin") | |
# 读取剧透数据 | |
lawer_data = [] | |
with open('./train_lawer.json', 'r') as f: | |
for line in f: | |
lawer_data.append(json.loads(line)) | |
# 读取非剧透数据 | |
not_lawer_data = [] | |
with open('./train_notlawer.json', 'r') as f: | |
for line in f: | |
not_lawer_data.append(json.loads(line)) | |
lawer_data = [(str(d['content']), 1) for d in lawer_data] | |
not_lawer_data = [(str(d['content']), 0) for d in not_lawer_data] | |
# 合并剧透和非剧透数据,并打乱 | |
merged_data = lawer_data + not_lawer_data | |
random.shuffle(merged_data) | |
# 分离特征和标签 | |
X, y = zip(*merged_data) | |
# 划分数据集: 80% 的数据用于训练,20% 用于验证。 | |
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42) | |
tokenizer = BertTokenizer.from_pretrained('./bert-base-chinese') | |
def tokenize(content, max_length=512): | |
truncated_content = [] | |
for t in content: | |
t = t if t is not None else "" | |
encoded = tokenizer.encode_plus(t, | |
max_length=max_length, | |
padding='max_length', | |
truncation=True, | |
return_tensors="pt") | |
truncated_content.append(encoded['input_ids']) | |
input_ids = torch.cat(truncated_content) | |
attention_mask = (input_ids != tokenizer.pad_token_id).type(input_ids.type()) | |
return {"input_ids": input_ids, "attention_mask": attention_mask} | |
train_encodings = tokenize(X_train) | |
val_encodings = tokenize(X_val) | |
train_dataset = TensorDataset(train_encodings['input_ids'], train_encodings['attention_mask'], torch.tensor(y_train)) | |
val_dataset = TensorDataset(val_encodings['input_ids'], val_encodings['attention_mask'], torch.tensor(y_val)) | |
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, drop_last=True) | |
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, drop_last=False) | |
# Focal Loss 是一种在处理高度不平衡的分类问题中非常有效的损失函数。 | |
class FocalLoss(nn.Module): | |
def __init__(self, gamma=2, alpha=0.3, reduction='mean'): | |
super(FocalLoss, self).__init__() | |
self.gamma = gamma | |
self.alpha = alpha | |
self.reduction = reduction | |
def forward(self, inputs, targets): | |
BCE_loss = F.cross_entropy(inputs, targets, reduction='none') | |
pt = torch.exp(-BCE_loss) | |
F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss | |
if self.reduction == 'sum': | |
return F_loss.sum() | |
elif self.reduction == 'mean': | |
return F_loss.mean() | |
class SingleInputBert(nn.Module): | |
def __init__(self): | |
super(SingleInputBert, self).__init__() | |
self.bert = BertModel.from_pretrained('./bert-base-chinese') | |
self.dropout = nn.Dropout(0.5) | |
self.classifier = nn.Linear(self.bert.config.hidden_size, 2) | |
def forward(self, input_ids, attention_mask): | |
outputs = self.bert(input_ids, attention_mask=attention_mask) | |
pooled_output = outputs.pooler_output | |
logits = self.classifier(pooled_output) | |
return logits | |
model = SingleInputBert() | |
model = nn.DataParallel(model) | |
optimizer = AdamW(model.parameters(), lr=3e-5, weight_decay=0.01) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f'Cuda: {torch.cuda.is_available()}') | |
model.to(device) | |
criterion = FocalLoss(gamma=2, alpha=1, reduction='mean') # 使用 FocalLoss 作为损失函数 | |
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=2, verbose=True) | |
def evaluate_model(model, data_loader, device): | |
model.eval() # 确保模型处于评估模式 | |
y_true = [] | |
y_pred = [] | |
total_loss = 0.0 | |
total_batches = 0 | |
with torch.no_grad(): | |
for batch in data_loader: | |
input_ids, attention_mask, labels = [b.to(device) for b in batch] | |
logits = model(input_ids, attention_mask=attention_mask) | |
loss = criterion(logits.view(-1, logits.size(-1)), labels.view(-1)) | |
total_loss += loss.item() | |
total_batches += 1 | |
preds = torch.argmax(logits, dim=-1) | |
y_true.extend(labels.cpu().numpy()) | |
y_pred.extend(preds.cpu().numpy()) | |
average_loss = total_loss / total_batches | |
accuracy = accuracy_score(y_true, y_pred) | |
precision = precision_score(y_true, y_pred) | |
recall = recall_score(y_true, y_pred) | |
f1 = f1_score(y_true, y_pred) | |
return accuracy, precision, recall, f1, average_loss | |
current_date = datetime.now().strftime("%Y%m%d") | |
save_path = f"./lawer_{current_date}.pt" | |
# 设定一些早停参数 | |
best_precision = 0.0 | |
patience = 10 | |
no_improve = 0 | |
print("Begin Epoch Training...") | |
model.train() | |
for epoch in range(50): # 最多训练50轮 | |
print(f"Epoch: {epoch+1}") | |
for batch in train_loader: | |
optimizer.zero_grad() | |
input_ids, attention_mask, labels = [b.to(device) for b in batch] | |
logits = model(input_ids, attention_mask=attention_mask) | |
loss = criterion(logits.view(-1, logits.size(-1)), labels.view(-1)) | |
loss.backward() | |
optimizer.step() | |
train_accuracy, train_precision, train_recall, train_f1, train_loss = evaluate_model(model, train_loader, device) | |
val_accuracy, val_precision, val_recall, val_f1, val_loss = evaluate_model(model, val_loader, device) | |
scheduler.step(val_loss) | |
print(f"Epoch {epoch + 1} - Training Accuracy: {train_accuracy}, Precision: {train_precision}, Recall: {train_recall}, F1 Score: {train_f1}, Loss: {train_loss}") | |
print(f"Epoch {epoch + 1} - Validation Accuracy: {val_accuracy}, Precision: {val_precision}, Recall: {val_recall}, F1 Score: {val_f1}, Loss: {val_loss}") | |
if val_precision > best_precision: | |
best_precision = val_precision | |
torch.save(model.state_dict(), save_path) | |
print(f"Save model to {save_path}, precision: {val_precision}") | |
no_improve = 0 | |
else: | |
no_improve += 1 | |
if no_improve >= patience: | |
print("Early stopping due to no improvement.") | |
break | |
log_file.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment