Skip to content

Instantly share code, notes, and snippets.

@MadcowD
Created December 10, 2024 01:16
Show Gist options
  • Save MadcowD/1cbfea0b7ee50134a9309442295229d9 to your computer and use it in GitHub Desktop.
Save MadcowD/1cbfea0b7ee50134a9309442295229d9 to your computer and use it in GitHub Desktop.
femtogpt.py
"""Copyright (c) 2024, William Guss"""
import math,random,torch,shutil,sys,time
from torch import nn
from torch.nn import functional as F
from typing import Optional
class AddPositionalEncoding(nn.Module):
def __init__(self,l):super().__init__();self.l=l
def forward(self,x):
u=torch.arange(x.size(1),device=x.device)[:,None]
j=torch.arange(x.size(2),device=x.device)[None,:]
k=j%2
t=u/(self.l**((j-k)/x.size(2)))+math.pi/2*k
return x+torch.sin(t)
class QKVAttention(nn.Module):
def __init__(self,d_in,d_qk,d_v,h=1,c=False,dr=0.0):
super().__init__()
def rw(*s):return nn.Parameter(torch.randn(*s)/math.sqrt(s[-1]))
self.c,self.dr=c,dr
self.w_q=rw(h,d_qk,d_in)
self.w_k=rw(h,d_qk,d_in)
self.w_v=rw(h,d_v,d_in)
self.w_o=rw(d_v*h,d_in)
def forward(self,x):
q=torch.einsum("ntc,hdc->nhtd",x,self.w_q)
k=torch.einsum("ntc,hdc->nhtd",x,self.w_k)
v=torch.einsum("ntc,hdc->nhtd",x,self.w_v)
a=torch.einsum("nhtd,nhsd->nhts",q,k)/math.sqrt(self.w_q.size(1))
if self.c:
t=torch.arange(x.size(1),device=q.device)
a=a.masked_fill(t[None,None,:,None]<t[None,None,None,:],float("-inf"))
a=F.dropout(a.softmax(dim=3),self.dr,self.training)
y=torch.einsum("nhts,nhsd->nthd",a,v).flatten(2)
return y
class TransformerBlock(nn.Module):
def __init__(self,dm,dk,dh,h,c,dr):
super().__init__()
self.ln1=nn.LayerNorm((dm,))
self.att=QKVAttention(dm,dk,dm//h,h,c,dr)
self.ln2=nn.LayerNorm((dm,))
self.fc1=nn.Linear(dm,dh)
self.fc2=nn.Linear(dh,dm)
def forward(self,x):
r=x
x=self.ln1(r);x=self.att(x);r=r+x
x=self.ln2(r);x=self.fc1(x);x=F.relu(x);x=self.fc2(x);r=r+x
return r
class FemtoGPT(nn.Module):
def __init__(self,v,dm,dk,dh,h,b,c,dr=0.0,lm=1e5):
super().__init__()
self.s=nn.Sequential(nn.Embedding(v,dm),nn.Dropout(dr),AddPositionalEncoding(lm))
self.t=nn.Sequential(*[TransformerBlock(dm,dk,dh,h,c,dr) for _ in range(b)])
self.o=nn.Linear(dm,v)
with torch.no_grad():
for m in self.modules():
if isinstance(m,nn.Embedding):m.weight.normal_(0,2e-2)
elif isinstance(m,nn.LayerNorm):m.bias.zero_();m.weight.fill_(1.0)
def forward(self,x):
x=F.pad(x,(1,-1))
x=self.s(x)
x=self.t(x)
x=self.o(x)
return x
def cross_entropy(self,x):
return F.cross_entropy(self(x).transpose(1,2),x)
def inplace_ar(self,x,t_s):
for t in range(t_s,x.size(1)):
dist=torch.distributions.categorical.Categorical(logits=self(x)[:,t:t+1,:])
x[:,t:t+1]=dist.sample()
if __name__=="__main__":
if torch.cuda.is_available():dev=torch.device("cuda")
elif torch.backends.mps.is_available():dev=torch.device("mps")
elif torch.xpu.is_available():dev=torch.device("xpu")
else:dev=torch.device("cpu")
def gen(n,pl=25):
l=["".join([random.choice("ABCDEFGHIJKLMNOPQRSTUVWXYZ")for _ in range(pl)])for _ in range(n)]
return[x+">"+x[::-1]for x in l]
nt,tr=10000,1000
data=gen(nt+tr)
pl=data[0].find(">")
s=set("".join(data))
c2t={c:i for i,c in enumerate(s)}
t2c={i:c for c,i in c2t.items()}
v=len(s)
data=torch.cat([torch.tensor([c2t[c]for c in st])[None,:]for st in data])
train,test=data[:nt],data[nt:]
d=128;b=4;h=4
model=FemtoGPT(v,d,d//h,d,h,b,True,0.1)
e,B=11,100
opt=torch.optim.Adam(model.parameters(),lr=1e-3)
train,test=train.to(dev),test.to(dev)
model.to(dev)
print("nb_parameters",sum(p.numel()for p in model.parameters()),"device",dev)
for ep in range(e):
model.train()
at=0.0
for inp in train.split(B):
l=model.cross_entropy(inp)
at+=l.item()*inp.size(0)
opt.zero_grad();l.backward();opt.step()
at/=train.size(0)
model.eval()
aT=0.0
for inp in test.split(B):
l=model.cross_entropy(inp);aT+=l.item()*inp.size(0)
aT/=test.size(0)
i=test[:B];r=i.clone();r[:,pl:]=0
model.inplace_ar(r,pl)
ne=(i[:,pl:]!=r[:,pl:]).long().sum().item()
er=ne/i[:,pl:].numel()
print(f"n_epoch {ep} train_loss {at} test_loss {aT} token_error {er*100:.01f}%")
if ep%10==0:
print("-"*70)
for s_,t_ in zip(i[:5],r):
print("true: "+"".join([t2c[x.item()]for x in s_]))
print("generated: "+"".join([t2c[x.item()]for x in t_]))
print("-"*70)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment