Created
January 24, 2025 03:32
-
-
Save nickfox-taterli/3e82aad6d2a0eac9cfe23fe75aa1bf1a to your computer and use it in GitHub Desktop.
Seq2Seq 例子
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "4d476dd6", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
" 句子数量: 5\n", | |
" 中文词汇表大小: 18\n", | |
" 英文词汇表大小: 20\n", | |
" 中文词汇到索引的字典: {'处理': 0, '小冰': 1, '深度学习': 2, '复杂': 3, '人工智能': 4, '我': 5, '喜欢': 6, '强大': 7, '非常': 8, '自然': 9, '学习': 10, '语言': 11, '改变': 12, '爱': 13, '神经网络': 14, '咖哥': 15, '很': 16, '世界': 17}\n", | |
" 英文词汇到索引的字典: {'world': 0, 'are': 1, 'is': 2, 'changed': 3, 'Neural-Nets': 4, 'DL': 5, 'KaGe': 6, 'likes': 7, 'XiaoBing': 8, 'AI': 9, 'complex': 10, 'the': 11, 'love': 12, 'powerful': 13, 'I': 14, 'NLP': 15, '<eos>': 16, '<sos>': 17, 'studying': 18, 'so': 19}\n" | |
] | |
} | |
], | |
"source": [ | |
"# 构建语料库,每行包含中文、英文(解码器输入)和翻译成英文后的目标输出 3 个句子\n", | |
"sentences = [\n", | |
" ['咖哥 喜欢 小冰', '<sos> KaGe likes XiaoBing', 'KaGe likes XiaoBing <eos>'],\n", | |
" ['我 爱 学习 人工智能', '<sos> I love studying AI', 'I love studying AI <eos>'],\n", | |
" ['深度学习 改变 世界', '<sos> DL changed the world', 'DL changed the world <eos>'],\n", | |
" ['自然 语言 处理 很 强大', '<sos> NLP is so powerful', 'NLP is so powerful <eos>'],\n", | |
" ['神经网络 非常 复杂', '<sos> Neural-Nets are complex', 'Neural-Nets are complex <eos>']]\n", | |
"word_list_cn, word_list_en = [], [] # 初始化中英文词汇表\n", | |
"# 遍历每一个句子并将单词添加到词汇表中\n", | |
"for s in sentences:\n", | |
" word_list_cn.extend(s[0].split())\n", | |
" word_list_en.extend(s[1].split())\n", | |
" word_list_en.extend(s[2].split())\n", | |
"# 去重,得到没有重复单词的词汇表\n", | |
"word_list_cn = list(set(word_list_cn))\n", | |
"word_list_en = list(set(word_list_en))\n", | |
"# 构建单词到索引的映射\n", | |
"word2idx_cn = {w: i for i, w in enumerate(word_list_cn)}\n", | |
"word2idx_en = {w: i for i, w in enumerate(word_list_en)}\n", | |
"# 构建索引到单词的映射\n", | |
"idx2word_cn = {i: w for i, w in enumerate(word_list_cn)}\n", | |
"idx2word_en = {i: w for i, w in enumerate(word_list_en)}\n", | |
"# 计算词汇表的大小\n", | |
"voc_size_cn = len(word_list_cn)\n", | |
"voc_size_en = len(word_list_en)\n", | |
"print(\" 句子数量:\", len(sentences)) # 打印句子数\n", | |
"print(\" 中文词汇表大小:\", voc_size_cn) # 打印中文词汇表大小\n", | |
"print(\" 英文词汇表大小:\", voc_size_en) # 打印英文词汇表大小\n", | |
"print(\" 中文词汇到索引的字典:\", word2idx_cn) # 打印中文词汇到索引的字典\n", | |
"print(\" 英文词汇到索引的字典:\", word2idx_en) # 打印英文词汇到索引的字典" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "2cf92365", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
" 原始句子: ['我 爱 学习 人工智能', '<sos> I love studying AI', 'I love studying AI <eos>']\n", | |
" 编码器输入张量的形状: torch.Size([1, 4])\n", | |
" 解码器输入张量的形状: torch.Size([1, 5])\n", | |
" 目标张量的形状: torch.Size([1, 5])\n", | |
" 编码器输入张量: tensor([[ 5, 13, 10, 4]])\n", | |
" 解码器输入张量: tensor([[17, 14, 12, 18, 9]])\n", | |
" 目标张量: tensor([[14, 12, 18, 9, 16]])\n" | |
] | |
} | |
], | |
"source": [ | |
"import numpy as np # 导入 numpy\n", | |
"import torch # 导入 torch\n", | |
"import random # 导入 random 库\n", | |
"# 定义一个函数,随机选择一个句子和词汇表生成输入、输出和目标数据\n", | |
"def make_data(sentences):\n", | |
" # 随机选择一个句子进行训练\n", | |
" random_sentence = random.choice(sentences)\n", | |
" # 将输入句子中的单词转换为对应的索引\n", | |
" encoder_input = np.array([[word2idx_cn[n] for n in random_sentence[0].split()]])\n", | |
" # 将输出句子中的单词转换为对应的索引\n", | |
" decoder_input = np.array([[word2idx_en[n] for n in random_sentence[1].split()]])\n", | |
" # 将目标句子中的单词转换为对应的索引\n", | |
" target = np.array([[word2idx_en[n] for n in random_sentence[2].split()]])\n", | |
" # 将输入、输出和目标批次转换为 LongTensor\n", | |
" encoder_input = torch.LongTensor(encoder_input)\n", | |
" decoder_input = torch.LongTensor(decoder_input)\n", | |
" target = torch.LongTensor(target)\n", | |
" return encoder_input, decoder_input, target \n", | |
"# 使用 make_data 函数生成输入、输出和目标张量\n", | |
"encoder_input, decoder_input, target = make_data(sentences)\n", | |
"for s in sentences: # 获取原始句子\n", | |
" if all([word2idx_cn[w] in encoder_input[0] for w in s[0].split()]):\n", | |
" original_sentence = s\n", | |
" break\n", | |
"print(\" 原始句子:\", original_sentence) # 打印原始句子\n", | |
"print(\" 编码器输入张量的形状:\", encoder_input.shape) # 打印输入张量形状\n", | |
"print(\" 解码器输入张量的形状:\", decoder_input.shape) # 打印输出张量形状\n", | |
"print(\" 目标张量的形状:\", target.shape) # 打印目标张量形状\n", | |
"print(\" 编码器输入张量:\", encoder_input) # 打印输入张量\n", | |
"print(\" 解码器输入张量:\", decoder_input) # 打印输出张量\n", | |
"print(\" 目标张量:\", target) # 打印目标张量\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "3b8b3d1d", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
" 编码器结构: Encoder(\n", | |
" (embedding): Embedding(18, 128)\n", | |
" (rnn): RNN(128, 128, batch_first=True)\n", | |
")\n", | |
" 解码器结构: Decoder(\n", | |
" (embedding): Embedding(20, 128)\n", | |
" (rnn): RNN(128, 128, batch_first=True)\n", | |
" (out): Linear(in_features=128, out_features=20, bias=True)\n", | |
")\n" | |
] | |
} | |
], | |
"source": [ | |
"import torch.nn as nn # 导入 torch.nn 库\n", | |
"# 定义编码器类,继承自 nn.Module\n", | |
"class Encoder(nn.Module):\n", | |
" def __init__(self, input_size, hidden_size):\n", | |
" super(Encoder, self).__init__() \n", | |
" self.hidden_size = hidden_size # 设置隐藏层大小 \n", | |
" self.embedding = nn.Embedding(input_size, hidden_size) # 创建词嵌入层 \n", | |
" self.rnn = nn.RNN(hidden_size, hidden_size, batch_first=True) # 创建 RNN 层 \n", | |
" def forward(self, inputs, hidden): # 前向传播函数\n", | |
" embedded = self.embedding(inputs) # 将输入转换为嵌入向量 \n", | |
" output, hidden = self.rnn(embedded, hidden) # 将嵌入向量输入 RNN 层并获取输出\n", | |
" return output, hidden\n", | |
"# 定义解码器类,继承自 nn.Module\n", | |
"class Decoder(nn.Module):\n", | |
" def __init__(self, hidden_size, output_size):\n", | |
" super(Decoder, self).__init__() \n", | |
" self.hidden_size = hidden_size # 设置隐藏层大小 \n", | |
" self.embedding = nn.Embedding(output_size, hidden_size) # 创建词嵌入层\n", | |
" self.rnn = nn.RNN(hidden_size, hidden_size, batch_first=True) # 创建 RNN 层 \n", | |
" self.out = nn.Linear(hidden_size, output_size) # 创建线性输出层 \n", | |
" def forward(self, inputs, hidden): # 前向传播函数 \n", | |
" embedded = self.embedding(inputs) # 将输入转换为嵌入向量 \n", | |
" output, hidden = self.rnn(embedded, hidden) # 将嵌入向量输入 RNN 层并获取输出 \n", | |
" output = self.out(output) # 使用线性层生成最终输出\n", | |
" return output, hidden\n", | |
"n_hidden = 128 # 设置隐藏层数量\n", | |
"# 创建编码器和解码器\n", | |
"encoder = Encoder(voc_size_cn, n_hidden)\n", | |
"decoder = Decoder(n_hidden, voc_size_en)\n", | |
"print(' 编码器结构:', encoder) # 打印编码器的结构\n", | |
"print(' 解码器结构:', decoder) # 打印解码器的结构" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "82dd33e3", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"S2S 模型结构: Seq2Seq(\n", | |
" (encoder): Encoder(\n", | |
" (embedding): Embedding(18, 128)\n", | |
" (rnn): RNN(128, 128, batch_first=True)\n", | |
" )\n", | |
" (decoder): Decoder(\n", | |
" (embedding): Embedding(20, 128)\n", | |
" (rnn): RNN(128, 128, batch_first=True)\n", | |
" (out): Linear(in_features=128, out_features=20, bias=True)\n", | |
" )\n", | |
")\n" | |
] | |
} | |
], | |
"source": [ | |
"class Seq2Seq(nn.Module):\n", | |
" def __init__(self, encoder, decoder):\n", | |
" super(Seq2Seq, self).__init__()\n", | |
" # 初始化编码器和解码器\n", | |
" self.encoder = encoder\n", | |
" self.decoder = decoder\n", | |
" def forward(self, enc_input, hidden, dec_input): # 定义前向传播函数\n", | |
" # 使输入序列通过编码器并获取输出和隐藏状态\n", | |
" encoder_output, encoder_hidden = self.encoder(enc_input, hidden)\n", | |
" # 将编码器的隐藏状态传递给解码器作为初始隐藏状态\n", | |
" decoder_hidden = encoder_hidden\n", | |
" # 使解码器输入(目标序列)通过解码器并获取输出\n", | |
" decoder_output, _ = self.decoder(dec_input, decoder_hidden)\n", | |
" return decoder_output\n", | |
"# 创建 Seq2Seq 架构\n", | |
"model = Seq2Seq(encoder, decoder)\n", | |
"print('S2S 模型结构:', model) # 打印模型的结构" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "4645634a", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch: 0040 cost = 0.540944\n", | |
"Epoch: 0080 cost = 0.072166\n", | |
"Epoch: 0120 cost = 0.030407\n", | |
"Epoch: 0160 cost = 0.026792\n", | |
"Epoch: 0200 cost = 0.017567\n", | |
"Epoch: 0240 cost = 0.011450\n", | |
"Epoch: 0280 cost = 0.012062\n", | |
"Epoch: 0320 cost = 0.011834\n", | |
"Epoch: 0360 cost = 0.007469\n", | |
"Epoch: 0400 cost = 0.007511\n" | |
] | |
} | |
], | |
"source": [ | |
"# 定义训练函数\n", | |
"def train_seq2seq(model, criterion, optimizer, epochs):\n", | |
" for epoch in range(epochs):\n", | |
" encoder_input, decoder_input, target = make_data(sentences) # 训练数据的创建\n", | |
" hidden = torch.zeros(1, encoder_input.size(0), n_hidden) # 初始化隐藏状态 \n", | |
" optimizer.zero_grad()# 梯度清零 \n", | |
" output = model(encoder_input, hidden, decoder_input) # 获取模型输出 \n", | |
" loss = criterion(output.view(-1, voc_size_en), target.view(-1)) # 计算损失 \n", | |
" if (epoch + 1) % 40 == 0: # 打印损失\n", | |
" print(f\"Epoch: {epoch + 1:04d} cost = {loss:.6f}\") \n", | |
" loss.backward()# 反向传播 \n", | |
" optimizer.step()# 更新参数\n", | |
"# 训练模型\n", | |
"epochs = 400 # 训练轮次\n", | |
"criterion = nn.CrossEntropyLoss() # 损失函数\n", | |
"optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # 优化器\n", | |
"train_seq2seq(model, criterion, optimizer, epochs) # 调用函数训练模型" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "9877b5c7", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"咖哥 喜欢 小冰 -> ['KaGe', 'likes', 'XiaoBing']\n", | |
"自然 语言 处理 很 强大 -> ['NLP', 'is', 'so', 'powerful', '<eos>']\n" | |
] | |
} | |
], | |
"source": [ | |
"# 定义测试函数\n", | |
"def test_seq2seq(model, source_sentence):\n", | |
" # 将输入的句子转换为索引\n", | |
" encoder_input = np.array([[word2idx_cn[n] for n in source_sentence.split()]])\n", | |
" # 构建输出的句子的索引,以 '<sos>' 开始,后面跟 '<eos>',长度与输入句子相同\n", | |
" decoder_input = np.array([word2idx_en['<sos>']] + [word2idx_en['<eos>']]*(len(encoder_input[0])-1))\n", | |
" # 转换为 LongTensor 类型\n", | |
" encoder_input = torch.LongTensor(encoder_input)\n", | |
" decoder_input = torch.LongTensor(decoder_input).unsqueeze(0) # 增加一维 \n", | |
" hidden = torch.zeros(1, encoder_input.size(0), n_hidden) # 初始化隐藏状态 \n", | |
" predict = model(encoder_input, hidden, decoder_input) # 获取模型输出 \n", | |
" predict = predict.data.max(2, keepdim=True)[1] # 获取概率最大的索引\n", | |
" # 打印输入的句子和预测的句子\n", | |
" print(source_sentence, '->', [idx2word_en[n.item()] for n in predict.squeeze()])\n", | |
"# 测试模型\n", | |
"test_seq2seq(model, '咖哥 喜欢 小冰') \n", | |
"test_seq2seq(model, '自然 语言 处理 很 强大')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "7cf88ce0", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.10.11" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment