Created
January 24, 2025 03:32
-
-
Save nickfox-taterli/4386808e517e8282a833c7772cc8bbfc to your computer and use it in GitHub Desktop.
自己改进的Seq2Seq
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import random | |
# 定义编码器 | |
class Encoder(nn.Module): | |
def __init__(self, input_dim, emb_dim, hid_dim, n_layers, dropout): | |
super().__init__() | |
self.hid_dim = hid_dim | |
self.n_layers = n_layers | |
# 定义嵌入层 | |
self.embedding = nn.Embedding(input_dim, emb_dim) | |
# 定义LSTM层 | |
self.rnn = nn.LSTM(emb_dim, hid_dim, n_layers, dropout = dropout) | |
self.dropout = nn.Dropout(dropout) | |
def forward(self, src): | |
# src = [src len, batch size] | |
embedded = self.dropout(self.embedding(src)) | |
# embedded = [src len, batch size, emb dim] | |
outputs, (hidden, cell) = self.rnn(embedded) | |
# outputs = [src len, batch size, hid dim * n directions] | |
# hidden = [n layers * n directions, batch size, hid dim] | |
# cell = [n layers * n directions, batch size, hid dim] | |
return hidden, cell | |
# 定义解码器 | |
class Decoder(nn.Module): | |
def __init__(self, output_dim, emb_dim, hid_dim, n_layers, dropout): | |
super().__init__() | |
self.output_dim = output_dim | |
self.hid_dim = hid_dim | |
self.n_layers = n_layers | |
# 定义嵌入层 | |
self.embedding = nn.Embedding(output_dim, emb_dim) | |
# 定义LSTM层 | |
self.rnn = nn.LSTM(emb_dim, hid_dim, n_layers, dropout = dropout) | |
# 定义全连接层 | |
self.fc_out = nn.Linear(hid_dim, output_dim) | |
self.dropout = nn.Dropout(dropout) | |
def forward(self, input, hidden, cell): | |
# input = [batch size] | |
input = input.unsqueeze(0) | |
# input = [1, batch size] | |
embedded = self.dropout(self.embedding(input)) | |
# embedded = [1, batch size, emb dim] | |
output, (hidden, cell) = self.rnn(embedded, (hidden, cell)) | |
# output = [seq len, batch size, hid dim * n directions] | |
# hidden = [n layers * n directions, batch size, hid dim] | |
# cell = [n layers * n directions, batch size, hid dim] | |
prediction = self.fc_out(output.squeeze(0)) | |
# prediction = [batch size, output dim] | |
return prediction, hidden, cell | |
# 定义Seq2Seq模型 | |
class Seq2Seq(nn.Module): | |
def __init__(self, encoder, decoder, device): | |
super().__init__() | |
self.encoder = encoder | |
self.decoder = decoder | |
self.device = device | |
def forward(self, src, trg, teacher_forcing_ratio = 0.5): | |
# src = [src len, batch size] | |
# trg = [trg len, batch size] | |
# teacher_forcing_ratio is probability to use teacher forcing | |
batch_size = trg.shape[1] | |
trg_len = trg.shape[0] | |
trg_vocab_size = self.decoder.output_dim | |
# 初始化输出张量 | |
outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device) | |
# 编码器前向传播 | |
hidden, cell = self.encoder(src) | |
# 初始输入是目标序列的起始符号 | |
input = trg[0,:] | |
for t in range(1, trg_len): | |
# 解码器前向传播 | |
output, hidden, cell = self.decoder(input, hidden, cell) | |
# 保存预测结果 | |
outputs[t] = output | |
# 决定是否使用教师强制 | |
teacher_force = random.random() < teacher_forcing_ratio | |
# 获取预测的最高概率的词 | |
top1 = output.argmax(1) | |
# 如果使用教师强制,下一个输入是真实的目标词;否则是预测的词 | |
input = trg[t] if teacher_force else top1 | |
return outputs | |
# 训练模型 | |
def train(model, iterator, optimizer, criterion, clip): | |
model.train() | |
epoch_loss = 0 | |
for i, batch in enumerate(iterator): | |
src = batch.src | |
trg = batch.trg | |
optimizer.zero_grad() | |
output = model(src, trg) | |
output_dim = output.shape[-1] | |
output = output[1:].view(-1, output_dim) | |
trg = trg[1:].view(-1) | |
loss = criterion(output, trg) | |
loss.backward() | |
torch.nn.utils.clip_grad_norm_(model.parameters(), clip) | |
optimizer.step() | |
epoch_loss += loss.item() | |
return epoch_loss / len(iterator) | |
# 评估模型 | |
def evaluate(model, iterator, criterion): | |
model.eval() | |
epoch_loss = 0 | |
with torch.no_grad(): | |
for i, batch in enumerate(iterator): | |
src = batch.src | |
trg = batch.trg | |
output = model(src, trg, 0) | |
output_dim = output.shape[-1] | |
output = output[1:].view(-1, output_dim) | |
trg = trg[1:].view(-1) | |
loss = criterion(output, trg) | |
epoch_loss += loss.item() | |
return epoch_loss / len(iterator) | |
# 示例参数设置 | |
INPUT_DIM = 1000 | |
OUTPUT_DIM = 1000 | |
ENC_EMB_DIM = 256 | |
DEC_EMB_DIM = 256 | |
HID_DIM = 512 | |
N_LAYERS = 2 | |
ENC_DROPOUT = 0.5 | |
DEC_DROPOUT = 0.5 | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# 初始化编码器、解码器和Seq2Seq模型 | |
enc = Encoder(INPUT_DIM, ENC_EMB_DIM, HID_DIM, N_LAYERS, ENC_DROPOUT) | |
dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, HID_DIM, N_LAYERS, DEC_DROPOUT) | |
model = Seq2Seq(enc, dec, device).to(device) | |
# 定义优化器和损失函数 | |
optimizer = optim.Adam(model.parameters()) | |
criterion = nn.CrossEntropyLoss() | |
# 这里需要你自己准备数据集和数据加载器,示例代码中未给出具体实现 | |
# train_iterator = ... | |
# valid_iterator = ... | |
# 训练模型 | |
# N_EPOCHS = 10 | |
# CLIP = 1 | |
# for epoch in range(N_EPOCHS): | |
# train_loss = train(model, train_iterator, optimizer, criterion, CLIP) | |
# valid_loss = evaluate(model, valid_iterator, criterion) | |
# print(f'Epoch: {epoch+1:02} | Train Loss: {train_loss:.3f} | Val. Loss: {valid_loss:.3f}') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment