Last active
          January 23, 2024 02:28 
        
      - 
      
- 
        Save NaxAlpha/13c80fd0df6f57958e147daec3d90485 to your computer and use it in GitHub Desktop. 
    Softformer - An Attention-free, softmax based transformer for causal language modeling.
  
        
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| def cum_softmax(x, dim=1): # <- main novelty | |
| z = x.exp() | |
| d = z.cumsum(dim) | |
| return z / d | |
| class SoftBlock(nn.Module): | |
| def __init__(self, dim): | |
| super().__init__() | |
| self.norm = nn.LayerNorm(dim) | |
| self.proj = nn.Linear(dim, dim) | |
| self.soft = nn.Linear(dim, dim) | |
| def forward(self, x): | |
| # x: [batch, seq, dim] | |
| x = self.norm(x) | |
| p = self.proj(x) | |
| s = self.soft(x) | |
| s = cum_softmax(s, dim=1) | |
| y = p * s | |
| return y.cumsum(dim=1) | |
| class FeedForward(nn.Module): | |
| def __init__(self, dim, hidden_dim): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.LayerNorm(dim), | |
| nn.Linear(dim, 2 * hidden_dim), | |
| nn.GLU(), | |
| nn.LayerNorm(hidden_dim), | |
| nn.Linear(hidden_dim, dim), | |
| nn.GELU(), | |
| ) | |
| def forward(self, x): | |
| return self.net(x) | |
| class Softformer(nn.Module): | |
| def __init__(self, dim, depth, hidden_dim=...): | |
| super().__init__() | |
| if hidden_dim is ...: | |
| hidden_dim = dim * 2 | |
| self.layers = nn.ModuleList([]) | |
| for _ in range(depth): | |
| self.layers += [ | |
| SoftBlock(dim), | |
| FeedForward(dim, hidden_dim), | |
| ] | |
| def forward(self, x): | |
| # x: [batch, seq, dim] | |
| for layer in self.layers: | |
| x = x + layer(x) | |
| return x | |
| class SoftRegressor(nn.Module): | |
| def __init__(self, max_ctx, vocab_size, emb_dim, depth, hidden_dim=...): | |
| super().__init__() | |
| self.emb = nn.Embedding(vocab_size, emb_dim, max_norm=1) | |
| self.pos = nn.Embedding(max_ctx, emb_dim, max_norm=1) | |
| self.net = Softformer(emb_dim, depth, hidden_dim) | |
| self.end = nn.Sequential( | |
| nn.Linear(emb_dim, emb_dim), | |
| nn.LayerNorm(emb_dim), | |
| ) | |
| def forward(self, x, y=None): | |
| # x: [batch, seq], y: [batch, seq] | |
| _, seq = x.shape | |
| x = self.emb(x) + self.pos(torch.arange(seq, device=x.device)) | |
| x = self.net(x) | |
| x = self.end(x) | |
| x = x @ self.emb.weight.t() | |
| if y is None: | |
| return x | |
| loss = F.cross_entropy(x.view(-1, x.shape[-1]), y.reshape(-1)) | |
| return loss, x | |
| @torch.no_grad() | |
| def generate(self, x, max_len=100, temp=1.0): | |
| # x: [batch, seq] | |
| is_train = self.training | |
| self.eval() | |
| while x.shape[1] < max_len: | |
| y = self.forward(x) | |
| y = y[:, -1, :] / temp | |
| y = y.softmax(dim=-1) | |
| y = torch.multinomial(y, 1) | |
| x = torch.cat([x, y], dim=1) | |
| self.train(is_train) | |
| return x | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | import time | |
| import random | |
| import torch | |
| from torch.utils.data import DataLoader | |
| import wandb | |
| from tqdm import tqdm | |
| from datasets import load_dataset | |
| from transformers import GPT2TokenizerFast | |
| from model import SoftRegressor | |
| class Trainer: | |
| def __init__(self): | |
| self.dataset = load_dataset( | |
| "the_pile", | |
| name="all", | |
| split="train", | |
| streaming=True, | |
| ).shuffle(buffer_size=1000) | |
| self.tokenizer: GPT2TokenizerFast = GPT2TokenizerFast.from_pretrained("gpt2") | |
| self.max_tokens = 512 | |
| self.dataset = self.dataset.map( | |
| self.tokenize, | |
| batched=True, | |
| batch_size=64, | |
| ) | |
| self.loader = DataLoader( | |
| self.dataset, | |
| batch_size=8, | |
| num_workers=8, | |
| ) | |
| self.model = model = SoftRegressor( | |
| max_ctx=self.max_tokens, | |
| vocab_size=self.tokenizer.vocab_size, | |
| emb_dim=1024, | |
| depth=24, | |
| ).cuda() | |
| self.opt = torch.optim.Adam( | |
| model.parameters(), | |
| lr=1e-4, | |
| weight_decay=1e-2, | |
| ) | |
| num_params = sum(p.numel() for p in model.parameters()) | |
| emb_params = list(model.emb.parameters()) + list(model.pos.parameters()) | |
| emb_params = sum(p.numel() for p in emb_params) | |
| non_emb_params = num_params - emb_params | |
| print(f"num params: {num_params}") | |
| print(f"emb params: {emb_params}") | |
| print(f"non emb params: {non_emb_params}") | |
| def tokenize(self, examples): | |
| N = len(examples["text"]) | |
| out = self.tokenizer(examples["text"]) | |
| # join with eos | |
| res = [] | |
| for inp in out["input_ids"]: | |
| res += inp + [self.tokenizer.eos_token_id] | |
| # sample len(examples) sequences of length max_tokens | |
| exp = [] | |
| for i in range(N): | |
| j = random.randint(0, len(res) - self.max_tokens - 1) | |
| exp.append(res[j : j + self.max_tokens + 1]) | |
| return {"input_ids": torch.tensor(exp)} | |
| def train(self): | |
| wandb.init( | |
| project="softformer", | |
| entity="nax-autify", | |
| ) | |
| prog = tqdm(self.loader) | |
| for i, batch in enumerate(prog): | |
| batch = batch["input_ids"].cuda() | |
| self.opt.zero_grad() | |
| loss, _ = self.model(batch[:, :-1], batch[:, 1:]) | |
| loss.backward() | |
| self.opt.step() | |
| prog.set_description(f"loss: {loss.item():.3f}") | |
| wandb.log({"loss": loss.item()}, step=i) | |
| if i % 100 == 0: | |
| torch.save(self.model.state_dict(), "model.pt") | |
| if i % 1000 == 0: | |
| x = torch.tensor([[self.tokenizer.eos_token_id]] * 8).cuda() | |
| t0 = time.time() | |
| y = self.model.generate(x, max_len=self.max_tokens).tolist() | |
| t1 = time.time() | |
| t = [self.tokenizer.decode(z) for z in y] | |
| t = "<hr>".join(f"<p>{c}</p>" for c in t) | |
| html = ( | |
| """ | |
| <style> | |
| html, body { | |
| padding: 0; | |
| margin: 0; | |
| width: 100%; | |
| height: 100%; | |
| } | |
| p { | |
| font-family: 'Verdana', sans-serif; | |
| } | |
| </style> | |
| """ | |
| + t | |
| ) | |
| wandb.log({"samples": wandb.Html(html)}, step=i) | |
| print(f"Generated in {t1-t0:.3f}s") | |
| if __name__ == "__main__": | |
| trainer = Trainer() | |
| trainer.train() | 
  
    Sign up for free
    to join this conversation on GitHub.
    Already have an account?
    Sign in to comment