Skip to content

Instantly share code, notes, and snippets.

@keskival
Created September 1, 2025 12:02
Show Gist options
  • Save keskival/f7b088345d6e2adb5e77457106e6fb63 to your computer and use it in GitHub Desktop.
Save keskival/f7b088345d6e2adb5e77457106e6fb63 to your computer and use it in GitHub Desktop.
Ngram
# ngram_lm.py
# A pure-Python n-gram language model with interpolated Kneser–Ney smoothing.
# No external libraries. Intended to make complexity explicit.
from collections import Counter, defaultdict
from functools import lru_cache
import re, math, random
class NGramLM:
"""
n-gram language model with interpolated Kneser–Ney smoothing.
Key ideas & complexity:
- Training counts all k-grams for k=1..n. Time ~ O(T * n), T = total tokens.
Memory proportional to number of observed k-grams (worst case O(V^n), typically far less).
- Probability P(w | context) uses recursive interpolation:
max(c(context w) - D, 0) / c(context) + λ(context) * P_lower(w | context[1:])
* Each evaluation is O(1) given the counters, but generation naively needs
probabilities for every candidate w (O(|V| * n)). We cache distributions per
context to make repeated sampling O(|V|) once per context.
- Generation draws from the context distribution; with caching it’s fast for
repeated contexts (common in short texts).
Features:
- Tokenization, simple sentence splitting, BOS/EOS handling
- <unk> for rare tokens via min_count threshold
- Text generation with optional temperature
- Perplexity calculation
Notation:
- BOS = "<s>" is only used as context (not generated)
- EOS = "</s>" may be generated to terminate
"""
BOS = "<s>"
EOS = "</s>"
UNK = "<unk>"
def __init__(self, n=3, discount=0.75, min_count=1, lowercase=True):
assert n >= 1
self.n = n
self.discount = float(discount)
self.min_count = int(min_count)
self.lowercase = lowercase
# Learned structures
self.vocab = set() # includes </s> and <unk>, excludes <s>
self.vocab_no_bos = set()
self.counts_by_order = {} # k -> Counter of k-grams (tuples)
self.next_counts = {} # (k-1) -> {context(tuple): Counter(next_word)}
self.num_follow_types = {} # context -> # distinct successors (for λ)
self.distinct_predecessors = defaultdict(set) # w -> set of distinct predecessors (for KN base)
self.total_bigram_types = 0
# Caches for speed during generation
self._dist_cache = {}
# ---------------------- Text processing ----------------------
def _normalize(self, text):
return text.lower() if self.lowercase else text
def _tokenize(self, text):
# simple word/punctuation tokenizer
text = self._normalize(text)
return re.findall(r"\w+|[^\w\s]", text, flags=re.UNICODE)
def _sentences(self, text):
# naive sentence splitter: split on punctuation followed by whitespace
return [s for s in re.split(r'(?<=[.!?])\s+', text.strip()) if s]
def _wrap_sentence(self, tokens):
# add (n-1) BOS and one EOS
return [self.BOS] * (self.n - 1) + tokens + [self.EOS]
def _detokenize(self, tokens):
# light-weight detokenizer to make output readable
out = []
for tok in tokens:
if tok in {".", ",", "!", "?", ";", ":", "%", ")", "]", "}", "»"} and out:
out[-1] = out[-1] + tok
elif tok in {"(", "[", "{", "«"}:
out.append(tok)
elif tok in {"'", "’"} and out and re.match(r".*[A-Za-z0-9]$", out[-1]):
out[-1] = out[-1] + tok
elif tok == self.EOS:
continue
else:
out.append((" " if out else "") + tok)
return "".join(out).strip()
# ---------------------- Training ----------------------
def fit(self, texts):
if isinstance(texts, str):
texts = [texts]
# Pass 1: raw token counts to build vocab with <unk>
raw_counts = Counter()
sentence_token_lists = []
for text in texts:
for sent in self._sentences(text):
toks = self._tokenize(sent)
raw_counts.update(toks)
sentence_token_lists.append(toks)
# Build vocab: keep >= min_count; always include EOS & UNK; never include BOS
self.vocab = {t for t, c in raw_counts.items() if c >= self.min_count}
self.vocab.discard(self.BOS)
self.vocab.update({self.EOS, self.UNK})
self.vocab_no_bos = set(self.vocab)
# Pass 2: map rares to <unk>, add BOS/EOS, accumulate counts for all orders
self.counts_by_order = {k: Counter() for k in range(1, self.n + 1)}
self.next_counts = {k - 1: defaultdict(Counter) for k in range(2, self.n + 1)}
self.distinct_predecessors.clear()
for toks in sentence_token_lists:
proc = [t if t in self.vocab_no_bos else self.UNK for t in toks]
wrapped = self._wrap_sentence(proc)
for k in range(1, self.n + 1):
for i in range(len(wrapped) - k + 1):
ng = tuple(wrapped[i:i + k])
self.counts_by_order[k][ng] += 1
if k >= 2:
ctx, w = ng[:-1], ng[-1]
self.next_counts[k - 1][ctx][w] += 1
if k == 2:
self.distinct_predecessors[w].add(ctx[0])
# Precompute #successor types per context and total bigram types (for KN base)
self.num_follow_types = {
ctx: len(ctr) for ctx_map in self.next_counts.values() for ctx, ctr in ctx_map.items()
}
self.total_bigram_types = sum(len(prevs) for prevs in self.distinct_predecessors.values())
# Invalidate caches (e.g., if refit)
self._dist_cache.clear()
# Also clear prob cache bound to the instance (created below)
try:
self._prob_kn.cache_clear() # type: ignore[attr-defined]
except AttributeError:
pass
return self)
# ---------------------- Kneser–Ney smoothing ----------------------
def _cont_unigram_prob(self, w):
# KN base: P_cont(w) = N1+(· w) / N1+(· ·)
if self.total_bigram_types == 0:
total = sum(self.counts_by_order[1].values())
return self.counts_by_order[1].get((w,), 0) / total if total else 0.0
return len(self.distinct_predecessors.get(w, ())) / self.total_bigram_types
@lru_cache(maxsize=100_000)
def _prob_kn(self, w, context):
"""
Recursively compute KN probability P(w | context).
context: tuple of length <= n-1
"""
if len(context) == 0:
return self._cont_unigram_prob(w)
k = len(context) + 1
c_ctx = self.counts_by_order.get(k - 1, Counter()).get(tuple(context), 0)
if c_ctx == 0:
# unseen context: pure backoff
return self._prob_kn(w, context[1:])
c_ctxw = self.counts_by_order.get(k, Counter()).get(tuple(context) + (w,), 0)
D = self.discount
numerator = max(c_ctxw - D, 0.0) / c_ctx
N1plus = self.num_follow_types.get(tuple(context), 0)
lamb = (D * N1plus) / c_ctx
lower = self._prob_kn(w, context[1:])
return numerator + lamb * lower
# ---------------------- Inference APIs ----------------------
def distribution(self, context):
"""
Return (vocab_list, probs_list) for the given context (truncated to n-1).
Cached per context to make repeated sampling efficient.
"""
ctx = tuple(context[-(self.n - 1):]) if self.n > 1 else tuple()
if ctx in self._dist_cache:
return self._dist_cache[ctx]
vocab = sorted(self.vocab_no_bos) # exclude BOS from candidates
probs = [self._prob_kn(w, ctx) for w in vocab]
s = sum(probs)
if s == 0.0:
probs = [1.0 / len(vocab)] * len(vocab)
else:
probs = [p / s for p in probs]
self._dist_cache[ctx] = (vocab, probs)
return self._dist_cache[ctx]
def sample_next(self, context, temperature=1.0):
vocab, probs = self.distribution(context)
if temperature <= 0:
return vocab[max(range(len(probs)), key=probs.__getitem__)]
if temperature != 1.0:
probs = [p ** (1.0 / temperature) for p in probs]
s = sum(probs)
probs = [p / s for p in probs]
r, acc = random.random(), 0.0
for w, p in zip(vocab, probs):
acc += p
if r <= acc:
return w
return vocab[-1] # numeric safety
def generate(self, max_tokens=50, temperature=1.0, seed=None):
if seed is not None:
random.seed(seed)
context = tuple([self.BOS] * (self.n - 1)) if self.n > 1 else tuple()
out = []
for _ in range(max_tokens):
w = self.sample_next(context, temperature=temperature)
if w == self.EOS:
break
out.append(w)
if self.n > 1:
context = context[1:] + (w,)
return self._detokenize(out)
def log_prob(self, tokens):
"""
Sum of log probabilities for a token sequence (wrapped with BOS/EOS).
Unknown tokens are mapped to <unk>.
"""
if self.n == 1:
wrapped = tokens + [self.EOS]
else:
wrapped = [self.BOS] * (self.n - 1) + tokens + [self.EOS]
logp = 0.0
for i in range(self.n - 1, len(wrapped)):
context = tuple(wrapped[i - (self.n - 1):i]) if self.n > 1 else tuple()
w = wrapped[i] if wrapped[i] in self.vocab_no_bos else self.UNK
p = self._prob_kn(w, context)
logp += math.log(max(p, 1e-12))
return logp
def perplexity(self, text):
"""
Tokenize text and compute per-token perplexity (lower is better).
"""
tokens = []
for sent in self._sentences(text):
tokens.extend([t if t in self.vocab_no_bos else self.UNK for t in self._tokenize(sent)])
N = max(len(tokens) + 1, 1) # +1 for EOS
return math.exp(-self.log_prob(tokens) / N)
# ---------------------- Usage example ----------------------
if __name__ == "__main__":
corpus = """
Alice was beginning to get very tired of sitting by her sister on the bank.
She had nothing to do: once or twice she had peeped into the book her sister was reading,
but it had no pictures or conversations in it, "and what is the use of a book," thought Alice "without pictures or conversations?"
"""
lm = NGramLM(n=3, discount=0.75, min_count=1, lowercase=True).fit(corpus)
print("Vocab size (excl. <s>):", len(lm.vocab_no_bos))
print("Sample:", lm.generate(max_tokens=40, temperature=1.0, seed=1))
print("Perplexity on training text:", round(lm.perplexity(corpus), 3))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment