Last active
May 31, 2025 15:29
-
-
Save xiupos/f14c3250f306c79cd14ad9e7ce85f36a to your computer and use it in GitHub Desktop.
nlp-ml-huit.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "nbformat": 4, | |
| "nbformat_minor": 0, | |
| "metadata": { | |
| "colab": { | |
| "provenance": [], | |
| "authorship_tag": "ABX9TyNIzz4jlceAnBWNB3hONj4T", | |
| "include_colab_link": true | |
| }, | |
| "kernelspec": { | |
| "name": "python3", | |
| "display_name": "Python 3" | |
| }, | |
| "language_info": { | |
| "name": "python" | |
| } | |
| }, | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "view-in-github", | |
| "colab_type": "text" | |
| }, | |
| "source": [ | |
| "<a href=\"https://colab.research.google.com/gist/xiupos/f14c3250f306c79cd14ad9e7ce85f36a/nlp-ml-huit.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "# NLP/ML 勉強会 @HUIT\n", | |
| "\n", | |
| "---\n", | |
| "\n" | |
| ], | |
| "metadata": { | |
| "id": "h33YdTr0Cisa" | |
| } | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "とにかく動く言語モデルを作ります。動けばいいので、精度は求めません。単純な構成を採用して「それっぽく」動くものをまずは作ってみましょう。\n", | |
| "\n", | |
| "[発表スライド](https://speakerdeck.com/xiupos/ml-mian-qiang-hui-at-huit) を前提とした記述をしています。" | |
| ], | |
| "metadata": { | |
| "id": "UlPyWRiRcLhb" | |
| } | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "## 準備" | |
| ], | |
| "metadata": { | |
| "id": "yPEcpnXUCrN6" | |
| } | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "まずは使うライブラリを用意します。今回は [PyTorch](https://pytorch.org/) を中心に使います。" | |
| ], | |
| "metadata": { | |
| "id": "DyuV_PX95kXu" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "UnKGkrykcKD7", | |
| "outputId": "b697417d-7438-4550-fb02-17a9a67bd1cd" | |
| }, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "<torch._C.Generator at 0x7b62253402d0>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "execution_count": 1 | |
| } | |
| ], | |
| "source": [ | |
| "import torch\n", | |
| "from torch import nn\n", | |
| "from torch.nn import functional as F\n", | |
| "import numpy as np\n", | |
| "from matplotlib import pyplot as plt\n", | |
| "import pandas as pd\n", | |
| "# 乱数を固定\n", | |
| "torch.manual_seed(1)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "次に教師データを用意します。教師データは何でもいいですが、ここでは「シャーロックホームズの冒険 (The Adventures of Sherlock Holmes)」の[原文](https://sherlock-holm.es/stories/plain-text/advs.txt)にしました。有名な教師データには [Tiny Shakespeare](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt) などがあります。" | |
| ], | |
| "metadata": { | |
| "id": "lz9mjU5ljFGe" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# 教師データをダウンロード (https://sherlock-holm.es/stories/plain-text/advs.txt を基に作成)\n", | |
| "!wget https://gist.githubusercontent.com/xiupos/b7914ea1e3ab35465c34de45146b15d8/raw/b7623310de06160876ffba6299994b0409c85c81/advs.txt\n", | |
| "lines: str = open('./advs.txt', 'r').read().replace('\\n', ' ')" | |
| ], | |
| "metadata": { | |
| "id": "F85C-xgrgYgg", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "outputId": "22ad09fe-0429-4065-ee76-f98727cc362c" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "--2024-08-16 08:23:44-- https://gist.githubusercontent.com/xiupos/b7914ea1e3ab35465c34de45146b15d8/raw/b7623310de06160876ffba6299994b0409c85c81/advs.txt\n", | |
| "Resolving gist.githubusercontent.com (gist.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...\n", | |
| "Connecting to gist.githubusercontent.com (gist.githubusercontent.com)|185.199.108.133|:443... connected.\n", | |
| "HTTP request sent, awaiting response... 200 OK\n", | |
| "Length: 562263 (549K) [text/plain]\n", | |
| "Saving to: ‘advs.txt’\n", | |
| "\n", | |
| "advs.txt 100%[===================>] 549.08K --.-KB/s in 0.03s \n", | |
| "\n", | |
| "2024-08-16 08:23:44 (17.8 MB/s) - ‘advs.txt’ saved [562263/562263]\n", | |
| "\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "教師データは文字列です。文字列を機械学習で直接扱うのは難しいので、まずは単語を整数と対応させます。ただし、英単語を単位に扱うのは高度なので、簡単のためここでは文字を単位に考えます。以降、「単語」は文字のことを指します。言語モデルの文脈では「トークン」と呼ぶのが正確だと思います。\n", | |
| "\n", | |
| "単純に文字順の番号を文字と対応する整数として用います。たとえば、文字順で $μ$ 番目の文字には整数 $μ$ が対応します。以降、その文字のことを単語 $μ$ と呼びます。" | |
| ], | |
| "metadata": { | |
| "id": "a4_gSeCj36vC" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# 文字一覧\n", | |
| "vocab: list[str] = sorted(list(set(lines)))\n", | |
| "# {整数(番号): 文字} の辞書\n", | |
| "itos: dict[int, str] = {i:s for i,s in enumerate(vocab)}\n", | |
| "# {文字: 整数(番号)} の辞書\n", | |
| "stoi: dict[str, int] = {s:i for i,s in enumerate(vocab)}\n", | |
| "\n", | |
| "# 文字の一覧と数\n", | |
| "\"\".join(vocab), len(vocab)" | |
| ], | |
| "metadata": { | |
| "id": "_Gl30hk9ghRk", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "outputId": "cc390000-472d-4307-fed8-b9bbaf73cae6" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "(' !\"&\\'(),-./0123456789:;?ABCDEFGHIJKLMNOPQRSTUVWYZ[]abcdefghijklmnopqrstuvwxyz£½àâèé',\n", | |
| " 83)" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "execution_count": 3 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "次に上の対応を用いて、文字列を整数列に変換する関数を定義しましょう。また教師データを整数列に変換して教師データ `dataset` とします。" | |
| ], | |
| "metadata": { | |
| "id": "naUorL6u4Ayw" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# エンコード (文字列 -> 数リスト の関数)\n", | |
| "def encode(s: str) -> list[int]:\n", | |
| " return [stoi[s] for s in s]\n", | |
| "\n", | |
| "# デコード (数リスト -> 文字列 の関数)\n", | |
| "def decode(l: list[int]) -> str:\n", | |
| " return ''.join([itos[i] for i in l])\n", | |
| "\n", | |
| "# 教師データを整数列に変換\n", | |
| "dataset = encode(lines)\n", | |
| "\n", | |
| "# エンコードとデコードのテスト\n", | |
| "encode(\"Do you know, Watson,\"), decode(encode(\"Do you know, Watson,\"))" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "wvkJQt4jhCbY", | |
| "outputId": "f9e32854-e931-40cd-91a1-16fb8db0115e" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "([27, 65, 0, 75, 65, 71, 0, 61, 64, 65, 73, 7, 0, 46, 51, 70, 69, 65, 64, 7],\n", | |
| " 'Do you know, Watson,')" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "execution_count": 4 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "以降の便利のため、`dataset` から学習用のバッチを作成する関数を定義します。例えば \"Do you know, Watson,\" という文章に対し, バッチは以下の2つの文字列の組から構成されます。\n", | |
| "\n", | |
| "- x: `Do you know, Wat`\n", | |
| "- y: `o you know, Wats`\n", | |
| "\n", | |
| "後述しますが、今回は「1単語から次の1単語を予測する」モデルを作ります。そのため、パッチは1文字ずらしたタプルになっています。" | |
| ], | |
| "metadata": { | |
| "id": "ixBGfeS87g69" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# バッチを作成する関数\n", | |
| "# 幅 xnum で1文字ずれた列を生成する\n", | |
| "# ex) ['rip has been upo', 'ip has been upon']\n", | |
| "# |<-----xnum----->| |<-----xnum----->|\n", | |
| "def get_batches(xnum: int, data: list[int] = dataset, split=\"train\") -> tuple[list[int], list[int]]:\n", | |
| " # 教師データを 8:2 で分割してバッチとする\n", | |
| " if split == \"train\":\n", | |
| " # 学習データ\n", | |
| " batchdata = data[:int(len(data)*0.8)]\n", | |
| " else:\n", | |
| " # 検証データ\n", | |
| " batchdata = data[int(len(data)*0.8):]\n", | |
| "\n", | |
| " # パッチの開始位置をランダムに決める\n", | |
| " idx = torch.randint(0, len(batchdata)-xnum-1, (1,))\n", | |
| " # パッチを取り出す\n", | |
| " x = batchdata[idx:idx+xnum]\n", | |
| " y = batchdata[idx+1:idx+xnum+1]\n", | |
| " return x, y\n", | |
| "\n", | |
| "# バッチのテスト\n", | |
| "[decode(l) for l in get_batches(16)]" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "uxSYnZjodkAy", | |
| "outputId": "a584d0f1-12a7-4bf9-e3af-7816953c7f75" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "['rip has been upo', 'ip has been upon']" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "execution_count": 5 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "これで準備が整いました。" | |
| ], | |
| "metadata": { | |
| "id": "nhGrIeL-9K7S" | |
| } | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "## はじめてのモデル" | |
| ], | |
| "metadata": { | |
| "id": "1mnExgl0Cvew" | |
| } | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "さて、はじめてのモデルを作っていきます。ここでは「1単語から次の1単語を予測する」モデルを作ります。これを繰り返し行うことで文章を生成することができます。\n", | |
| "\n", | |
| "まずは単語のベクトル化の関数を定義しましょう。詳細は発表スライドを参照してください。" | |
| ], | |
| "metadata": { | |
| "id": "XWbnir1y4HPl" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# 数リスト -> 文字ベクトル の関数\n", | |
| "def ltov(l: list[int]) -> torch.Tensor:\n", | |
| " # ワンホットベクトルに変換\n", | |
| " return torch.eye(len(vocab))[l]\n", | |
| "\n", | |
| "# logits -> 数リスト の関数 (確率分布に基づく)\n", | |
| "def vtol(v: torch.Tensor) -> list[int]:\n", | |
| " # 確率分布を計算\n", | |
| " p = F.softmax(v, dim=-1)\n", | |
| " # 確率分布を基に\n", | |
| " return torch.multinomial(p, num_samples=1).view(-1).tolist()\n", | |
| "\n", | |
| "# logits -> 数リスト の関数 (貪欲法)\n", | |
| "def vtol_greedy(v: torch.Tensor) -> list[int]:\n", | |
| " # 値が最大(⇔確率が最大)の文字を選択\n", | |
| " return torch.argmax(v, dim=-1).tolist()\n", | |
| "\n", | |
| "# 試しに文字ベクトルを logits として文字に戻す\n", | |
| "xs, _ = get_batches(16)\n", | |
| "logits = ltov(xs)\n", | |
| "decode(xs), decode(vtol(logits)), decode(vtol_greedy(logits))" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "b8wyKO8K9qZG", | |
| "outputId": "b3f6da4e-ed59-43d4-921f-7245d008cb72" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "('bited. The bedro', '½Qz,eKvHs0:6SdLQ', 'bited. The bedro')" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "execution_count": 6 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "次にモデルの中心となるニューラルネットワークを定義します。簡単のため ReLU を活性化関数に用いた中間層1層の順伝播型ニューラルネットワークとします: 式に書くと\n", | |
| "$$\n", | |
| "\\vec{y} = W_2 \\operatorname*{ReLU}(W_1 \\vec{x} + \\vec{b}_1) + \\vec{b}_2\n", | |
| "$$\n", | |
| "です。ここで、行列 $W_1$, $W_2$ とベクトル $\\vec{b}_1$, $\\vec{b}_2$ はモデルのパラメータです。成分表示すれば,\n", | |
| "$$\n", | |
| "\\begin{aligned}\n", | |
| "y_μ\n", | |
| " &= \\sum_ν \\left[ w_{2μν} \\operatorname*{ReLU}\\left(\\sum_λ w_{1νλ} x_λ + b_{1λ}\\right) + b_{2ν} \\right] \\\\\n", | |
| " &= \\sum_ν \\left[ w_{2μν} \\max\\left(\\sum_λ w_{1νλ} x_λ + b_{1λ}, 0\\right) + b_{2ν} \\right]\n", | |
| "\\end{aligned}\n", | |
| "$$\n", | |
| "となります。\n", | |
| "PyTorch の `nn.Sequential` を使ってモデルを定義しましょう。 `nn.Linear` は $\\vec{x} ↦ W \\vec{x} + \\vec{b}$, `nn.ReLU` は $\\vec{x} ↦ \\operatorname*{ReLU}(\\vec{x})$ を意味しています。" | |
| ], | |
| "metadata": { | |
| "id": "KjrBDikMruW0" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# モデルの隠れ層の要素数\n", | |
| "d_model = 128\n", | |
| "\n", | |
| "# モデルの定義\n", | |
| "model = nn.Sequential(\n", | |
| " nn.Linear(len(vocab), d_model),\n", | |
| " nn.ReLU(),\n", | |
| " nn.Linear(d_model, len(vocab)),\n", | |
| ")\n", | |
| "\n", | |
| "# モデルのパラメータのサイズ\n", | |
| "model, [m.numel() for m in model.parameters()]" | |
| ], | |
| "metadata": { | |
| "id": "JUZIrlQsA7uD", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "outputId": "d4188c8d-d4c4-4f67-e8f4-f01c0a6b1e32" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "(Sequential(\n", | |
| " (0): Linear(in_features=83, out_features=128, bias=True)\n", | |
| " (1): ReLU()\n", | |
| " (2): Linear(in_features=128, out_features=83, bias=True)\n", | |
| " ),\n", | |
| " [10624, 128, 10624, 83])" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "execution_count": 7 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "ここで、隠れ層の要素数は `hidden_layer_dim` としました。さっそくモデルを使ってみましょう。適当にバッチを取得して次の単語を予想させてみます。" | |
| ], | |
| "metadata": { | |
| "id": "Jca-Fql-ncdu" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# モデルのテスト\n", | |
| "xs, _ = get_batches(16)\n", | |
| "logits = model(ltov(xs))\n", | |
| "decode(xs), decode(vtol(logits)), decode(vtol_greedy(logits))" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "i3wJcrV34gf4", | |
| "outputId": "12c6eff0-4032-451a-9d6d-6c1accbcf72f" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "('ression upon me,', \"uTCf'G)zewtKLHr!\", 'Vlgggg(gMgg(ggll')" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "execution_count": 8 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "滅茶苦茶ですね。学習も何もしてないので当然の結果です。ついでに長い文章を生成する関数も定義します。" | |
| ], | |
| "metadata": { | |
| "id": "kZ4l65Y7niU8" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# 文章を生成する関数\n", | |
| "def generate(model: nn.Module, xnum=100) -> str:\n", | |
| " # 文章(数リスト)を格納する配列\n", | |
| " out: list[int] = []\n", | |
| " # 初期状態(1文字目をランダムに選択)\n", | |
| " out += torch.randint(0, len(vocab), (1,)).tolist()\n", | |
| "\n", | |
| " # 文章の長さだけ繰り返し\n", | |
| " for i in range(xnum):\n", | |
| " # 結果の最後の文字に対しモデルを適用\n", | |
| " ys = vtol(model(ltov(out[-1:])))\n", | |
| " # 結果に追加\n", | |
| " out += [ys[-1]]\n", | |
| "\n", | |
| " return decode(out)\n", | |
| "\n", | |
| "# 文章を生成してみる\n", | |
| "print(generate(model))" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "utMCQ2f0rFeq", | |
| "outputId": "6084b85f-2281-426e-8aff-34d045c8ef0e" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "W6)TfOcQKfélo7h]è(0o?eNSRB7U1vh½m;\"l)£JHB9!h£s&Kgm7TfhDy58b:?2m5B'n\" fS£gW[-VeFPK3P9:léiVEVLQtC?V7Geo\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "見るからに出鱈目です。ですが、どのくらい出鱈目でしょうか?モデルを学習させるためにはこの「出鱈目さ」を計算する必要があります。" | |
| ], | |
| "metadata": { | |
| "id": "2srvCmyEUizZ" | |
| } | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "## モデルの評価" | |
| ], | |
| "metadata": { | |
| "id": "qoQ6OdN7DPqs" | |
| } | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "作ったモデルを定量的に評価することを考えます。機械学習においては「損失関数」と呼ばれる関数を用いて、教師データとモデルのずれを評価します。ここでは「交差エントロピー」という関数を使います。詳細は発表スライドを参照してください。" | |
| ], | |
| "metadata": { | |
| "id": "Wv-EQh70nuTZ" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# 損失関数に交差エントロピーを採用する\n", | |
| "loss_fn = nn.CrossEntropyLoss()\n", | |
| "\n", | |
| "# バッチに対して損失関数を計算してみる\n", | |
| "xs, ys = get_batches(16)\n", | |
| "loss = loss_fn(model(ltov(xs)), ltov(ys))\n", | |
| "loss.item()" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "OsGfl9doGV-d", | |
| "outputId": "b96f9d5d-bda3-482c-d812-29acdaf0d22e" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "4.4361042976379395" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "execution_count": 10 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "この値はどう解釈すればいいのでしょうか? 今回の教師データ $\\vec{y}$ はワンホットベクトルで, ここでは $\\vec{y}$ が単語 $μ$ を表しているときの確率分布が $p = \\{p_ν\\} = \\{δ_{μν}\\}$ としてみましょう。(実際、`nn.CrossEntropyLoss()` ではこのような扱いになります。) このとき1単語のみの損失関数は\n", | |
| "$$\n", | |
| "\\begin{aligned}\n", | |
| "\\mathtt{loss}\n", | |
| " &= ⟨- \\log q⟩ \\\\\n", | |
| " &= - \\sum_ν p_ν \\log q_ν \\\\\n", | |
| " &= - \\sum_ν δ_{μν} \\log q_ν \\\\\n", | |
| " &= - \\log q_μ \\\\\n", | |
| "\\end{aligned}\n", | |
| "$$\n", | |
| "であるので、損失関数の値 `loss` から単語 $μ$ の生成確率がわかります:\n", | |
| "$$\n", | |
| "q_μ = e^{-\\mathtt{loss}}\n", | |
| "$$\n", | |
| "実際には損失関数は複数の単語についての平均値になるので、正しい単語の生成確率が\n", | |
| "$$\n", | |
| "q = e^{-\\mathtt{loss}}\n", | |
| "$$\n", | |
| "になるくらいの目安です。" | |
| ], | |
| "metadata": { | |
| "id": "NXygYyddKFyH" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# 単語の生成確率\n", | |
| "q = torch.exp(-loss)\n", | |
| "q.item()" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "jVAZ0LItUlph", | |
| "outputId": "85f4ad5c-c9f2-4ac2-8435-c97c62905624" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "0.011841981671750546" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "execution_count": 11 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "これでもまだよくわかりませんね。ところで、$N$ 個の単語から無作為に単語を選ぶとき、正しい単語が選ばれる確率は $\\displaystyle q = \\frac1N$ となります。つまり, $1/q\\ (=e^{\\mathtt{loss}})$ は無作為に単語を選び出したとするときの「母数」を表しています。これを計算してみましょう。" | |
| ], | |
| "metadata": { | |
| "id": "HrTgVnCeYNqa" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# 単語生成の\"母数\"\n", | |
| "1/q.item(), len(vocab)" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "mB2S_EfHcQWT", | |
| "outputId": "a3a9ade0-2a92-40ca-a7ec-8270b4434cce" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "(84.44532576718426, 83)" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "execution_count": 12 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "つまり $1/q≃\\mathtt{len(vocab)}$ であることがわかりました。モデルはほとんど無作為に単語を選んでいるということです。学習も何もしていないので当然の結果ですね。それではこれを改善すべく、モデルを学習させていきましょう。" | |
| ], | |
| "metadata": { | |
| "id": "9gdbURiQc1xP" | |
| } | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "## モデルの学習" | |
| ], | |
| "metadata": { | |
| "id": "T2qsjzR5fZgI" | |
| } | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "モデルの学習をします。学習とは損失関数の値が小さくなるようパラメータの値を変えていく作業です。詳細は発表スライドを参照してください。ここでは [Adam](https://arxiv.org/abs/1412.6980) という確率的勾配降下法の改良を利用して学習を実行します。" | |
| ], | |
| "metadata": { | |
| "id": "yzZ3fA-crwz2" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# 学習 (最適化手法に Adam を採用する)\n", | |
| "def train(model: nn.Module, optim=torch.optim.Adam, epochs=100):\n", | |
| " # optimizer を初期化\n", | |
| " optimizer = optim(model.parameters())\n", | |
| " # 学習過程を記録する配列\n", | |
| " logs: dict[str, list[float]] = {\"train\": [], \"val\": []}\n", | |
| "\n", | |
| " # epochs の回数だけ学習を繰り返す\n", | |
| " for epoch in range(epochs):\n", | |
| " # 一時的に学習過程を記録する配列(後で平均を取って log に加える)\n", | |
| " log_temp: dict[str, list[float]] = {\"train\": [], \"val\": []}\n", | |
| "\n", | |
| " # 学習と検証を10回ずつ実行\n", | |
| " for _ in range(10):\n", | |
| " for split in [\"train\", \"val\"]:\n", | |
| " if split == \"train\":\n", | |
| " # 学習用\n", | |
| " model.train()\n", | |
| " else:\n", | |
| " # 検証用\n", | |
| " model.eval()\n", | |
| "\n", | |
| " # バッチを取得\n", | |
| " xs, ys = get_batches(16, split=split)\n", | |
| " # 損失関数を計算\n", | |
| " loss = loss_fn(model(ltov(xs)), torch.tensor(ys))\n", | |
| " # 「母数」の値を記録\n", | |
| " log_temp[split] += [torch.exp(loss).item()]\n", | |
| "\n", | |
| " # 学習を実行\n", | |
| " if split == \"train\":\n", | |
| " # 勾配を初期化\n", | |
| " optimizer.zero_grad()\n", | |
| " # 誤差逆伝播\n", | |
| " loss.backward()\n", | |
| " # 学習\n", | |
| " optimizer.step()\n", | |
| "\n", | |
| " # 10回の学習・検証の値を平均して記録\n", | |
| " for split in [\"train\", \"val\"]:\n", | |
| " logs[split] += [float(np.mean(log_temp[split]))]\n", | |
| "\n", | |
| " # 学習過程のグラフを返す\n", | |
| " return pd.DataFrame(logs).plot()\n", | |
| "\n", | |
| "# 学習を実行\n", | |
| "train(model)" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 448 | |
| }, | |
| "id": "uO7evbn_ffcS", | |
| "outputId": "ff1669cb-aa5c-4519-8a99-88c4e1ca634f" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "<Axes: >" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "execution_count": 13 | |
| }, | |
| { | |
| "output_type": "display_data", | |
| "data": { | |
| "text/plain": [ | |
| "<Figure size 640x480 with 1 Axes>" | |
| ], | |
| "image/png": "\n" | |
| }, | |
| "metadata": {} | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "上記のグラフの縦軸は前述の「無作為に単語を選び出したとするときの『母数』」で、横軸は `epoch` 数です。学習が繰り返されるごとに「母数」は減っていき、最終的に15語程度になっていることがわかります。これはすごい進化です!もはや無作為な単語の羅列では無いはずです。さっそく文章を生成してみましょう。" | |
| ], | |
| "metadata": { | |
| "id": "utcBZIaR7-qA" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "print(generate(model))" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "KEVrFHxJp_2d", | |
| "outputId": "ff845009-57bc-4ff4-c685-2db4cc32ac46" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "pdN.apine t ul. Md my itie, Mrd, Tgooncaad he weny t syurthes a BW. \"ucofe, mounl, ot b cpan stohiva\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "学習前は出鱈目な羅列だったのが、学習によって「薄目で見れば英語かもしれない」程度の文章を生成するようになりました。現実に存在する単語もいくつかあります。当初の目標だった、「それっぽく」動く言語モデルの完成です。" | |
| ], | |
| "metadata": { | |
| "id": "F8lHDleV9Mvy" | |
| } | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "## 参考文献" | |
| ], | |
| "metadata": { | |
| "id": "_7VywL9-_sog" | |
| } | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "理論は以下を参考にしました:\n", | |
| "\n", | |
| "- 岡﨑 直観, 荒瀬 由紀, 鈴木 潤, 鶴岡 慶雅, 宮尾 祐介.『IT Text 自然言語処理の基礎』(オーム社, 2022)\n", | |
| "\n", | |
| "実装は以下を参考にしました:\n", | |
| "\n", | |
| "- [Llama from scratch (or how to implement a paper without crying) | Brian Kitano](https://blog.briankitano.com/llama-from-scratch/)\n", | |
| "- [CrossEntropyLoss — PyTorch 2.3 documentation](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html)" | |
| ], | |
| "metadata": { | |
| "id": "0CuwmNmo_uEV" | |
| } | |
| } | |
| ] | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment