Skip to content

Instantly share code, notes, and snippets.

@nickfox-taterli
Created January 25, 2025 15:36
Show Gist options
  • Save nickfox-taterli/9e7737d1be49a502dcb29e54a72235a4 to your computer and use it in GitHub Desktop.
Save nickfox-taterli/9e7737d1be49a502dcb29e54a72235a4 to your computer and use it in GitHub Desktop.
带注意力Seq2Seq #1
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
import random
# 示例训练数据 (中文->英文)
data = [
(['我', '爱', '学习'], ['I', 'love', 'studying']),
(['今天', '天气', '好'], ['Today', 'weather', 'is', 'good']),
(['他', '喜欢', '读书'], ['He', 'likes', 'reading']),
(['我们', '去', '公园'], ['We', 'go', 'to', 'park']),
(['这', '是', '苹果'], ['This', 'is', 'an', 'apple'])
]
# 构建词汇表
src_vocab = {'<pad>':0, '<sos>':1, '<eos>':2}
trg_vocab = {'<pad>':0, '<sos>':1, '<eos>':2}
for chn, eng in data:
for word in chn:
if word not in src_vocab:
src_vocab[word] = len(src_vocab)
for word in eng:
if word not in trg_vocab:
trg_vocab[word] = len(trg_vocab)
# 超参数
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
BATCH_SIZE = 2
ENC_EMB_DIM = 256
DEC_EMB_DIM = 256
HID_DIM = 512
N_LAYERS = 2
ENC_DROPOUT = 0.5
DEC_DROPOUT = 0.5
LEARNING_RATE = 0.001
CLIP = 1
# 自注意力模块
class SelfAttention(nn.Module):
def __init__(self, hid_dim):
super().__init__()
self.query = nn.Linear(hid_dim, hid_dim)
self.key = nn.Linear(hid_dim, hid_dim)
self.value = nn.Linear(hid_dim, hid_dim)
def forward(self, encoder_outputs, mask):
Q = self.query(encoder_outputs)
K = self.key(encoder_outputs)
V = self.value(encoder_outputs)
scores = torch.matmul(Q, K.transpose(1,2)) / (HID_DIM ** 0.5)
scores = scores.masked_fill(mask == 0, -1e9)
attn_weights = torch.softmax(scores, dim=-1)
return torch.matmul(attn_weights, V)
# 编码器
class Encoder(nn.Module):
def __init__(self, input_dim):
super().__init__()
self.embedding = nn.Embedding(input_dim, ENC_EMB_DIM)
self.lstm = nn.LSTM(ENC_EMB_DIM, HID_DIM, N_LAYERS, dropout=ENC_DROPOUT)
self.attention = SelfAttention(HID_DIM)
self.dropout = nn.Dropout(ENC_DROPOUT)
def forward(self, src, src_lengths):
embedded = self.dropout(self.embedding(src))
packed = pack_padded_sequence(embedded, src_lengths)
outputs, (hidden, cell) = self.lstm(packed)
outputs, _ = pad_packed_sequence(outputs)
outputs = outputs.permute(1, 0, 2)
return self.attention(outputs, self.create_mask(src)), hidden, cell
def create_mask(self, src):
return (src != src_vocab['<pad>']).permute(1, 0).unsqueeze(1)
# 解码器
class Decoder(nn.Module):
def __init__(self, output_dim):
super().__init__()
self.embedding = nn.Embedding(output_dim, DEC_EMB_DIM)
self.lstm = nn.LSTM(DEC_EMB_DIM + HID_DIM, HID_DIM, N_LAYERS, dropout=DEC_DROPOUT)
self.fc = nn.Linear(HID_DIM, output_dim)
self.dropout = nn.Dropout(DEC_DROPOUT)
def forward(self, input, hidden, cell, context):
input = input.unsqueeze(0)
embedded = self.dropout(self.embedding(input))
output, (hidden, cell) = self.lstm(
torch.cat((embedded, context.unsqueeze(0)), dim=2),
(hidden, cell)
)
prediction = self.fc(output.squeeze(0))
return prediction, hidden, cell
# Seq2Seq模型
class Seq2Seq(nn.Module):
def __init__(self, encoder, decoder, device):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.device = device
def forward(self, src, src_lengths, trg, teacher_forcing_ratio=0.5):
batch_size = src.shape[1]
trg_len = trg.shape[0]
trg_vocab_size = self.decoder.fc.out_features
outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)
context, hidden, cell = self.encoder(src, src_lengths)
input = trg[0,:]
for t in range(1, trg_len):
output, hidden, cell = self.decoder(input, hidden, cell, context.mean(1))
outputs[t] = output
teacher_force = random.random() < teacher_forcing_ratio
top1 = output.argmax(1)
input = trg[t] if teacher_force else top1
return outputs
# 初始化模型
enc = Encoder(len(src_vocab)).to(DEVICE)
dec = Decoder(len(trg_vocab)).to(DEVICE)
model = Seq2Seq(enc, dec, DEVICE).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss(ignore_index=trg_vocab['<pad>'])
# 数据预处理
def prepare_data(data):
src_tensors = []
trg_tensors = []
for chn, eng in data:
chn = [src_vocab['<sos>']] + [src_vocab[w] for w in chn] + [src_vocab['<eos>']]
eng = [trg_vocab['<sos>']] + [trg_vocab[w] for w in eng] + [trg_vocab['<eos>']]
src_tensors.append(torch.tensor(chn))
trg_tensors.append(torch.tensor(eng))
return pad_sequence(src_tensors, padding_value=src_vocab['<pad>']), \
pad_sequence(trg_tensors, padding_value=trg_vocab['<pad>'])
# 训练循环
for epoch in range(10):
src, trg = prepare_data(data)
src, trg = src.to(DEVICE), trg.to(DEVICE)
src_lengths = torch.sum(src != src_vocab['<pad>'], dim=0)
optimizer.zero_grad()
output = model(src, src_lengths, trg)
loss = criterion(output[1:].view(-1, output.shape[-1]), trg[1:].view(-1))
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP)
optimizer.step()
print(f'Epoch: {epoch+1:02} | Loss: {loss.item():.3f}')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment