Skip to content

Instantly share code, notes, and snippets.

@cnmoro
Created January 8, 2025 03:59
Show Gist options
  • Save cnmoro/6beb12eadf26a3fd70e9a0fa2516dd8c to your computer and use it in GitHub Desktop.
Save cnmoro/6beb12eadf26a3fd70e9a0fa2516dd8c to your computer and use it in GitHub Desktop.
SpatioTemporalGraphEncoding
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
from collections import defaultdict
class WordGraph:
def __init__(self):
self.graph = defaultdict(list)
def add_edge(self, word1, word2, weight=1.0):
"""Add a directed edge with a weight (e.g., distance or co-occurrence score)."""
self.graph[word1].append((word2, weight))
def get_neighbors(self, word):
"""Retrieve neighbors of a word."""
return self.graph[word]
class GraphEncoder:
def __init__(self, vocab, max_steps=5, decay=0.8):
self.vocab = vocab
self.max_steps = max_steps # How far to traverse the graph
self.decay = decay # Time-decay factor for longer paths
def encode(self, graph, start_word):
"""Encode a word as a fixed-size vector based on graph traversal."""
vector = np.zeros(len(self.vocab)) # Fixed-size vector for each word
visited = set()
# Depth-limited traversal
def traverse(node, depth, path_score):
if depth > self.max_steps or node in visited:
return
visited.add(node)
vector[self.vocab[node]] += path_score # Update vector based on path score
for neighbor, weight in graph.get_neighbors(node):
traverse(neighbor, depth + 1, path_score * weight * self.decay)
traverse(start_word, depth=0, path_score=1.0)
return vector
# Step 1: Dataset Preparation
class TextDataset(Dataset):
def __init__(self, texts, vocab):
self.texts = texts
self.vocab = vocab
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
tokens = self.texts[idx].split()
return tokens
# Step 2: Build Vocabulary
def build_vocab(texts):
vocab = {word: idx for idx, word in enumerate(set(" ".join(texts).split()))}
return vocab
# Step 3: Build Graph
def build_graph(texts, vocab):
graph = WordGraph()
for text in texts:
tokens = text.split()
for i in range(len(tokens) - 1):
graph.add_edge(tokens[i], tokens[i + 1], weight=1.0)
return graph
# Step 4: Encode Texts
def encode_texts(texts, graph, vocab):
encoder = GraphEncoder(vocab)
encoded_texts = []
for text in texts:
tokens = text.split()
encoded_vectors = np.zeros((len(tokens), len(vocab)))
for i, token in enumerate(tokens):
encoded_vectors[i] = encoder.encode(graph, token)
encoded_texts.append(encoded_vectors)
return encoded_texts
# Putting It All Together
if __name__ == "__main__":
# Sample Text Data
texts = [
"Today there are more than an estimated number of fans",
"Netflix has just released another popular series"
]
# Build Vocabulary and Graph
vocab = build_vocab(texts)
graph = build_graph(texts, vocab)
# Encode Texts
encoded_texts = encode_texts(texts, graph, vocab)
print("Encoded Text Vectors:")
for text, encoded in zip(texts, encoded_texts):
print(f"Text: {text}")
print(f"Encoded: {encoded}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment