Skip to content

Instantly share code, notes, and snippets.

@nickfox-taterli
Created January 24, 2025 03:32
Show Gist options
  • Save nickfox-taterli/3e82aad6d2a0eac9cfe23fe75aa1bf1a to your computer and use it in GitHub Desktop.
Save nickfox-taterli/3e82aad6d2a0eac9cfe23fe75aa1bf1a to your computer and use it in GitHub Desktop.
Seq2Seq 例子
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "4d476dd6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" 句子数量: 5\n",
" 中文词汇表大小: 18\n",
" 英文词汇表大小: 20\n",
" 中文词汇到索引的字典: {'处理': 0, '小冰': 1, '深度学习': 2, '复杂': 3, '人工智能': 4, '我': 5, '喜欢': 6, '强大': 7, '非常': 8, '自然': 9, '学习': 10, '语言': 11, '改变': 12, '爱': 13, '神经网络': 14, '咖哥': 15, '很': 16, '世界': 17}\n",
" 英文词汇到索引的字典: {'world': 0, 'are': 1, 'is': 2, 'changed': 3, 'Neural-Nets': 4, 'DL': 5, 'KaGe': 6, 'likes': 7, 'XiaoBing': 8, 'AI': 9, 'complex': 10, 'the': 11, 'love': 12, 'powerful': 13, 'I': 14, 'NLP': 15, '<eos>': 16, '<sos>': 17, 'studying': 18, 'so': 19}\n"
]
}
],
"source": [
"# 构建语料库,每行包含中文、英文(解码器输入)和翻译成英文后的目标输出 3 个句子\n",
"sentences = [\n",
" ['咖哥 喜欢 小冰', '<sos> KaGe likes XiaoBing', 'KaGe likes XiaoBing <eos>'],\n",
" ['我 爱 学习 人工智能', '<sos> I love studying AI', 'I love studying AI <eos>'],\n",
" ['深度学习 改变 世界', '<sos> DL changed the world', 'DL changed the world <eos>'],\n",
" ['自然 语言 处理 很 强大', '<sos> NLP is so powerful', 'NLP is so powerful <eos>'],\n",
" ['神经网络 非常 复杂', '<sos> Neural-Nets are complex', 'Neural-Nets are complex <eos>']]\n",
"word_list_cn, word_list_en = [], [] # 初始化中英文词汇表\n",
"# 遍历每一个句子并将单词添加到词汇表中\n",
"for s in sentences:\n",
" word_list_cn.extend(s[0].split())\n",
" word_list_en.extend(s[1].split())\n",
" word_list_en.extend(s[2].split())\n",
"# 去重,得到没有重复单词的词汇表\n",
"word_list_cn = list(set(word_list_cn))\n",
"word_list_en = list(set(word_list_en))\n",
"# 构建单词到索引的映射\n",
"word2idx_cn = {w: i for i, w in enumerate(word_list_cn)}\n",
"word2idx_en = {w: i for i, w in enumerate(word_list_en)}\n",
"# 构建索引到单词的映射\n",
"idx2word_cn = {i: w for i, w in enumerate(word_list_cn)}\n",
"idx2word_en = {i: w for i, w in enumerate(word_list_en)}\n",
"# 计算词汇表的大小\n",
"voc_size_cn = len(word_list_cn)\n",
"voc_size_en = len(word_list_en)\n",
"print(\" 句子数量:\", len(sentences)) # 打印句子数\n",
"print(\" 中文词汇表大小:\", voc_size_cn) # 打印中文词汇表大小\n",
"print(\" 英文词汇表大小:\", voc_size_en) # 打印英文词汇表大小\n",
"print(\" 中文词汇到索引的字典:\", word2idx_cn) # 打印中文词汇到索引的字典\n",
"print(\" 英文词汇到索引的字典:\", word2idx_en) # 打印英文词汇到索引的字典"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "2cf92365",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" 原始句子: ['我 爱 学习 人工智能', '<sos> I love studying AI', 'I love studying AI <eos>']\n",
" 编码器输入张量的形状: torch.Size([1, 4])\n",
" 解码器输入张量的形状: torch.Size([1, 5])\n",
" 目标张量的形状: torch.Size([1, 5])\n",
" 编码器输入张量: tensor([[ 5, 13, 10, 4]])\n",
" 解码器输入张量: tensor([[17, 14, 12, 18, 9]])\n",
" 目标张量: tensor([[14, 12, 18, 9, 16]])\n"
]
}
],
"source": [
"import numpy as np # 导入 numpy\n",
"import torch # 导入 torch\n",
"import random # 导入 random 库\n",
"# 定义一个函数,随机选择一个句子和词汇表生成输入、输出和目标数据\n",
"def make_data(sentences):\n",
" # 随机选择一个句子进行训练\n",
" random_sentence = random.choice(sentences)\n",
" # 将输入句子中的单词转换为对应的索引\n",
" encoder_input = np.array([[word2idx_cn[n] for n in random_sentence[0].split()]])\n",
" # 将输出句子中的单词转换为对应的索引\n",
" decoder_input = np.array([[word2idx_en[n] for n in random_sentence[1].split()]])\n",
" # 将目标句子中的单词转换为对应的索引\n",
" target = np.array([[word2idx_en[n] for n in random_sentence[2].split()]])\n",
" # 将输入、输出和目标批次转换为 LongTensor\n",
" encoder_input = torch.LongTensor(encoder_input)\n",
" decoder_input = torch.LongTensor(decoder_input)\n",
" target = torch.LongTensor(target)\n",
" return encoder_input, decoder_input, target \n",
"# 使用 make_data 函数生成输入、输出和目标张量\n",
"encoder_input, decoder_input, target = make_data(sentences)\n",
"for s in sentences: # 获取原始句子\n",
" if all([word2idx_cn[w] in encoder_input[0] for w in s[0].split()]):\n",
" original_sentence = s\n",
" break\n",
"print(\" 原始句子:\", original_sentence) # 打印原始句子\n",
"print(\" 编码器输入张量的形状:\", encoder_input.shape) # 打印输入张量形状\n",
"print(\" 解码器输入张量的形状:\", decoder_input.shape) # 打印输出张量形状\n",
"print(\" 目标张量的形状:\", target.shape) # 打印目标张量形状\n",
"print(\" 编码器输入张量:\", encoder_input) # 打印输入张量\n",
"print(\" 解码器输入张量:\", decoder_input) # 打印输出张量\n",
"print(\" 目标张量:\", target) # 打印目标张量\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "3b8b3d1d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" 编码器结构: Encoder(\n",
" (embedding): Embedding(18, 128)\n",
" (rnn): RNN(128, 128, batch_first=True)\n",
")\n",
" 解码器结构: Decoder(\n",
" (embedding): Embedding(20, 128)\n",
" (rnn): RNN(128, 128, batch_first=True)\n",
" (out): Linear(in_features=128, out_features=20, bias=True)\n",
")\n"
]
}
],
"source": [
"import torch.nn as nn # 导入 torch.nn 库\n",
"# 定义编码器类,继承自 nn.Module\n",
"class Encoder(nn.Module):\n",
" def __init__(self, input_size, hidden_size):\n",
" super(Encoder, self).__init__() \n",
" self.hidden_size = hidden_size # 设置隐藏层大小 \n",
" self.embedding = nn.Embedding(input_size, hidden_size) # 创建词嵌入层 \n",
" self.rnn = nn.RNN(hidden_size, hidden_size, batch_first=True) # 创建 RNN 层 \n",
" def forward(self, inputs, hidden): # 前向传播函数\n",
" embedded = self.embedding(inputs) # 将输入转换为嵌入向量 \n",
" output, hidden = self.rnn(embedded, hidden) # 将嵌入向量输入 RNN 层并获取输出\n",
" return output, hidden\n",
"# 定义解码器类,继承自 nn.Module\n",
"class Decoder(nn.Module):\n",
" def __init__(self, hidden_size, output_size):\n",
" super(Decoder, self).__init__() \n",
" self.hidden_size = hidden_size # 设置隐藏层大小 \n",
" self.embedding = nn.Embedding(output_size, hidden_size) # 创建词嵌入层\n",
" self.rnn = nn.RNN(hidden_size, hidden_size, batch_first=True) # 创建 RNN 层 \n",
" self.out = nn.Linear(hidden_size, output_size) # 创建线性输出层 \n",
" def forward(self, inputs, hidden): # 前向传播函数 \n",
" embedded = self.embedding(inputs) # 将输入转换为嵌入向量 \n",
" output, hidden = self.rnn(embedded, hidden) # 将嵌入向量输入 RNN 层并获取输出 \n",
" output = self.out(output) # 使用线性层生成最终输出\n",
" return output, hidden\n",
"n_hidden = 128 # 设置隐藏层数量\n",
"# 创建编码器和解码器\n",
"encoder = Encoder(voc_size_cn, n_hidden)\n",
"decoder = Decoder(n_hidden, voc_size_en)\n",
"print(' 编码器结构:', encoder) # 打印编码器的结构\n",
"print(' 解码器结构:', decoder) # 打印解码器的结构"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "82dd33e3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"S2S 模型结构: Seq2Seq(\n",
" (encoder): Encoder(\n",
" (embedding): Embedding(18, 128)\n",
" (rnn): RNN(128, 128, batch_first=True)\n",
" )\n",
" (decoder): Decoder(\n",
" (embedding): Embedding(20, 128)\n",
" (rnn): RNN(128, 128, batch_first=True)\n",
" (out): Linear(in_features=128, out_features=20, bias=True)\n",
" )\n",
")\n"
]
}
],
"source": [
"class Seq2Seq(nn.Module):\n",
" def __init__(self, encoder, decoder):\n",
" super(Seq2Seq, self).__init__()\n",
" # 初始化编码器和解码器\n",
" self.encoder = encoder\n",
" self.decoder = decoder\n",
" def forward(self, enc_input, hidden, dec_input): # 定义前向传播函数\n",
" # 使输入序列通过编码器并获取输出和隐藏状态\n",
" encoder_output, encoder_hidden = self.encoder(enc_input, hidden)\n",
" # 将编码器的隐藏状态传递给解码器作为初始隐藏状态\n",
" decoder_hidden = encoder_hidden\n",
" # 使解码器输入(目标序列)通过解码器并获取输出\n",
" decoder_output, _ = self.decoder(dec_input, decoder_hidden)\n",
" return decoder_output\n",
"# 创建 Seq2Seq 架构\n",
"model = Seq2Seq(encoder, decoder)\n",
"print('S2S 模型结构:', model) # 打印模型的结构"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "4645634a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 0040 cost = 0.540944\n",
"Epoch: 0080 cost = 0.072166\n",
"Epoch: 0120 cost = 0.030407\n",
"Epoch: 0160 cost = 0.026792\n",
"Epoch: 0200 cost = 0.017567\n",
"Epoch: 0240 cost = 0.011450\n",
"Epoch: 0280 cost = 0.012062\n",
"Epoch: 0320 cost = 0.011834\n",
"Epoch: 0360 cost = 0.007469\n",
"Epoch: 0400 cost = 0.007511\n"
]
}
],
"source": [
"# 定义训练函数\n",
"def train_seq2seq(model, criterion, optimizer, epochs):\n",
" for epoch in range(epochs):\n",
" encoder_input, decoder_input, target = make_data(sentences) # 训练数据的创建\n",
" hidden = torch.zeros(1, encoder_input.size(0), n_hidden) # 初始化隐藏状态 \n",
" optimizer.zero_grad()# 梯度清零 \n",
" output = model(encoder_input, hidden, decoder_input) # 获取模型输出 \n",
" loss = criterion(output.view(-1, voc_size_en), target.view(-1)) # 计算损失 \n",
" if (epoch + 1) % 40 == 0: # 打印损失\n",
" print(f\"Epoch: {epoch + 1:04d} cost = {loss:.6f}\") \n",
" loss.backward()# 反向传播 \n",
" optimizer.step()# 更新参数\n",
"# 训练模型\n",
"epochs = 400 # 训练轮次\n",
"criterion = nn.CrossEntropyLoss() # 损失函数\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # 优化器\n",
"train_seq2seq(model, criterion, optimizer, epochs) # 调用函数训练模型"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "9877b5c7",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"咖哥 喜欢 小冰 -> ['KaGe', 'likes', 'XiaoBing']\n",
"自然 语言 处理 很 强大 -> ['NLP', 'is', 'so', 'powerful', '<eos>']\n"
]
}
],
"source": [
"# 定义测试函数\n",
"def test_seq2seq(model, source_sentence):\n",
" # 将输入的句子转换为索引\n",
" encoder_input = np.array([[word2idx_cn[n] for n in source_sentence.split()]])\n",
" # 构建输出的句子的索引,以 '<sos>' 开始,后面跟 '<eos>',长度与输入句子相同\n",
" decoder_input = np.array([word2idx_en['<sos>']] + [word2idx_en['<eos>']]*(len(encoder_input[0])-1))\n",
" # 转换为 LongTensor 类型\n",
" encoder_input = torch.LongTensor(encoder_input)\n",
" decoder_input = torch.LongTensor(decoder_input).unsqueeze(0) # 增加一维 \n",
" hidden = torch.zeros(1, encoder_input.size(0), n_hidden) # 初始化隐藏状态 \n",
" predict = model(encoder_input, hidden, decoder_input) # 获取模型输出 \n",
" predict = predict.data.max(2, keepdim=True)[1] # 获取概率最大的索引\n",
" # 打印输入的句子和预测的句子\n",
" print(source_sentence, '->', [idx2word_en[n.item()] for n in predict.squeeze()])\n",
"# 测试模型\n",
"test_seq2seq(model, '咖哥 喜欢 小冰') \n",
"test_seq2seq(model, '自然 语言 处理 很 强大')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7cf88ce0",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment