-
-
Save kasuganosora/cd93062861944860fa33c74199aa3c44 to your computer and use it in GitHub Desktop.
中文分词模型
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
from torchcrf import CRF | |
from transformers import BertTokenizer, BertModel,AdamW | |
from torch.utils.data import DataLoader, Dataset | |
import random | |
import pandas as pd | |
from sklearn.model_selection import train_test_split | |
import torch.nn.functional as F | |
from torch.nn.utils.rnn import pack_sequence, pad_packed_sequence | |
from torch.cuda.amp import GradScaler, autocast | |
import os | |
class BERT_CRF(nn.Module): | |
def __init__(self, bert, num_tags): | |
super().__init__() | |
self.bert = bert | |
self.dropout = nn.Dropout(0.1) | |
self.classifier = nn.Linear(bert.config.hidden_size, num_tags) | |
self.crf = CRF(num_tags, batch_first=True) | |
def forward(self, input_ids, attention_mask, tags=None): | |
outputs = self.bert(input_ids, attention_mask) | |
sequence_output = outputs[0] | |
sequence_output = self.dropout(sequence_output) | |
logits = self.classifier(sequence_output) | |
mask = attention_mask.bool() | |
if tags is not None: | |
loss = -self.crf(logits, tags, mask=mask, reduction='mean') | |
return loss | |
else: | |
return logits | |
def decode(self, logits, attention_mask): | |
mask = attention_mask.bool() | |
return self.crf.decode(logits, mask=mask) | |
# 创建标签到ID的映射 | |
tag2id = {"S": 0, "B": 1, "M": 2, "E": 3, 'P': 4, "L": 5, "PAD": 6} | |
def tags_to_segmented_text(tag_ids, id2tag, text): | |
tags = [tag for tag_id_list in tag_ids for tag_id in tag_id_list for tag, id_ in tag2id.items() if tag_id == id_] | |
segmented_text = "" | |
for idx, (char, tag) in enumerate(zip(text, tags)): | |
if tag in ["S", "B"]: | |
segmented_text += " / " + char | |
elif tag in ["M", "E"]: | |
segmented_text += char | |
elif tag in ["P", "L"]: | |
segmented_text += " / " + char | |
return segmented_text.strip() | |
# 构建数据集 | |
class ChineseSegmentationDataset(Dataset): | |
def __init__(self, data, tokenizer, tag2id, max_length=256): | |
self.data = data | |
self.tokenizer = tokenizer | |
self.tag2id = tag2id | |
self.max_length = max_length | |
self.data = self._filter_data(data) | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, idx): | |
text, tags = self.data[idx] | |
# 使用encode_plus方法 | |
inputs = self.tokenizer.encode_plus( | |
text, | |
max_length=self.max_length, | |
padding='max_length', | |
truncation=True, | |
return_attention_mask=True, | |
return_tensors='pt', | |
) | |
# 调整tags长度以与输入tokens对齐 | |
tags_list = list(tags) | |
padded_tags = tags_list[:self.max_length - 2] + ["PAD"] * (self.max_length - len(tags_list) - 2) | |
tag_ids = [self.tag2id[tag] for tag in padded_tags] | |
return { | |
"input_ids": inputs["input_ids"].squeeze(), | |
"attention_mask": inputs["attention_mask"].squeeze(), | |
"tags": torch.tensor(tag_ids, dtype=torch.long), | |
} | |
def _filter_data(self, data): | |
filtered_data = [] | |
for text, tags in data: | |
if not text.strip(): | |
continue | |
if len(text) != len(tags): | |
print(f"Skipping data: {text} due to inconsistent length between tokens and tags") | |
continue | |
if len(text) > self.max_length -2: | |
print(f"Skipping data: {text} due to length exceeding max length of {self.max_length}") | |
continue | |
filtered_data.append((text, tags)) | |
return filtered_data | |
def collate_fn(batch): | |
max_length = max([len(item["input_ids"]) for item in batch]) | |
input_ids = torch.zeros((len(batch), max_length), dtype=torch.long) | |
attention_mask = torch.zeros((len(batch), max_length), dtype=torch.long) | |
tags = torch.zeros((len(batch), max_length), dtype=torch.long) | |
for i, item in enumerate(batch): | |
input_ids[i, :len(item["input_ids"])] = item["input_ids"] | |
attention_mask[i, :len(item["attention_mask"])] = item["attention_mask"] | |
tags[i, :len(item["tags"])] = item["tags"] | |
return { | |
"input_ids": input_ids, | |
"attention_mask": attention_mask, | |
"tags": tags, | |
} | |
# 设置随机种子以获得可重复的结果 | |
def set_seed(seed): | |
random.seed(seed) | |
torch.manual_seed(seed) | |
if __name__ == "__main__": | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# 从CSV文件中读取数据 | |
csv_file = "data/tweet.js.csv" # 替换为你的CSV文件名 | |
df = pd.read_csv(csv_file, header=None, names=["text", "tags"]) | |
# 将数据分为训练集和验证集 | |
train_df, val_df = train_test_split(df, test_size=0.2, random_state=42) | |
# 加载预训练的中文BERT分词器 | |
tokenizer = BertTokenizer.from_pretrained("bert-base-chinese") | |
# 将DataFrame转换为数据集所需的元组列表格式 | |
train_data = list(train_df.itertuples(index=False, name=None)) | |
val_data = list(val_df.itertuples(index=False, name=None)) | |
# 使用加载的数据创建训练和验证数据集 | |
train_dataset = ChineseSegmentationDataset(train_data, tokenizer, tag2id) | |
val_dataset = ChineseSegmentationDataset(val_data, tokenizer, tag2id) | |
set_seed(42) | |
# 创建BERT_CRF模型实例 | |
bert_model = BertModel.from_pretrained("bert-base-chinese") | |
model = BERT_CRF(bert_model, num_tags=len(tag2id)).to(device) | |
model.train() | |
# 创建DataLoader | |
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn) | |
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn) | |
# 设置优化器和学习率 | |
optimizer = AdamW(model.parameters(), lr=1e-5) | |
scaler = GradScaler() | |
# 判断之前的模型是否存在,如果存在则加载之前的模型 | |
if os.path.exists("model_p.pth"): | |
model.load_state_dict(torch.load("model_p.pth")) | |
# 训练循环 | |
num_epochs = 100 | |
for epoch in range(num_epochs): | |
for batch_idx, batch in enumerate(train_dataloader): | |
input_ids = batch["input_ids"].squeeze().to(device) | |
attention_mask = batch["attention_mask"].squeeze().to(device) | |
tags = batch["tags"].squeeze().to(device) | |
with autocast(): | |
loss = model(input_ids, attention_mask, tags) | |
scaler.scale(loss).backward() | |
scaler.step(optimizer) | |
scaler.update() | |
optimizer.zero_grad() | |
print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx+1}/{len(train_dataloader)}, Loss: {loss.item()}") | |
print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}") | |
torch.save(model.state_dict(), f"model_p.pth") | |
# 验证阶段 | |
# 每2个epoch进行一次验证 | |
if (epoch + 1) % 2 == 0: | |
model.eval() | |
val_loss = 0 | |
with torch.no_grad(): | |
for batch_idx, batch in enumerate(val_dataloader): | |
input_ids = batch["input_ids"].squeeze().to(device) | |
attention_mask = batch["attention_mask"].squeeze().to(device) | |
tags = batch["tags"].squeeze().to(device) | |
loss = model(input_ids, attention_mask, tags) | |
val_loss += loss.item() | |
val_loss /= len(val_dataloader) | |
print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {loss.item()}, Val Loss: {val_loss}") | |
# 将模型切换回训练模式 | |
model.train() | |
# 保存模型 | |
torch.save(model.state_dict(), "model.pth") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from a import * | |
from transformers import BertTokenizer, BertModel | |
import torch | |
import os | |
bert_model = BertModel.from_pretrained("bert-base-chinese") | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = BERT_CRF(bert_model, num_tags=len(tag2id)).to(device) | |
# 加载预训练的中文BERT分词器 | |
tokenizer = BertTokenizer.from_pretrained("bert-base-chinese") | |
# 将模型设置为评估模式 | |
model.eval() | |
# 加载权重 | |
if os.path.exists("model_p.pth"): | |
model.load_state_dict(torch.load("model_p.pth")) | |
else: | |
print("model not exist") | |
os._exit(1) | |
# 预测文本 | |
text = """下面这个网址是https://www.google.com/,这是一个搜索引擎""" | |
# 使用分词器对输入文本进行编码 | |
inputs = tokenizer(text, return_tensors="pt", padding="max_length", max_length=256, truncation=True) | |
# 确保输入的数据类型正确 | |
inputs["input_ids"] = inputs["input_ids"].long().to(device) | |
inputs["attention_mask"] = inputs["attention_mask"].bool().to(device) | |
# 从模型中获取logits | |
with torch.no_grad(): | |
logits = model(inputs["input_ids"], inputs["attention_mask"]) | |
# 使用decode方法从logits中预测出tag_ids | |
tag_ids = model.decode(logits, inputs["attention_mask"].bool()) | |
# 将预测的tag_ids转换为分词后的文本 | |
segmented_text = tags_to_segmented_text(tag_ids, tag2id, text) | |
print(segmented_text) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import jieba | |
import json | |
import os | |
import string | |
import glob | |
import csv | |
# 把twitter 导出来的推文作为分词语料, 使用jieba分词 | |
# S: 单字成词 | |
# B: 词的开始 | |
# M: 词的中间 | |
# E: 词的结束 | |
# P: 标点符号 | |
tag2id = {"S": 0, "B": 1, "M": 2, "E": 3, 'P':4, 'U': 5} | |
# 读取语料 | |
def read_corpus(corpus_path): | |
data = [] | |
# 打开语料的js文件, 然后把第一行的直 [ 之前的内容去掉 | |
with open(corpus_path, 'r', encoding='utf-8') as f: | |
lines = f.readlines() | |
lines[0] = "[ {" | |
lines = "".join(lines) | |
# 作为json格式读取 | |
items = json.loads(lines) | |
for item in items: | |
text = item["full_text"] | |
# 将连续空格和回车替换为单个空格 | |
text = text.replace('\n', ' ').replace('\r', ' ').replace('\t', ' ') | |
text = ' '.join(text.split()) # 也可以使用正则表达式来替换 | |
text = text.strip() | |
tags = segment(text) | |
if len(tags) == 0: | |
continue | |
if len(tags) != len(text): | |
print("error") | |
print(text) | |
print(tags) | |
continue | |
data.append((text, ''.join(tags))) | |
return data | |
# 分词 | |
def segment(text): | |
words = jieba.cut(text) | |
nw = [] | |
for w in words: | |
nw.append(w) | |
words = nw | |
nw = [] | |
l = len(words) | |
i = 0 | |
# 如果遇到@/#的时候后面一段是字 并且后面有一个空格, 那么就把@/#和后面的字合并 | |
while i < l: | |
w = words[i] | |
if w == "@" or w == "#": | |
s = [] | |
while i < l and words[i] != ' ': | |
s.append(words[i]) | |
i = i + 1 | |
nw.append(''.join(s)) | |
continue | |
nw.append(w) | |
i = i + 1 | |
words = nw | |
# 生成标签 | |
tags = [] | |
i = 0 | |
l = len(words) | |
while i < l: | |
w = words[i] | |
# 如果是单字并且是符号, 则标记为P | |
if len(w) == 1 and w in string.punctuation or w == ' ': | |
tags.append('P') | |
i = i + 1 | |
continue | |
# 如果是单字并且不是符号, 则标记为S | |
if len(w) == 1: | |
tags.append('S') | |
i = i + 1 | |
continue | |
# 判断是否是链接,如果是链接则开头标记为B,中间标记为M,结尾标记为E | |
# 判断依据是当前第二个词是否为":",第三个词是否为"/",第四个词是否为"/" | |
if l > i+3 and words[i+1] == ":" and words[i+2] == "/" and words[i+3] == "/": | |
# 例如: https://www.google.com 那么就是: BMMMMMMMMMMMMMMMMMME | |
first = True | |
while i < l and words[i] != '': | |
for w in words[i]: | |
if first: | |
tags.append('B') | |
first = False | |
else: | |
tags.append('M') | |
i = i + 1 | |
tags[-1] = 'E' | |
continue | |
# 如果是多字, 则开头标记为B,中间标记为M,结尾标记为E | |
first = True | |
for w in w: | |
if first: | |
tags.append('B') | |
first = False | |
else: | |
tags.append('M') | |
tags[-1] = 'E' | |
i = i + 1 | |
return tags | |
if __name__ == '__main__': | |
# 在当前目录下创建data文件夹 | |
if not os.path.exists('data'): | |
os.mkdir('data') | |
corpusPath = "E:\\tw" | |
#遍历下面的 tweet-part*.json 文件 | |
file_list = glob.glob(corpusPath + '\\tweet-part*.js') | |
file_list.append("E:\\tw\\tweet.js") | |
print(file_list) | |
for file in file_list: | |
# 建立语料的cvs文件 | |
dataFileName = "data/" + os.path.basename(file) + '.csv' | |
with open(dataFileName, 'w', encoding='utf-8') as f: | |
data = read_corpus(file) | |
writer = csv.writer(f) | |
writer.writerows(data) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment