Created
December 10, 2024 01:16
-
-
Save MadcowD/1cbfea0b7ee50134a9309442295229d9 to your computer and use it in GitHub Desktop.
femtogpt.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Copyright (c) 2024, William Guss""" | |
import math,random,torch,shutil,sys,time | |
from torch import nn | |
from torch.nn import functional as F | |
from typing import Optional | |
class AddPositionalEncoding(nn.Module): | |
def __init__(self,l):super().__init__();self.l=l | |
def forward(self,x): | |
u=torch.arange(x.size(1),device=x.device)[:,None] | |
j=torch.arange(x.size(2),device=x.device)[None,:] | |
k=j%2 | |
t=u/(self.l**((j-k)/x.size(2)))+math.pi/2*k | |
return x+torch.sin(t) | |
class QKVAttention(nn.Module): | |
def __init__(self,d_in,d_qk,d_v,h=1,c=False,dr=0.0): | |
super().__init__() | |
def rw(*s):return nn.Parameter(torch.randn(*s)/math.sqrt(s[-1])) | |
self.c,self.dr=c,dr | |
self.w_q=rw(h,d_qk,d_in) | |
self.w_k=rw(h,d_qk,d_in) | |
self.w_v=rw(h,d_v,d_in) | |
self.w_o=rw(d_v*h,d_in) | |
def forward(self,x): | |
q=torch.einsum("ntc,hdc->nhtd",x,self.w_q) | |
k=torch.einsum("ntc,hdc->nhtd",x,self.w_k) | |
v=torch.einsum("ntc,hdc->nhtd",x,self.w_v) | |
a=torch.einsum("nhtd,nhsd->nhts",q,k)/math.sqrt(self.w_q.size(1)) | |
if self.c: | |
t=torch.arange(x.size(1),device=q.device) | |
a=a.masked_fill(t[None,None,:,None]<t[None,None,None,:],float("-inf")) | |
a=F.dropout(a.softmax(dim=3),self.dr,self.training) | |
y=torch.einsum("nhts,nhsd->nthd",a,v).flatten(2) | |
return y | |
class TransformerBlock(nn.Module): | |
def __init__(self,dm,dk,dh,h,c,dr): | |
super().__init__() | |
self.ln1=nn.LayerNorm((dm,)) | |
self.att=QKVAttention(dm,dk,dm//h,h,c,dr) | |
self.ln2=nn.LayerNorm((dm,)) | |
self.fc1=nn.Linear(dm,dh) | |
self.fc2=nn.Linear(dh,dm) | |
def forward(self,x): | |
r=x | |
x=self.ln1(r);x=self.att(x);r=r+x | |
x=self.ln2(r);x=self.fc1(x);x=F.relu(x);x=self.fc2(x);r=r+x | |
return r | |
class FemtoGPT(nn.Module): | |
def __init__(self,v,dm,dk,dh,h,b,c,dr=0.0,lm=1e5): | |
super().__init__() | |
self.s=nn.Sequential(nn.Embedding(v,dm),nn.Dropout(dr),AddPositionalEncoding(lm)) | |
self.t=nn.Sequential(*[TransformerBlock(dm,dk,dh,h,c,dr) for _ in range(b)]) | |
self.o=nn.Linear(dm,v) | |
with torch.no_grad(): | |
for m in self.modules(): | |
if isinstance(m,nn.Embedding):m.weight.normal_(0,2e-2) | |
elif isinstance(m,nn.LayerNorm):m.bias.zero_();m.weight.fill_(1.0) | |
def forward(self,x): | |
x=F.pad(x,(1,-1)) | |
x=self.s(x) | |
x=self.t(x) | |
x=self.o(x) | |
return x | |
def cross_entropy(self,x): | |
return F.cross_entropy(self(x).transpose(1,2),x) | |
def inplace_ar(self,x,t_s): | |
for t in range(t_s,x.size(1)): | |
dist=torch.distributions.categorical.Categorical(logits=self(x)[:,t:t+1,:]) | |
x[:,t:t+1]=dist.sample() | |
if __name__=="__main__": | |
if torch.cuda.is_available():dev=torch.device("cuda") | |
elif torch.backends.mps.is_available():dev=torch.device("mps") | |
elif torch.xpu.is_available():dev=torch.device("xpu") | |
else:dev=torch.device("cpu") | |
def gen(n,pl=25): | |
l=["".join([random.choice("ABCDEFGHIJKLMNOPQRSTUVWXYZ")for _ in range(pl)])for _ in range(n)] | |
return[x+">"+x[::-1]for x in l] | |
nt,tr=10000,1000 | |
data=gen(nt+tr) | |
pl=data[0].find(">") | |
s=set("".join(data)) | |
c2t={c:i for i,c in enumerate(s)} | |
t2c={i:c for c,i in c2t.items()} | |
v=len(s) | |
data=torch.cat([torch.tensor([c2t[c]for c in st])[None,:]for st in data]) | |
train,test=data[:nt],data[nt:] | |
d=128;b=4;h=4 | |
model=FemtoGPT(v,d,d//h,d,h,b,True,0.1) | |
e,B=11,100 | |
opt=torch.optim.Adam(model.parameters(),lr=1e-3) | |
train,test=train.to(dev),test.to(dev) | |
model.to(dev) | |
print("nb_parameters",sum(p.numel()for p in model.parameters()),"device",dev) | |
for ep in range(e): | |
model.train() | |
at=0.0 | |
for inp in train.split(B): | |
l=model.cross_entropy(inp) | |
at+=l.item()*inp.size(0) | |
opt.zero_grad();l.backward();opt.step() | |
at/=train.size(0) | |
model.eval() | |
aT=0.0 | |
for inp in test.split(B): | |
l=model.cross_entropy(inp);aT+=l.item()*inp.size(0) | |
aT/=test.size(0) | |
i=test[:B];r=i.clone();r[:,pl:]=0 | |
model.inplace_ar(r,pl) | |
ne=(i[:,pl:]!=r[:,pl:]).long().sum().item() | |
er=ne/i[:,pl:].numel() | |
print(f"n_epoch {ep} train_loss {at} test_loss {aT} token_error {er*100:.01f}%") | |
if ep%10==0: | |
print("-"*70) | |
for s_,t_ in zip(i[:5],r): | |
print("true: "+"".join([t2c[x.item()]for x in s_])) | |
print("generated: "+"".join([t2c[x.item()]for x in t_])) | |
print("-"*70) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment