Last active
April 27, 2023 17:18
-
-
Save ypeleg/46882125456f815837d6f9babd44807b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import pickle | |
import struct | |
import zipfile | |
import numpy as np | |
from sentencepiece import SentencePieceProcessor | |
def rms_norm(x): return (x / np.sqrt(np.square(x).mean(-1, keepdims=True) + 1e-6)) | |
def softmax(x): return (np.exp(x - np.max(x, axis=-1, keepdims=True))) / np.sum((np.exp(x - np.max(x, axis=-1, keepdims=True))), axis=-1, keepdims = True) | |
def load_model(model_size = "7B", models_dir = "models"): | |
tokenizer = SentencePieceProcessor(model_file=f"{models_dir}/tokenizer.model") | |
with open(f'{models_dir}/{model_size}/params.json') as f: params = json.load(f) | |
with zipfile.ZipFile(f'{models_dir}/{model_size}/consolidated.00.pth', 'r') as zip: | |
class Unpickler(pickle.Unpickler): | |
def find_class(self, mod_name, name): | |
if mod_name == 'torch._utils' and name == '_rebuild_tensor_v2': | |
def rebuild_tensor(storage, storage_offset, size, stride, requires_grad, backward_hooks): | |
return np.lib.stride_tricks.as_strided(storage, size, np.array(stride) * storage.strides, writeable = False) | |
return rebuild_tensor | |
if mod_name == 'torch' and name == 'HalfStorage': return np.half | |
return super().find_class(mod_name, name) | |
def persistent_load(self, saved_id): | |
(typename, dtype, key, location, count) = saved_id | |
with open(f'{models_dir}/{model_size}/consolidated.00.pth', 'br') as f: | |
f.seek(zip.getinfo(f'consolidated/data/{key}').header_offset) | |
(header, _, n_name, n_extra) = struct.unpack('<4s22shh', f.read(30)) | |
return np.memmap(zip.fp, dtype = dtype, mode = 'r', | |
offset = (zip.getinfo(f'consolidated/data/{key}').header_offset + 30 + n_name + n_extra), shape = (count,)) | |
with zip.open('consolidated/data.pkl') as data_pkl: data = Unpickler(data_pkl).load() | |
return (tokenizer, params, data) | |
def main(prompt = "The meaning of life is", temperature = 0.8, n_tokens_to_generate = 5, model_size = "7B", models_dir = "llama_weights"): | |
(tokenizer, params, data) = load_model(model_size, models_dir) | |
prev_pos = 0 | |
tokens = [tokenizer.bos_id()] + tokenizer.encode(prompt) | |
cache_k, cache_v = ([None] * params['n_layers']), ([None] * params['n_layers']) | |
freqs_cis = np.exp(1j * np.outer(np.arange(2 * 512), (np.logspace(0, 1.0, base = 1e-4, # max_seq_len = 512 | |
num = (params['dim'] // params['n_heads']) // 2, endpoint = False)))).astype(np.complex64) | |
for cur_pos in range(len(tokens), len(tokens) + n_tokens_to_generate): | |
h = data['tok_embeddings.weight'][tokens[prev_pos: cur_pos], :].astype(np.float32) # Embed tokens | |
f = freqs_cis[prev_pos:cur_pos].reshape(-1, 1, (params['dim'] // params['n_heads']) // 2) # Rotary embedding | |
for layer in range(params['n_layers']): | |
xn = rms_norm(h) * data[f'layers.{layer}.attention_norm.weight'] # LayerNorm | |
xq = (xn @ data[f'layers.{layer}.attention.wq.weight'].T).reshape((-1, params['n_heads'], (params['dim'] // params['n_heads']))) # QKV projections | |
xk = (xn @ data[f'layers.{layer}.attention.wk.weight'].T).reshape((-1, params['n_heads'], (params['dim'] // params['n_heads']))) # QKV projections | |
xv = (xn @ data[f'layers.{layer}.attention.wv.weight'].T).reshape((-1, params['n_heads'], (params['dim'] // params['n_heads']))) # QKV projections | |
xq = (xq.view(dtype = np.complex64) * f).view(dtype = np.float32) # Rotary embedding | |
xk = (xk.view(dtype = np.complex64) * f).view(dtype = np.float32) # Rotary embedding | |
if prev_pos == 0: cache_k[layer], cache_v[layer] = xk, xv # Cache | |
else: xk, xv = cache_k[layer], cache_v[layer] = np.concatenate((cache_k[layer], xk), axis=0), np.concatenate((cache_v[layer], xv), axis = 0)# Cache | |
scores = np.matmul(xk, xq, axes=[(0,2),(2,0),(2,1)]) / np.sqrt((params['dim'] // params['n_heads'])) # Attention | |
if (cur_pos - prev_pos) > 1: scores += (-1e10 * (1 - np.tri(cur_pos - prev_pos))) # Mask | |
h += (np.matmul(softmax(scores), xv, axes=[(1,2), (0,2), (0,2)]).reshape(-1, params['dim'])) @ data[f'layers.{layer}.attention.wo.weight'].T# Attention | |
x1 = xn @ data[f'layers.{layer}.feed_forward.w1.weight'].T # MLP | |
h += ((x1 / (1.0 + np.exp(-x1))) * (rms_norm(h) * data[f'layers.{layer}.ffn_norm.weight'] @ | |
data[f'layers.{layer}.feed_forward.w3.weight'].T)) @ data[f'layers.{layer}.feed_forward.w2.weight'].T # MLP | |
tokens.append(int(np.argmax(((rms_norm(h) * data['norm.weight'])[-1, :] @ data['output.weight'].T)))) # Unembed tokens | |
prev_pos = cur_pos | |
print(tokenizer.decode(tokens)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment