Last active
August 19, 2023 07:55
-
-
Save masuidrive/d810f2b4c0b52b041a83f5c3a80f0289 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
https://github.com/karpathy/llama2.c/blob/master/run.c | |
GPT-4による解説 | |
このプログラムは、Transformerネットワークを実装し、トークン化されたテキスト入力から次の最も適したトークンを予測します。具体的には以下のようになります: | |
先頭の部分は、TransformerWeightsとRunStateという2つのデータ構造とそれらの関連するメモリの管理を含みます。 | |
Configという構造体は、トランスフォーマーネットワークのパラメータを保持します。 | |
次に、指定されたチェックポイントファイルから重みを初期化する関数があります。この関数は、チェックポイントファイルからトランスフォーマーネットワークの重みを読み込み、適切に配置します。 | |
さらに、各種ニューラルネットワークのブロック(関数)が存在します: | |
- 'accum': ベクトル'a'にベクトル'b'の要素を加えます。 | |
- 'rmsnorm': RMS正規化を行います。 | |
- 'softmax': ソフトマックスを適用します。 | |
- 'matmul': 行列の積算を行います。 | |
'transformer'関数は、全てのレイヤーを通じてトークンをプッシュします。これは、各トークンに対してアテンションとフィードフォワードネットワークを通じて情報を伝播します。 | |
'main'関数では、特定の条件(温度とステップ数)で指定されたチェックポイントファイルを基に入力から次のトークンをサンプリングします。ここで、サンプリングは確率的(温度が0でない場合)または確定的(温度が0の場合)に行われます。 | |
最終的には、このモデルをいくつのトークンを処理できるか、単位時間におけるトークン処理数(トークン/秒)を出力します。 | |
*/ | |
/* https://github.com/karpathy/llama2.c/blob/master/run.c */ | |
/* | |
Inference for Llama-2 Transformer model in pure C. | |
Example compile: (see README for more details) | |
$ gcc -O3 -o run run.c -lm | |
Then run with: | |
$ ./run | |
*/ | |
//各々のライブラリをインポートしています。 | |
#include <stdio.h> | |
#include <stdlib.h> | |
#include <time.h> | |
#include <math.h> | |
#include <string.h> | |
#include <unistd.h> | |
#include <fcntl.h> | |
#include <sys/mman.h> | |
// TransformerとRunState構造体、関連するメモリ管理 | |
typedef struct { | |
int dim; // transformerの次元 | |
int hidden_dim; // ffnレイヤー用 | |
int n_layers; // レイヤーの数 | |
int n_heads; // query headの数 | |
int n_kv_heads; // キー/値 headの数 (multiqueryのためにquery headよりも少なくてもOK) | |
int vocab_size; // 語彙のサイズ、通常は256 (バイトレベル) | |
int seq_len; // 最大シーケンス長 | |
} Config; | |
typedef struct { | |
// トークン埋め込みテーブル | |
float* token_embedding_table; // (語彙サイズ, 次元) | |
// rmsnormsのための重み | |
float* rms_att_weight; // (レイヤー, 次元) rmsnormの重み | |
float* rms_ffn_weight; // (レイヤー, 次元) | |
// matmulsのための重み | |
float* wq; // (レイヤー, 次元, 次元) | |
float* wk; // (レイヤー, 次元, 次元) | |
float* wv; // (レイヤー, 次元, 次元) | |
float* wo; // (レイヤー, 次元, 次元) | |
// ffnのための重み | |
float* w1; // (レイヤー, hidden_dim, 次元) | |
float* w2; // (レイヤー, 次元, hidden_dim) | |
float* w3; // (レイヤー, hidden_dim, 次元) | |
// 最終rmsnorm | |
float* rms_final_weight; // (次元,) | |
// RoPE相対位置埋め込みのためのfreq_cis | |
float* freq_cis_real; // (シーケンス長, 次元/2) | |
float* freq_cis_imag; // (シーケンス長, 次元/2) | |
// (オプション)最終レイヤーのlogitsの分類器重み | |
float* wcls; | |
} TransformerWeights; | |
typedef struct { | |
// 現在の活性化の波 | |
float *x; // カレントタイムスタンプの活性化 (次元,) | |
float *xb; // 同様、ただし内部に残留ブランチ (次元,) | |
float *xb2; // 便宜上追加のバッファ (次元,) | |
float *hb; // ffn内の隠れ次元のバッファ (hidden_dim,) | |
float *hb2; // ffn内の隠れ次元のバッファ (hidden_dim,) | |
float *q; // query (次元,) | |
float *k; // key (次元,) | |
float *v; // value (次元,) | |
float *att; // スコア/注目値のバッファ (n_heads, シーケンス長) | |
float *logits; // 出力logits | |
// kvキャッシュ | |
float* key_cache; // (レイヤー, シーケンス長, 次元) | |
float* value_cache; // (レイヤー, シーケンス長, 次元) | |
} RunState; | |
// RunStateの確保と初期化を行う関数 | |
void malloc_run_state(RunState* s, Config* p) { | |
// valgrindを満足させるためにcallocではなくmallocを使用 | |
s->x = calloc(p->dim, sizeof(float)); | |
s->xb = calloc(p->dim, sizeof(float)); | |
s->xb2 = calloc(p->dim, sizeof(float)); | |
s->hb = calloc(p->hidden_dim, sizeof(float)); | |
s->hb2 = calloc(p->hidden_dim, sizeof(float)); | |
s->q = calloc(p->dim, sizeof(float)); | |
s->k = calloc(p->dim, sizeof(float)); | |
s->v = calloc(p->dim, sizeof(float)); | |
s->att = calloc(p->n_heads * p->seq_len, sizeof(float)); | |
s->logits = calloc(p->vocab_size, sizeof(float)); | |
s->key_cache = calloc(p->n_layers * p->seq_len * p->dim, sizeof(float)); | |
s->value_cache = calloc(p->n_layers * p->seq_len * p->dim, sizeof(float)); | |
// すべてのmallocが正常に実行されていることを確認します | |
if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q | |
|| !s->k || !s->v || !s->att || !s->logits || !s->key_cache | |
|| !s->value_cache) { | |
printf("malloc failed!\n"); | |
exit(1); | |
} | |
} | |
// RunStateのメモリ解放関数 | |
void free_run_state(RunState* s) { | |
free(s->x); | |
free(s->xb); | |
free(s->xb2); | |
free(s->hb); | |
free(s->hb2); | |
free(s->q); | |
free(s->k); | |
free(s->v); | |
free(s->att); | |
free(s->logits); | |
free(s->key_cache); | |
free(s->value_cache); | |
} | |
// -------------------------------------------------------------- | |
// 初期化:checkpointから読み出し | |
// checkpointから重みを初期化する関数 | |
void checkpoint_init_weights(TransformerWeights *w, Config* p, float* f, int shared_weights) { | |
float* ptr = f; | |
w->token_embedding_table = ptr; | |
ptr += p->vocab_size * p->dim; | |
w->rms_att_weight = ptr; | |
ptr += p->n_layers * p->dim; | |
w->wq = ptr; | |
ptr += p->n_layers * p->dim * p->dim; | |
w->wk = ptr; | |
ptr += p->n_layers * p->dim * p->dim; | |
w->wv = ptr; | |
ptr += p->n_layers * p->dim * p->dim; | |
w->wo = ptr; | |
ptr += p->n_layers * p->dim * p->dim; | |
w->rms_ffn_weight = ptr; | |
ptr += p->n_layers * p->dim; | |
w->w1 = ptr; | |
ptr += p->n_layers * p->dim * p->hidden_dim; | |
w->w2 = ptr; | |
ptr += p->n_layers * p->hidden_dim * p->dim; | |
w->w3 = ptr; | |
ptr += p->n_layers * p->dim * p->hidden_dim; | |
w->rms_final_weight = ptr; | |
ptr += p->dim; | |
w->freq_cis_real = ptr; | |
int head_size = p->dim / p->n_heads; | |
ptr += p->seq_len * head_size / 2; | |
w->freq_cis_imag = ptr; | |
ptr += p->seq_len * head_size / 2; | |
w->wcls = shared_weights ? w->token_embedding_table : ptr; | |
} | |
// -------------------------------------------------------------- | |
// ニューラルネットワークのブロック | |
// ベクトルbをベクトルaに累積する関数 | |
void accum(float *a, float *b, int size) { | |
// 全ての要素に対してbをaに加算 | |
for (int i = 0; i < size; i++) { | |
a[i] += b[i]; | |
} | |
} | |
// RMS正規化を行う関数 | |
void rmsnorm(float* o, float* x, float* weight, int size) { | |
// 平方和を計算 | |
float ss = 0.0f; | |
for (int j = 0; j < size; j++) { | |
ss += x[j] * x[j]; | |
} | |
ss /= size; | |
ss += 1e-5f; | |
ss = 1.0f / sqrtf(ss); | |
// 正規化してスケール | |
for (int j = 0; j < size; j++) { | |
o[j] = weight[j] * (ss * x[j]); | |
} | |
} | |
// softmax関数です。配列xの全ての要素に対してsoftmaxを計算します。 | |
void softmax(float* x, int size) { | |
// 最大の値を見つける(数値安定性のため) | |
float max_val = x[0]; | |
for (int i = 1; i < size; i++) { | |
if (x[i] > max_val) { | |
max_val = x[i]; | |
} | |
} | |
// expを計算し、合計を求める | |
float sum = 0.0f; | |
for (int i = 0; i < size; i++) { | |
x[i] = expf(x[i] - max_val); | |
sum += x[i]; | |
} | |
// 正規化する | |
for (int i = 0; i < size; i++) { | |
x[i] /= sum; | |
} | |
} | |
// 行列乗算の計算をします。入力行列xに重み行列wを掛け合わせ、結果をxoutに格納します。 | |
void matmul(float* xout, float* x, float* w, int n, int d) { | |
// W (d,n) @ x (n,) -> xout (d,) | |
#pragma omp parallel for | |
for (int i = 0; i < d; i++) { | |
float val = 0.0f; | |
for (int j = 0; j < n; j++) { | |
val += w[i * n + j] * x[j]; | |
} | |
xout[i] = val; | |
} | |
} | |
// この関数は、指定されたトークンと位置に基づいてTransformerモデルを前方に進行させます。 | |
// 具体的には、埋め込みのコピー、positionの抽出、すべての層についてのループ(注意スコアの計算、softmaxの適用、 | |
// 重み付き和の計算、残差の接続)、最終的なSoftmaxを経て、分類器によるLogitsの計算を行います。 | |
void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights* w) { | |
// いくつかの便利な変数を定義します。 | |
float *x = s->x; | |
int dim = p->dim; | |
int hidden_dim = p->hidden_dim; | |
int head_size = dim / p->n_heads; | |
// トークンの埋め込みをxにコピーします。 | |
float* content_row = &(w->token_embedding_table[token * dim]); | |
memcpy(x, content_row, dim*sizeof(*x)); | |
// freq_cis_realとfreq_cis_imagの"pos"行を求めます。 | |
float* freq_cis_real_row = w->freq_cis_real + pos * head_size / 2; | |
float* freq_cis_imag_row = w->freq_cis_imag + pos * head_size / 2; | |
// 全てのレイヤーについて順に進めます。 | |
for(int l = 0; l < p->n_layers; l++) { | |
// 注目ベクトルのrmsnorm | |
rmsnorm(s->xb, x, w->rms_att_weight + l*dim, dim); | |
// 該当する位置に対するqkv matmuls | |
matmul(s->q, s->xb, w->wq + l*dim*dim, dim, dim); | |
matmul(s->k, s->xb, w->wk + l*dim*dim, dim, dim); | |
matmul(s->v, s->xb, w->wv + l*dim*dim, dim, dim); | |
// 各ヘッドに対してRoPE回転をqとkベクトルに適用します。 | |
for (int h = 0; h < p->n_heads; h++) { | |
// このヘッドのqとkベクトルを取ってきます。 | |
float* q = s->q + h * head_size; | |
float* k = s->k + h * head_size; | |
// freq_cis_realとfreq_cis_imagによるqとkの回転 | |
for (int i = 0; i < head_size; i+=2) { | |
float q0 = q[i]; | |
float q1 = q[i+1]; | |
float k0 = k[i]; | |
float k1 = k[i+1]; | |
float fcr = freq_cis_real_row[i/2]; | |
float fci = freq_cis_imag_row[i/2]; | |
q[i] = q0 * fcr - q1 * fci; | |
q[i+1] = q0 * fci + q1 * fcr; | |
k[i] = k0 * fcr - k1 * fci; | |
k[i+1] = k0 * fci + k1 * fcr; | |
} | |
} | |
// このタイムステップ(pos)のキーと値をkvキャッシュに保存します。 | |
int loff = l * p->seq_len * dim; | |
float* key_cache_row = s->key_cache + loff + pos * dim; | |
float* value_cache_row = s->value_cache + loff + pos * dim; | |
memcpy(key_cache_row, s->k, dim*sizeof(*key_cache_row)); | |
memcpy(value_cache_row, s->v, dim*sizeof(*value_cache_row)); | |
// multihead注意処理。全てのヘッドについて繰り返します。 | |
#pragma omp parallel for | |
for (int h = 0; h < p->n_heads; h++) { | |
// このヘッドのクエリーベクトルを取得 | |
float* q = s->q + h * head_size; | |
// このヘッドの注意スコア | |
float* att = s->att + h * p->seq_len; | |
// すべてのタイムステップについて繰り返し実行、現在のものを含む | |
for (int t = 0; t <= pos; t++) { | |
// このヘッドとこのタイムステップのキーベクトルを取得 | |
float* k = s->key_cache + loff + t * dim + h * head_size; | |
// 注意スコアをqとkのドット積として計算します | |
float score = 0.0f; | |
for (int i = 0; i < head_size; i++) { | |
score += q[i] * k[i]; | |
} | |
score /= sqrtf(head_size); | |
// スコアを注意バッファに保存します | |
att[t] = score; | |
} | |
// 0..posのスコアをsoftmax適用して注意の重みを得る。 | |
softmax(att, pos + 1); | |
// 値の加重和を計算し、xbに保存します。 | |
for (int i = 0; i < head_size; i++) { | |
float val = 0.0f; | |
for (int t = 0; t <= pos; t++) { | |
val += att[t] * s->value_cache[loff + t * dim + h * head_size + i]; | |
} | |
s->xb[h * head_size + i] = val; | |
} | |
} | |
// 注意の出力を得るための最終的な行列乗算 | |
matmul(s->xb2, s->xb, w->wo + l*dim*dim, dim, dim); | |
// レジデュアル接続をxにバック | |
accum(x, s->xb2, dim); | |
// フィードフォワードネットワーク用OFRMノーム | |
rmsnorm(s->xb, x, w->rms_ffn_weight + l*dim, dim); | |
// PyTorchのFFNでは次のようになります: self.w2(F.silu(self.w1(x)) * self.w3(x)) | |
// 最初にself.w1(x)とself.w3(x)を計算します | |
matmul(s->hb, s->xb, w->w1 + l*dim*hidden_dim, dim, hidden_dim); | |
matmul(s->hb2, s->xb, w->w3 + l*dim*hidden_dim, dim, hidden_dim); | |
// F.silu; silu(x)=x*σ(x),where σ(x) is the logistic sigmoid | |
for (int i = 0; i < hidden_dim; i++) { | |
s->hb[i] = s->hb[i] * (1.0f / (1.0f + expf(-s->hb[i]))); | |
} | |
// w3(x)との要素ごとの積 | |
for (int i = 0; i < hidden_dim; i++) { | |
s->hb[i] = s->hb[i] * s->hb2[i]; | |
} | |
// FFNの出力を得るための最終的なmatmul | |
matmul(s->xb, s->hb, w->w2 + l*dim*hidden_dim, hidden_dim, dim); | |
// レジデュアル接続 | |
accum(x, s->xb, dim); | |
} | |
// 最終的なrmsnorm | |
rmsnorm(x, x, w->rms_final_weight, dim); | |
// ロジットへのクラス分類器 | |
matmul(s->logits, x, w->wcls, p->dim, p->vocab_size); | |
} | |
// 確率配列からインデックスをサンプリングします。確率の総和は1でなければなりません。 | |
// これにより、softmax関数の出力(確率)に基づいて次のトークンが選択されます | |
int sample(float* probabilities, int n) { | |
// sample index from probabilities, they must sum to 1 | |
float r = (float)rand() / (float)RAND_MAX; | |
float cdf = 0.0f; | |
for (int i = 0; i < n; i++) { | |
cdf += probabilities[i]; | |
if (r < cdf) { | |
return i; | |
} | |
} | |
return n - 1; // in case of rounding errors | |
} | |
// 配列 v の中で最も大きい値を持つインデックスを返します。 | |
// これは、temperatureが0の場合に使用されます。(すなわち、確率最大のトークンだけが選択されます) | |
int argmax(float* v, int n) { | |
// return argmax of v in elements 0..n | |
int max_i = 0; | |
float max_p = v[0]; | |
for (int i = 1; i < n; i++) { | |
if (v[i] > max_p) { | |
max_i = i; | |
max_p = v[i]; | |
} | |
} | |
return max_i; | |
} | |
// ---------------------------------------------------------------------------- | |
long time_in_ms() { | |
struct timespec time; | |
// Get the current time with nanosecond precision | |
if (clock_gettime(CLOCK_REALTIME, &time) == 0) { | |
return time.tv_sec * 1000 + time.tv_nsec / 1000000; | |
} else { | |
perror("clock_gettime"); | |
return -1; // Return -1 to indicate an error | |
} | |
} | |
// C型の引数解析を行っています | |
// 必要な引数は 'checkpoint' (モデルの重みが保存されたファイル)です | |
// オプション引数には、生成するテキストの多様性を制御するための 'temperature' があり、 | |
// また最大のステップ数も設定できます | |
// 乱数生成器は現在時刻でシードされます。決定論的な動作が望まれる場合は、temperatureを0.0に設定します | |
// | |
// 残りの部分は、モデルとトークナイザを読み込み、メモリ内で必要なスペースを確保し、 | |
// トークン生成プロセスが終了するまでトークンを繰り返し生成します。 | |
int main(int argc, char *argv[]) { | |
// C型の引数解析を行っています | |
char *checkpoint = NULL; // e.g. out/model.bin | |
float temperature = 0.9f; // e.g. 1.0, or 0.0 | |
int steps = 256; // max number of steps to run for, 0: use seq_len | |
// 'checkpoint' is necessary arg | |
if (argc < 2) { | |
printf("Usage: %s <checkpoint_file> [temperature] [steps]\n", argv[0]); | |
return 1; | |
} | |
if (argc >= 2) { | |
checkpoint = argv[1]; | |
} | |
if (argc >= 3) { | |
// optional temperature. 0.0 = (deterministic) argmax sampling. 1.0 = baseline | |
temperature = atof(argv[2]); | |
} | |
if (argc >= 4) { | |
steps = atoi(argv[3]); | |
} | |
// seed rng with time. if you want deterministic behavior use temperature 0.0 | |
srand((unsigned int)time(NULL)); | |
// read in the model.bin file | |
Config config; | |
TransformerWeights weights; | |
int fd = 0; | |
float* data = NULL; | |
long file_size; | |
{ | |
FILE *file = fopen(checkpoint, "rb"); | |
if (!file) { | |
printf("Unable to open the checkpoint file %s!\n", checkpoint); | |
return 1; | |
} | |
// read in the config header | |
if(fread(&config, sizeof(Config), 1, file) != 1) { return 1; } | |
// negative vocab size is hacky way of signaling unshared weights. bit yikes. | |
int shared_weights = config.vocab_size > 0 ? 1 : 0; | |
config.vocab_size = abs(config.vocab_size); | |
// figure out the file size | |
fseek(file, 0, SEEK_END); // move file pointer to end of file | |
file_size = ftell(file); // get the file size, in bytes | |
fclose(file); | |
// memory map the Transformer weights into the data pointer | |
fd = open(checkpoint, O_RDONLY); // open in read only mode | |
if (fd == -1) { printf("open failed!\n"); return 1; } | |
data = mmap(NULL, file_size, PROT_READ, MAP_PRIVATE, fd, 0); | |
if (data == MAP_FAILED) { printf("mmap failed!\n"); return 1; } | |
float* weights_ptr = data + sizeof(Config)/sizeof(float); | |
checkpoint_init_weights(&weights, &config, weights_ptr, shared_weights); | |
} | |
// right now we cannot run for more than config.seq_len steps | |
if (steps <= 0 || steps > config.seq_len) { steps = config.seq_len; } | |
// read in the tokenizer.bin file | |
char** vocab = (char**)malloc(config.vocab_size * sizeof(char*)); | |
{ | |
FILE *file = fopen("tokenizer.bin", "rb"); | |
if (!file) { | |
printf("Unable to open the tokenizer file tokenizer.bin! Run " | |
"python tokenizer.py to convert tokenizer.model -> tokenizer.bin\n"); | |
return 1; | |
} | |
int len; | |
for (int i = 0; i < config.vocab_size; i++) { | |
if(fread(&len, sizeof(int), 1, file) != 1) { return 1; } | |
vocab[i] = (char *)malloc(len + 1); | |
if(fread(vocab[i], len, 1, file) != 1) { return 1; } | |
vocab[i][len] = '\0'; // add the string terminating token | |
} | |
fclose(file); | |
} | |
// create and init the application RunState | |
RunState state; | |
malloc_run_state(&state, &config); | |
// the current position we are in | |
long start = time_in_ms(); | |
int next; | |
int token = 1; // 1 = BOS token in Llama-2 sentencepiece | |
int pos = 0; | |
printf("<s>\n"); // explicit print the initial BOS token (=1), stylistically symmetric | |
while (pos < steps) { | |
// forward the transformer to get logits for the next token | |
transformer(token, pos, &config, &state, &weights); | |
// sample the next token | |
if(temperature == 0.0f) { | |
// greedy argmax sampling | |
next = argmax(state.logits, config.vocab_size); | |
} else { | |
// apply the temperature to the logits | |
for (int q=0; q<config.vocab_size; q++) { state.logits[q] /= temperature; } | |
// apply softmax to the logits to get the probabilities for next token | |
softmax(state.logits, config.vocab_size); | |
// we now want to sample from this distribution to get the next token | |
next = sample(state.logits, config.vocab_size); | |
} | |
printf("%s", vocab[next]); | |
fflush(stdout); | |
// advance forward | |
token = next; | |
pos++; | |
} | |
// report achieved tok/s | |
long end = time_in_ms(); | |
printf("\nachieved tok/s: %f\n", steps / (double)(end-start)*1000); | |
// memory and file handles cleanup | |
free_run_state(&state); | |
for (int i = 0; i < config.vocab_size; i++) { free(vocab[i]); } | |
free(vocab); | |
if (data != MAP_FAILED) munmap(data, file_size); | |
if (fd != -1) close(fd); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment