Created
February 3, 2025 07:51
-
-
Save nickfox-taterli/4d8dc953a09795371ce90e40d4297649 to your computer and use it in GitHub Desktop.
Pretrain_Deepseek
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "7e2afcd9-5b6e-4dfd-b115-58c45c700ef3", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"12.6\n" | |
] | |
} | |
], | |
"source": [ | |
"import torch\n", | |
"print(torch.version.cuda)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "41bf1a0e-e068-40e2-be15-65a328184866", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"C:\\Users\\TaterLi\\.conda\\envs\\gpt_learn\\Lib\\site-packages\\accelerate\\utils\\modeling.py:1536: UserWarning: Current model requires 128 bytes of buffer for offloaded layers, which seems does not fit any GPU's remaining memory. If you are experiencing a OOM later, please consider using offload_buffers=True.\n", | |
" warnings.warn(\n" | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "b6edd794ae0348f08533d5d921bb4b17", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"设备信息: cuda\n" | |
] | |
} | |
], | |
"source": [ | |
"import torch\n", | |
"from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig\n", | |
"\n", | |
"model_name = \"deepseek-ai/deepseek-llm-7b-chat\"\n", | |
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n", | |
"model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map=\"auto\")\n", | |
"model.generation_config = GenerationConfig.from_pretrained(model_name)\n", | |
"model.generation_config.pad_token_id = model.generation_config.eos_token_id\n", | |
"\n", | |
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\" # 判断是否有可用GPU\n", | |
"model.to(device)\n", | |
"print(\"设备信息:\", device)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "dabd736f-318d-4ff2-99ba-dfdce736fcbc", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"你好,我是一个人工智能助手,名为DeepSeek Chat。我是基于DeepSeek大语言模型开发的智能对话系统,可以进行自然语言处理和理解,帮助用户解答问题、提供信息和服务。我可以在多个领域中提供帮助,如科技、文化、历史、地理、语言学习等。我可以进行语音识别和合成,也可以与各种智能设备进行交互。我的主要目标是提供高效、准确、智能的问答服务,为用户带来便利和乐趣。\n" | |
] | |
} | |
], | |
"source": [ | |
"messages = [\n", | |
" {\"role\": \"user\", \"content\": \"你是谁\"}\n", | |
"]\n", | |
"input_tensor = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors=\"pt\")\n", | |
"attention_mask = input_tensor.ne(tokenizer.pad_token_id).float()\n", | |
"outputs = model.generate(input_tensor.to(model.device), attention_mask=attention_mask.to(model.device), max_new_tokens=100)\n", | |
"\n", | |
"result = tokenizer.decode(outputs[0][input_tensor.shape[1]:], skip_special_tokens=True)\n", | |
"print(result)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "e9ec8551-42f7-44d6-8d8b-4193f95ba007", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"模型信息: LlamaForCausalLM(\n", | |
" (model): LlamaModel(\n", | |
" (embed_tokens): Embedding(102400, 4096)\n", | |
" (layers): ModuleList(\n", | |
" (0-29): 30 x LlamaDecoderLayer(\n", | |
" (self_attn): LlamaAttention(\n", | |
" (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n", | |
" (k_proj): Linear(in_features=4096, out_features=4096, bias=False)\n", | |
" (v_proj): Linear(in_features=4096, out_features=4096, bias=False)\n", | |
" (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n", | |
" )\n", | |
" (mlp): LlamaMLP(\n", | |
" (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)\n", | |
" (up_proj): Linear(in_features=4096, out_features=11008, bias=False)\n", | |
" (down_proj): Linear(in_features=11008, out_features=4096, bias=False)\n", | |
" (act_fn): SiLU()\n", | |
" )\n", | |
" (input_layernorm): LlamaRMSNorm((4096,), eps=1e-06)\n", | |
" (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-06)\n", | |
" )\n", | |
" )\n", | |
" (norm): LlamaRMSNorm((4096,), eps=1e-06)\n", | |
" (rotary_emb): LlamaRotaryEmbedding()\n", | |
" )\n", | |
" (lm_head): Linear(in_features=4096, out_features=102400, bias=False)\n", | |
")\n", | |
"分词器信息: LlamaTokenizerFast(name_or_path='deepseek-ai/deepseek-llm-7b-chat', vocab_size=100000, model_max_length=4096, is_fast=True, padding_side='left', truncation_side='right', special_tokens={'bos_token': '<|begin▁of▁sentence|>', 'eos_token': '<|end▁of▁sentence|>', 'pad_token': '<|end▁of▁sentence|>'}, clean_up_tokenization_spaces=False, added_tokens_decoder={\n", | |
"\t100000: AddedToken(\"<|begin▁of▁sentence|>\", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),\n", | |
"\t100001: AddedToken(\"<|end▁of▁sentence|>\", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),\n", | |
"\t100002: AddedToken(\"ø\", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),\n", | |
"\t100003: AddedToken(\"ö\", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),\n", | |
"\t100004: AddedToken(\"ú\", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),\n", | |
"\t100005: AddedToken(\"ÿ\", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),\n", | |
"\t100006: AddedToken(\"õ\", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),\n", | |
"\t100007: AddedToken(\"÷\", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),\n", | |
"\t100008: AddedToken(\"û\", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),\n", | |
"\t100009: AddedToken(\"ý\", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),\n", | |
"\t100010: AddedToken(\"À\", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),\n", | |
"\t100011: AddedToken(\"ù\", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),\n", | |
"\t100012: AddedToken(\"Á\", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),\n", | |
"\t100013: AddedToken(\"þ\", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),\n", | |
"\t100014: AddedToken(\"ü\", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),\n", | |
"}\n", | |
")\n", | |
"词汇表大小: 100015\n", | |
"部分词汇示例: ['oling', 'ĠHart', 'ĠAdams', 'Ġprocedents', 'Ġfactions']\n" | |
] | |
} | |
], | |
"source": [ | |
"vocab = tokenizer.get_vocab() # 获取词汇表\n", | |
"\n", | |
"print(\"模型信息:\", model)\n", | |
"print(\"分词器信息:\",tokenizer)\n", | |
"print(\"词汇表大小:\", len(vocab))\n", | |
"print(\"部分词汇示例:\", (list(vocab.keys())[8000:8005]))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "ba0114ae-241a-4088-b4b7-80f4192c3ab7", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Example 1:\n", | |
"Input: <|begin▁of▁sentence|>你是谁<|end▁of▁sentence|>\n", | |
"Target: <|begin▁of▁sentence|>我是DS助手<|end▁of▁sentence|>\n", | |
"Example 2:\n", | |
"Input: <|begin▁of▁sentence|>今天天气怎样<|end▁of▁sentence|>\n", | |
"Target: <|begin▁of▁sentence|>天气每天都不错啦<|end▁of▁sentence|>\n" | |
] | |
} | |
], | |
"source": [ | |
"from torch.utils.data import Dataset # 导入Pytorch的Dataset\n", | |
"# 自定义ChatDataset类,继承自Pytorch的Dataset类\n", | |
"class ChatDataset(Dataset):\n", | |
" def __init__(self, file_path, tokenizer, vocab):\n", | |
" self.tokenizer = tokenizer # 分词器\n", | |
" self.vocab = vocab # 词汇表\n", | |
" # 加载数据并处理,将处理后的输入数据和目标数据赋值给input_data和target_data\n", | |
" self.input_data, self.target_data = self.load_and_process_data(file_path)\n", | |
" # 定义加载和处理数据的方法\n", | |
" def load_and_process_data(self, file_path): \n", | |
" with open(file_path, \"r\", encoding=\"utf-8\") as f: # 读取文件内容\n", | |
" lines = f.readlines()\n", | |
" input_data, target_data = [], [] \n", | |
" for i, line in enumerate(lines): # 遍历文件的每一行 \n", | |
" if line.startswith(\"User:\"): # 如以\"User:\"开头,分词,移除\"User: \"前缀,并将张量转换为列表\n", | |
" tokens = self.tokenizer(line.strip()[6:], return_tensors=\"pt\")[\"input_ids\"].tolist()[0]\n", | |
" tokens = tokens + [tokenizer.eos_token_id] # 添加结束符\n", | |
" input_data.append(torch.tensor(tokens, dtype=torch.long)) # 添加到input_data中\n", | |
" elif line.startswith(\"AI:\"): # 如以\"AI:\"开头,分词,移除\"AI: \"前缀,并将张量转换为列表\n", | |
" tokens = self.tokenizer(line.strip()[4:], return_tensors=\"pt\")[\"input_ids\"].tolist()[0]\n", | |
" tokens = tokens + [tokenizer.eos_token_id] # 添加结束符\n", | |
" target_data.append(torch.tensor(tokens, dtype=torch.long)) # 添加到target_data中\n", | |
" return input_data, target_data\n", | |
" # 定义数据集的长度,即input_data的长度\n", | |
" def __len__(self):\n", | |
" return len(self.input_data)\n", | |
" # 定义获取数据集中指定索引的数据的方法\n", | |
" def __getitem__(self, idx):\n", | |
" return self.input_data[idx], self.target_data[idx]\n", | |
"\n", | |
"file_path = \"chat.txt\" # 加载chat.txt数据集\n", | |
"chat_dataset = ChatDataset(file_path, tokenizer, vocab) # 创建ChatDataset对象,传入文件、分词器和词汇表\n", | |
"for i in range(2): # 打印数据集中前2个数据示例\n", | |
" input_example, target_example = chat_dataset[i]\n", | |
" print(f\"Example {i + 1}:\")\n", | |
" print(\"Input:\", tokenizer.decode(input_example))\n", | |
" print(\"Target:\", tokenizer.decode(target_example))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "7d46c771-4105-4a49-beb5-dc7724b653c5", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"'<|end▁of▁sentence|>'" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"tokenizer.pad_token_id\n", | |
"tokenizer.decode(tokenizer.pad_token_id)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "f76cb1f0-966c-42ee-ab82-1a870bfb1aec", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Input batch tensor size: torch.Size([1, 21])\n", | |
"Target batch tensor size: torch.Size([1, 21])\n", | |
"Input batch tensor:\n", | |
"tensor([[100000, 9728, 68650, 33042, 100001, 100001, 100001, 100001, 100001]])\n", | |
"<|begin▁of▁sentence|>努力总有回报<|end▁of▁sentence|><|end▁of▁sentence|><|end▁of▁sentence|><|end▁of▁sentence|><|end▁of▁sentence|>\n", | |
"Target batch tensor:\n", | |
"tensor([[100000, 3705, 15037, 1059, 5782, 1059, 26253, 617, 100001]])\n", | |
"<|begin▁of▁sentence|>所以就要多学习多练习了<|end▁of▁sentence|>\n" | |
] | |
} | |
], | |
"source": [ | |
"from torch.utils.data import DataLoader # 导入Dataloader\n", | |
"# 定义pad_sequence函数,用于将一批序列补齐到相同长度\n", | |
"def pad_sequence(sequences, padding_value=0, length=None):\n", | |
" # 计算最大序列长度,如果length参数未提供,则使用输入序列中的最大长度\n", | |
" max_length = max(len(seq) for seq in sequences) if length is None else length \n", | |
" # 创建一个具有适当形状的全零张量,用于存储补齐后的序列\n", | |
" result = torch.full((len(sequences), max_length), padding_value, dtype=torch.long) \n", | |
" # 遍历序列,将每个序列的内容复制到结果张量中\n", | |
" for i, seq in enumerate(sequences):\n", | |
" end = len(seq)\n", | |
" result[i, :end] = seq[:end]\n", | |
" return result\n", | |
"\n", | |
"# 定义collate_fn函数,用于将一个批次的数据整理成适当的形状\n", | |
"def collate_fn(batch):\n", | |
" # 从批次中分离源序列和目标序列\n", | |
" sources, targets = zip(*batch) \n", | |
" # 计算批次中的最大序列长度\n", | |
" max_length = max(max(len(s) for s in sources), max(len(t) for t in targets)) \n", | |
" # 使用pad_sequence函数补齐源序列和目标序列\n", | |
" sources = pad_sequence(sources, padding_value=tokenizer.pad_token_id, length=max_length)\n", | |
" targets = pad_sequence(targets, padding_value=tokenizer.pad_token_id, length=max_length) \n", | |
" # 返回补齐后的源序列和目标序列\n", | |
" return sources, targets\n", | |
"\n", | |
"# 创建Dataloader\n", | |
"chat_dataloader = DataLoader(chat_dataset, batch_size=1, shuffle=True, collate_fn=collate_fn)\n", | |
"\n", | |
"# 检查Dataloader输出\n", | |
"for input_batch, target_batch in chat_dataloader:\n", | |
" print(\"Input batch tensor size:\", input_batch.size())\n", | |
" print(\"Target batch tensor size:\", target_batch.size())\n", | |
" break\n", | |
"for input_batch, target_batch in chat_dataloader:\n", | |
" print(\"Input batch tensor:\")\n", | |
" print(input_batch)\n", | |
" texts = tokenizer.batch_decode(input_batch)\n", | |
" # 打印转换后的文本\n", | |
" for text in texts:\n", | |
" print(text)\n", | |
" \n", | |
" print(\"Target batch tensor:\")\n", | |
" print(target_batch)\n", | |
" texts = tokenizer.batch_decode(target_batch)\n", | |
" # 打印转换后的文本\n", | |
" for text in texts:\n", | |
" print(text)\n", | |
" \n", | |
" break" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "d9e4604a-a951-47f5-9f60-c72b6229b50b", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"for i, layer in enumerate(model.model.layers):\n", | |
" if i < 10: # 冻结前 10 层\n", | |
" for param in layer.parameters():\n", | |
" param.requires_grad = False" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "299ad4c5-2d54-417f-809f-98ddfddd1679", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"|===========================================================================|\n", | |
"| PyTorch CUDA memory summary, device ID 0 |\n", | |
"|---------------------------------------------------------------------------|\n", | |
"| CUDA OOMs: 2 | cudaMalloc retries: 2 |\n", | |
"|===========================================================================|\n", | |
"| Metric | Cur Usage | Peak Usage | Tot Alloc | Tot Freed |\n", | |
"|---------------------------------------------------------------------------|\n", | |
"| Allocated memory | 17875 MiB | 55157 MiB | 2166 GiB | 2148 GiB |\n", | |
"| from large pool | 17856 MiB | 55136 MiB | 2126 GiB | 2109 GiB |\n", | |
"| from small pool | 19 MiB | 138 MiB | 39 GiB | 39 GiB |\n", | |
"|---------------------------------------------------------------------------|\n", | |
"| Active memory | 17875 MiB | 55157 MiB | 2166 GiB | 2148 GiB |\n", | |
"| from large pool | 17856 MiB | 55136 MiB | 2126 GiB | 2109 GiB |\n", | |
"| from small pool | 19 MiB | 138 MiB | 39 GiB | 39 GiB |\n", | |
"|---------------------------------------------------------------------------|\n", | |
"| Requested memory | 17875 MiB | 55157 MiB | 2166 GiB | 2148 GiB |\n", | |
"| from large pool | 17856 MiB | 55136 MiB | 2126 GiB | 2108 GiB |\n", | |
"| from small pool | 19 MiB | 138 MiB | 39 GiB | 39 GiB |\n", | |
"|---------------------------------------------------------------------------|\n", | |
"| GPU reserved memory | 17976 MiB | 55264 MiB | 92742 MiB | 74766 MiB |\n", | |
"| from large pool | 17880 MiB | 55160 MiB | 92440 MiB | 74560 MiB |\n", | |
"| from small pool | 96 MiB | 150 MiB | 302 MiB | 206 MiB |\n", | |
"|---------------------------------------------------------------------------|\n", | |
"| Non-releasable memory | 102937 KiB | 119086 KiB | 55617 MiB | 55517 MiB |\n", | |
"| from large pool | 24320 KiB | 24320 KiB | 11454 MiB | 11430 MiB |\n", | |
"| from small pool | 78617 KiB | 99566 KiB | 44163 MiB | 44086 MiB |\n", | |
"|---------------------------------------------------------------------------|\n", | |
"| Allocations | 725 | 1603 | 511022 | 510297 |\n", | |
"| from large pool | 285 | 853 | 36888 | 36603 |\n", | |
"| from small pool | 440 | 1034 | 474134 | 473694 |\n", | |
"|---------------------------------------------------------------------------|\n", | |
"| Active allocs | 725 | 1603 | 511022 | 510297 |\n", | |
"| from large pool | 285 | 853 | 36888 | 36603 |\n", | |
"| from small pool | 440 | 1034 | 474134 | 473694 |\n", | |
"|---------------------------------------------------------------------------|\n", | |
"| GPU reserved segments | 333 | 905 | 1572 | 1239 |\n", | |
"| from large pool | 285 | 853 | 1421 | 1136 |\n", | |
"| from small pool | 48 | 75 | 151 | 103 |\n", | |
"|---------------------------------------------------------------------------|\n", | |
"| Non-releasable allocs | 128 | 217 | 275994 | 275866 |\n", | |
"| from large pool | 2 | 4 | 475 | 473 |\n", | |
"| from small pool | 126 | 214 | 275519 | 275393 |\n", | |
"|---------------------------------------------------------------------------|\n", | |
"| Oversize allocations | 0 | 0 | 0 | 0 |\n", | |
"|---------------------------------------------------------------------------|\n", | |
"| Oversize GPU segments | 0 | 0 | 0 | 0 |\n", | |
"|===========================================================================|\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"import gc\n", | |
"\n", | |
"# 强制垃圾回收\n", | |
"gc.collect()\n", | |
"\n", | |
"# 释放未使用的显存\n", | |
"torch.cuda.empty_cache()\n", | |
"\n", | |
"print(torch.cuda.memory_summary())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "f3f2e1c2-117e-4fa2-9a0c-c14a817c88dc", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch: 0001, cost = 7.500000\n", | |
"Epoch: 0002, cost = 3.562500\n", | |
"Epoch: 0003, cost = 0.251953\n", | |
"Epoch: 0004, cost = 1.453125\n", | |
"Epoch: 0005, cost = 0.000584\n", | |
"Epoch: 0006, cost = 0.000422\n", | |
"Epoch: 0007, cost = 0.002109\n", | |
"Epoch: 0008, cost = 0.000534\n", | |
"Epoch: 0009, cost = 0.009399\n", | |
"Epoch: 0010, cost = 0.000723\n", | |
"Epoch: 0011, cost = 0.000884\n", | |
"Epoch: 0012, cost = 0.000429\n", | |
"Epoch: 0013, cost = 0.000531\n", | |
"Epoch: 0014, cost = 0.000613\n", | |
"Epoch: 0015, cost = 0.000341\n", | |
"Epoch: 0016, cost = 0.000006\n", | |
"Epoch: 0017, cost = 0.000091\n", | |
"Epoch: 0018, cost = 0.000039\n", | |
"Epoch: 0019, cost = 0.000092\n", | |
"Epoch: 0020, cost = 0.000012\n" | |
] | |
} | |
], | |
"source": [ | |
"import torch.nn as nn\n", | |
"import torch.optim as optim\n", | |
"\n", | |
"torch.cuda.empty_cache()\n", | |
"\n", | |
"# 定义损失函数,忽略pad_token_id对应的损失值\n", | |
"criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)\n", | |
"# 定义优化器\n", | |
"optimizer = optim.Adam(model.parameters(), lr=0.0001)\n", | |
"# 进行100个epoch的训练\n", | |
"for epoch in range(20):\n", | |
" # 遍历数据加载器中的批次\n", | |
" for batch_idx, (input_batch, target_batch) in enumerate(chat_dataloader): \n", | |
" optimizer.zero_grad() # 梯度清零 \n", | |
" input_batch, target_batch = input_batch.to(device), target_batch.to(device) # 将输入和目标批次移至设备(CPU或GPU) \n", | |
" outputs = model(input_batch) # 前向传播\n", | |
" logits = outputs.logits # 获取logits \n", | |
" loss = criterion(logits.view(-1, logits.size(-1)), target_batch.view(-1)) # 计算损失\n", | |
" loss.backward() # 反向传播 \n", | |
" optimizer.step()# 更新参数 \n", | |
" print(f'Epoch: {epoch + 1:04d}, cost = {loss:.6f}')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"id": "6a988393-7a0b-4108-9106-93bf4d930725", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"我是DS助手\n" | |
] | |
} | |
], | |
"source": [ | |
"messages = [\n", | |
" {\"role\": \"user\", \"content\": \"你是谁\"}\n", | |
"]\n", | |
"input_tensor = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors=\"pt\")\n", | |
"attention_mask = input_tensor.ne(tokenizer.pad_token_id).float()\n", | |
"outputs = model.generate(input_tensor.to(model.device), attention_mask=attention_mask.to(model.device), max_new_tokens=100)\n", | |
"\n", | |
"result = tokenizer.decode(outputs[0][input_tensor.shape[1]:], skip_special_tokens=True)\n", | |
"print(result)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "d113dea9-6a73-4482-97ae-42a5cfffc05b", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.11.11" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment