Skip to content

Instantly share code, notes, and snippets.

@rtkclouds
Created August 19, 2024 18:21
Show Gist options
  • Save rtkclouds/b2bbc1864ba781c80d13809304cd3606 to your computer and use it in GitHub Desktop.
Save rtkclouds/b2bbc1864ba781c80d13809304cd3606 to your computer and use it in GitHub Desktop.
This code defines a custom attention mechanism and transformer layer based on the hexastore concept, which is typically used for efficiently querying triples in a knowledge graph (Subject, Predicate, Object).
import torch
import torch.nn as nn
import torch.nn.functional as F
class HexastoreAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
# Linear transformations for S, P, O
self.W_s = nn.Linear(d_model, d_model)
self.W_p = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
# Shared linear transformations for permutations
self.W_perm = nn.ModuleDict({
'spo': nn.Linear(3 * d_model, d_model),
'sop': nn.Linear(3 * d_model, d_model),
'pso': nn.Linear(3 * d_model, d_model),
'pos': nn.Linear(3 * d_model, d_model),
'osp': nn.Linear(3 * d_model, d_model),
'ops': nn.Linear(3 * d_model, d_model),
})
# Final projection layer
self.W_out = nn.Linear(6 * d_model, d_model)
def attention(self, q, k, v):
scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
attn = F.softmax(scores, dim=-1)
return torch.matmul(attn, v)
def forward(self, x):
batch_size, seq_len, _ = x.shape
# Split x into S, P, O components
s, p, o = x.chunk(3, dim=-1)
# Apply linear transformations to S, P, O
s = self.W_s(s)
p = self.W_p(p)
o = self.W_o(o)
# Create the 6 permutations and apply the corresponding linear transformations
permuted_outputs = []
for perm in ['spo', 'sop', 'pso', 'pos', 'osp', 'ops']:
perm_input = {
'spo': torch.cat([s, p, o], dim=-1),
'sop': torch.cat([s, o, p], dim=-1),
'pso': torch.cat([p, s, o], dim=-1),
'pos': torch.cat([p, o, s], dim=-1),
'osp': torch.cat([o, s, p], dim=-1),
'ops': torch.cat([o, p, s], dim=-1)
}[perm]
perm_output = self.W_perm[perm](perm_input)
permuted_outputs.append(self.attention(perm_output, perm_output, perm_output))
# Combine the results from all permutations
combined = torch.cat(permuted_outputs, dim=-1)
# Final projection
output = self.W_out(combined)
return output
class HexastoreTransformerLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
self.attention = HexastoreAttention(d_model, num_heads)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Linear(d_ff, d_model)
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
attn_output = self.attention(x)
x = self.norm1(x + self.dropout(attn_output))
ff_output = self.feed_forward(x)
x = self.norm2(x + self.dropout(ff_output))
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment