Created
January 25, 2025 15:36
-
-
Save nickfox-taterli/9e7737d1be49a502dcb29e54a72235a4 to your computer and use it in GitHub Desktop.
带注意力Seq2Seq #1
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence | |
import random | |
# 示例训练数据 (中文->英文) | |
data = [ | |
(['我', '爱', '学习'], ['I', 'love', 'studying']), | |
(['今天', '天气', '好'], ['Today', 'weather', 'is', 'good']), | |
(['他', '喜欢', '读书'], ['He', 'likes', 'reading']), | |
(['我们', '去', '公园'], ['We', 'go', 'to', 'park']), | |
(['这', '是', '苹果'], ['This', 'is', 'an', 'apple']) | |
] | |
# 构建词汇表 | |
src_vocab = {'<pad>':0, '<sos>':1, '<eos>':2} | |
trg_vocab = {'<pad>':0, '<sos>':1, '<eos>':2} | |
for chn, eng in data: | |
for word in chn: | |
if word not in src_vocab: | |
src_vocab[word] = len(src_vocab) | |
for word in eng: | |
if word not in trg_vocab: | |
trg_vocab[word] = len(trg_vocab) | |
# 超参数 | |
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
BATCH_SIZE = 2 | |
ENC_EMB_DIM = 256 | |
DEC_EMB_DIM = 256 | |
HID_DIM = 512 | |
N_LAYERS = 2 | |
ENC_DROPOUT = 0.5 | |
DEC_DROPOUT = 0.5 | |
LEARNING_RATE = 0.001 | |
CLIP = 1 | |
# 自注意力模块 | |
class SelfAttention(nn.Module): | |
def __init__(self, hid_dim): | |
super().__init__() | |
self.query = nn.Linear(hid_dim, hid_dim) | |
self.key = nn.Linear(hid_dim, hid_dim) | |
self.value = nn.Linear(hid_dim, hid_dim) | |
def forward(self, encoder_outputs, mask): | |
Q = self.query(encoder_outputs) | |
K = self.key(encoder_outputs) | |
V = self.value(encoder_outputs) | |
scores = torch.matmul(Q, K.transpose(1,2)) / (HID_DIM ** 0.5) | |
scores = scores.masked_fill(mask == 0, -1e9) | |
attn_weights = torch.softmax(scores, dim=-1) | |
return torch.matmul(attn_weights, V) | |
# 编码器 | |
class Encoder(nn.Module): | |
def __init__(self, input_dim): | |
super().__init__() | |
self.embedding = nn.Embedding(input_dim, ENC_EMB_DIM) | |
self.lstm = nn.LSTM(ENC_EMB_DIM, HID_DIM, N_LAYERS, dropout=ENC_DROPOUT) | |
self.attention = SelfAttention(HID_DIM) | |
self.dropout = nn.Dropout(ENC_DROPOUT) | |
def forward(self, src, src_lengths): | |
embedded = self.dropout(self.embedding(src)) | |
packed = pack_padded_sequence(embedded, src_lengths) | |
outputs, (hidden, cell) = self.lstm(packed) | |
outputs, _ = pad_packed_sequence(outputs) | |
outputs = outputs.permute(1, 0, 2) | |
return self.attention(outputs, self.create_mask(src)), hidden, cell | |
def create_mask(self, src): | |
return (src != src_vocab['<pad>']).permute(1, 0).unsqueeze(1) | |
# 解码器 | |
class Decoder(nn.Module): | |
def __init__(self, output_dim): | |
super().__init__() | |
self.embedding = nn.Embedding(output_dim, DEC_EMB_DIM) | |
self.lstm = nn.LSTM(DEC_EMB_DIM + HID_DIM, HID_DIM, N_LAYERS, dropout=DEC_DROPOUT) | |
self.fc = nn.Linear(HID_DIM, output_dim) | |
self.dropout = nn.Dropout(DEC_DROPOUT) | |
def forward(self, input, hidden, cell, context): | |
input = input.unsqueeze(0) | |
embedded = self.dropout(self.embedding(input)) | |
output, (hidden, cell) = self.lstm( | |
torch.cat((embedded, context.unsqueeze(0)), dim=2), | |
(hidden, cell) | |
) | |
prediction = self.fc(output.squeeze(0)) | |
return prediction, hidden, cell | |
# Seq2Seq模型 | |
class Seq2Seq(nn.Module): | |
def __init__(self, encoder, decoder, device): | |
super().__init__() | |
self.encoder = encoder | |
self.decoder = decoder | |
self.device = device | |
def forward(self, src, src_lengths, trg, teacher_forcing_ratio=0.5): | |
batch_size = src.shape[1] | |
trg_len = trg.shape[0] | |
trg_vocab_size = self.decoder.fc.out_features | |
outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device) | |
context, hidden, cell = self.encoder(src, src_lengths) | |
input = trg[0,:] | |
for t in range(1, trg_len): | |
output, hidden, cell = self.decoder(input, hidden, cell, context.mean(1)) | |
outputs[t] = output | |
teacher_force = random.random() < teacher_forcing_ratio | |
top1 = output.argmax(1) | |
input = trg[t] if teacher_force else top1 | |
return outputs | |
# 初始化模型 | |
enc = Encoder(len(src_vocab)).to(DEVICE) | |
dec = Decoder(len(trg_vocab)).to(DEVICE) | |
model = Seq2Seq(enc, dec, DEVICE).to(DEVICE) | |
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE) | |
criterion = nn.CrossEntropyLoss(ignore_index=trg_vocab['<pad>']) | |
# 数据预处理 | |
def prepare_data(data): | |
src_tensors = [] | |
trg_tensors = [] | |
for chn, eng in data: | |
chn = [src_vocab['<sos>']] + [src_vocab[w] for w in chn] + [src_vocab['<eos>']] | |
eng = [trg_vocab['<sos>']] + [trg_vocab[w] for w in eng] + [trg_vocab['<eos>']] | |
src_tensors.append(torch.tensor(chn)) | |
trg_tensors.append(torch.tensor(eng)) | |
return pad_sequence(src_tensors, padding_value=src_vocab['<pad>']), \ | |
pad_sequence(trg_tensors, padding_value=trg_vocab['<pad>']) | |
# 训练循环 | |
for epoch in range(10): | |
src, trg = prepare_data(data) | |
src, trg = src.to(DEVICE), trg.to(DEVICE) | |
src_lengths = torch.sum(src != src_vocab['<pad>'], dim=0) | |
optimizer.zero_grad() | |
output = model(src, src_lengths, trg) | |
loss = criterion(output[1:].view(-1, output.shape[-1]), trg[1:].view(-1)) | |
loss.backward() | |
torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP) | |
optimizer.step() | |
print(f'Epoch: {epoch+1:02} | Loss: {loss.item():.3f}') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment