Created
January 8, 2025 04:01
-
-
Save cnmoro/2b0d16d6a72e7a752db5254b65208e0f to your computer and use it in GitHub Desktop.
SemanticDiffusionEncoder
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.utils.data import DataLoader, Dataset | |
from torchtext.vocab import build_vocab_from_iterator | |
from torchtext.data.utils import get_tokenizer | |
# Step 1: Preprocessing | |
class TextDataset(Dataset): | |
def __init__(self, texts, vocab, tokenizer, max_len=50): | |
self.texts = texts | |
self.vocab = vocab | |
self.tokenizer = tokenizer | |
self.max_len = max_len | |
def __len__(self): | |
return len(self.texts) | |
def __getitem__(self, idx): | |
tokens = self.tokenizer(self.texts[idx]) | |
indices = [self.vocab[token] for token in tokens[:self.max_len]] | |
return torch.tensor(indices, dtype=torch.long) | |
def build_vocab(texts, tokenizer): | |
def yield_tokens(data_iter): | |
for text in data_iter: | |
yield tokenizer(text) | |
vocab = build_vocab_from_iterator(yield_tokens(texts), specials=["<pad>", "<unk>"]) | |
vocab.set_default_index(vocab["<unk>"]) | |
return vocab | |
# Step 2: Semantic Diffusion Encoder | |
class SemanticDiffusionEncoder(nn.Module): | |
def __init__(self, vocab_size, embedding_dim, grid_size, diffusion_steps): | |
super(SemanticDiffusionEncoder, self).__init__() | |
self.embedding = nn.Embedding(vocab_size, embedding_dim) | |
self.grid_size = grid_size | |
self.diffusion_steps = diffusion_steps | |
self.conv = nn.Conv2d(embedding_dim, embedding_dim, kernel_size=3, padding=1) | |
def forward(self, x): | |
# Embed tokens | |
embedded = self.embedding(x) # Shape: (batch_size, seq_len, embedding_dim) | |
# Initialize grid (Semantic Field) | |
grid = torch.zeros(x.size(0), embedded.size(2), self.grid_size, self.grid_size) | |
seq_len = x.size(1) | |
for i in range(seq_len): | |
position = (i % self.grid_size, i // self.grid_size) | |
grid[:, :, position[0], position[1]] += embedded[:, i, :] | |
# Diffusion Process | |
for _ in range(self.diffusion_steps): | |
grid = F.relu(self.conv(grid)) # Each step spreads information | |
# Aggregate Final Representations | |
aggregated = grid.mean(dim=(2, 3)) # Shape: (batch_size, embedding_dim) | |
return aggregated | |
# Step 3: Training Loop | |
def train_model(model, dataloader, optimizer, criterion, epochs=5): | |
model.train() | |
for epoch in range(epochs): | |
total_loss = 0 | |
for input_seq in dataloader: | |
optimizer.zero_grad() | |
output = model(input_seq) # Shape: (batch_size, embedding_dim) | |
loss = criterion(output, torch.zeros_like(output)) # Dummy loss for demonstration | |
loss.backward() | |
optimizer.step() | |
total_loss += loss.item() | |
print(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss / len(dataloader)}") | |
# Step 4: Putting It Together | |
if __name__ == "__main__": | |
# Sample Text Data | |
texts = [ | |
"Today, there are more than an estimated 220 million fans of Korean entertainment.", | |
"Squid Game, Netflix's most popular show ever, has returned for a much-anticipated second season." | |
] | |
tokenizer = get_tokenizer("basic_english") | |
vocab = build_vocab(texts, tokenizer) | |
# Hyperparameters | |
embedding_dim = 32 | |
grid_size = 8 | |
diffusion_steps = 5 | |
batch_size = 2 | |
epochs = 5 | |
# Dataset and DataLoader | |
dataset = TextDataset(texts, vocab, tokenizer) | |
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) | |
# Model, Optimizer, Loss | |
model = SemanticDiffusionEncoder(len(vocab), embedding_dim, grid_size, diffusion_steps) | |
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) | |
criterion = nn.MSELoss() | |
# Training | |
train_model(model, dataloader, optimizer, criterion, epochs=epochs) | |
# Example Encoding | |
model.eval() | |
with torch.no_grad(): | |
sample_text = "Today there are more fans" | |
sample_tokens = [vocab[token] for token in tokenizer(sample_text)] | |
sample_tensor = torch.tensor([sample_tokens], dtype=torch.long) | |
encoded_output = model(sample_tensor) | |
print("Encoded Vector:", encoded_output) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment