Created
September 1, 2025 12:02
-
-
Save keskival/f7b088345d6e2adb5e77457106e6fb63 to your computer and use it in GitHub Desktop.
Ngram
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # ngram_lm.py | |
| # A pure-Python n-gram language model with interpolated Kneser–Ney smoothing. | |
| # No external libraries. Intended to make complexity explicit. | |
| from collections import Counter, defaultdict | |
| from functools import lru_cache | |
| import re, math, random | |
| class NGramLM: | |
| """ | |
| n-gram language model with interpolated Kneser–Ney smoothing. | |
| Key ideas & complexity: | |
| - Training counts all k-grams for k=1..n. Time ~ O(T * n), T = total tokens. | |
| Memory proportional to number of observed k-grams (worst case O(V^n), typically far less). | |
| - Probability P(w | context) uses recursive interpolation: | |
| max(c(context w) - D, 0) / c(context) + λ(context) * P_lower(w | context[1:]) | |
| * Each evaluation is O(1) given the counters, but generation naively needs | |
| probabilities for every candidate w (O(|V| * n)). We cache distributions per | |
| context to make repeated sampling O(|V|) once per context. | |
| - Generation draws from the context distribution; with caching it’s fast for | |
| repeated contexts (common in short texts). | |
| Features: | |
| - Tokenization, simple sentence splitting, BOS/EOS handling | |
| - <unk> for rare tokens via min_count threshold | |
| - Text generation with optional temperature | |
| - Perplexity calculation | |
| Notation: | |
| - BOS = "<s>" is only used as context (not generated) | |
| - EOS = "</s>" may be generated to terminate | |
| """ | |
| BOS = "<s>" | |
| EOS = "</s>" | |
| UNK = "<unk>" | |
| def __init__(self, n=3, discount=0.75, min_count=1, lowercase=True): | |
| assert n >= 1 | |
| self.n = n | |
| self.discount = float(discount) | |
| self.min_count = int(min_count) | |
| self.lowercase = lowercase | |
| # Learned structures | |
| self.vocab = set() # includes </s> and <unk>, excludes <s> | |
| self.vocab_no_bos = set() | |
| self.counts_by_order = {} # k -> Counter of k-grams (tuples) | |
| self.next_counts = {} # (k-1) -> {context(tuple): Counter(next_word)} | |
| self.num_follow_types = {} # context -> # distinct successors (for λ) | |
| self.distinct_predecessors = defaultdict(set) # w -> set of distinct predecessors (for KN base) | |
| self.total_bigram_types = 0 | |
| # Caches for speed during generation | |
| self._dist_cache = {} | |
| # ---------------------- Text processing ---------------------- | |
| def _normalize(self, text): | |
| return text.lower() if self.lowercase else text | |
| def _tokenize(self, text): | |
| # simple word/punctuation tokenizer | |
| text = self._normalize(text) | |
| return re.findall(r"\w+|[^\w\s]", text, flags=re.UNICODE) | |
| def _sentences(self, text): | |
| # naive sentence splitter: split on punctuation followed by whitespace | |
| return [s for s in re.split(r'(?<=[.!?])\s+', text.strip()) if s] | |
| def _wrap_sentence(self, tokens): | |
| # add (n-1) BOS and one EOS | |
| return [self.BOS] * (self.n - 1) + tokens + [self.EOS] | |
| def _detokenize(self, tokens): | |
| # light-weight detokenizer to make output readable | |
| out = [] | |
| for tok in tokens: | |
| if tok in {".", ",", "!", "?", ";", ":", "%", ")", "]", "}", "»"} and out: | |
| out[-1] = out[-1] + tok | |
| elif tok in {"(", "[", "{", "«"}: | |
| out.append(tok) | |
| elif tok in {"'", "’"} and out and re.match(r".*[A-Za-z0-9]$", out[-1]): | |
| out[-1] = out[-1] + tok | |
| elif tok == self.EOS: | |
| continue | |
| else: | |
| out.append((" " if out else "") + tok) | |
| return "".join(out).strip() | |
| # ---------------------- Training ---------------------- | |
| def fit(self, texts): | |
| if isinstance(texts, str): | |
| texts = [texts] | |
| # Pass 1: raw token counts to build vocab with <unk> | |
| raw_counts = Counter() | |
| sentence_token_lists = [] | |
| for text in texts: | |
| for sent in self._sentences(text): | |
| toks = self._tokenize(sent) | |
| raw_counts.update(toks) | |
| sentence_token_lists.append(toks) | |
| # Build vocab: keep >= min_count; always include EOS & UNK; never include BOS | |
| self.vocab = {t for t, c in raw_counts.items() if c >= self.min_count} | |
| self.vocab.discard(self.BOS) | |
| self.vocab.update({self.EOS, self.UNK}) | |
| self.vocab_no_bos = set(self.vocab) | |
| # Pass 2: map rares to <unk>, add BOS/EOS, accumulate counts for all orders | |
| self.counts_by_order = {k: Counter() for k in range(1, self.n + 1)} | |
| self.next_counts = {k - 1: defaultdict(Counter) for k in range(2, self.n + 1)} | |
| self.distinct_predecessors.clear() | |
| for toks in sentence_token_lists: | |
| proc = [t if t in self.vocab_no_bos else self.UNK for t in toks] | |
| wrapped = self._wrap_sentence(proc) | |
| for k in range(1, self.n + 1): | |
| for i in range(len(wrapped) - k + 1): | |
| ng = tuple(wrapped[i:i + k]) | |
| self.counts_by_order[k][ng] += 1 | |
| if k >= 2: | |
| ctx, w = ng[:-1], ng[-1] | |
| self.next_counts[k - 1][ctx][w] += 1 | |
| if k == 2: | |
| self.distinct_predecessors[w].add(ctx[0]) | |
| # Precompute #successor types per context and total bigram types (for KN base) | |
| self.num_follow_types = { | |
| ctx: len(ctr) for ctx_map in self.next_counts.values() for ctx, ctr in ctx_map.items() | |
| } | |
| self.total_bigram_types = sum(len(prevs) for prevs in self.distinct_predecessors.values()) | |
| # Invalidate caches (e.g., if refit) | |
| self._dist_cache.clear() | |
| # Also clear prob cache bound to the instance (created below) | |
| try: | |
| self._prob_kn.cache_clear() # type: ignore[attr-defined] | |
| except AttributeError: | |
| pass | |
| return self) | |
| # ---------------------- Kneser–Ney smoothing ---------------------- | |
| def _cont_unigram_prob(self, w): | |
| # KN base: P_cont(w) = N1+(· w) / N1+(· ·) | |
| if self.total_bigram_types == 0: | |
| total = sum(self.counts_by_order[1].values()) | |
| return self.counts_by_order[1].get((w,), 0) / total if total else 0.0 | |
| return len(self.distinct_predecessors.get(w, ())) / self.total_bigram_types | |
| @lru_cache(maxsize=100_000) | |
| def _prob_kn(self, w, context): | |
| """ | |
| Recursively compute KN probability P(w | context). | |
| context: tuple of length <= n-1 | |
| """ | |
| if len(context) == 0: | |
| return self._cont_unigram_prob(w) | |
| k = len(context) + 1 | |
| c_ctx = self.counts_by_order.get(k - 1, Counter()).get(tuple(context), 0) | |
| if c_ctx == 0: | |
| # unseen context: pure backoff | |
| return self._prob_kn(w, context[1:]) | |
| c_ctxw = self.counts_by_order.get(k, Counter()).get(tuple(context) + (w,), 0) | |
| D = self.discount | |
| numerator = max(c_ctxw - D, 0.0) / c_ctx | |
| N1plus = self.num_follow_types.get(tuple(context), 0) | |
| lamb = (D * N1plus) / c_ctx | |
| lower = self._prob_kn(w, context[1:]) | |
| return numerator + lamb * lower | |
| # ---------------------- Inference APIs ---------------------- | |
| def distribution(self, context): | |
| """ | |
| Return (vocab_list, probs_list) for the given context (truncated to n-1). | |
| Cached per context to make repeated sampling efficient. | |
| """ | |
| ctx = tuple(context[-(self.n - 1):]) if self.n > 1 else tuple() | |
| if ctx in self._dist_cache: | |
| return self._dist_cache[ctx] | |
| vocab = sorted(self.vocab_no_bos) # exclude BOS from candidates | |
| probs = [self._prob_kn(w, ctx) for w in vocab] | |
| s = sum(probs) | |
| if s == 0.0: | |
| probs = [1.0 / len(vocab)] * len(vocab) | |
| else: | |
| probs = [p / s for p in probs] | |
| self._dist_cache[ctx] = (vocab, probs) | |
| return self._dist_cache[ctx] | |
| def sample_next(self, context, temperature=1.0): | |
| vocab, probs = self.distribution(context) | |
| if temperature <= 0: | |
| return vocab[max(range(len(probs)), key=probs.__getitem__)] | |
| if temperature != 1.0: | |
| probs = [p ** (1.0 / temperature) for p in probs] | |
| s = sum(probs) | |
| probs = [p / s for p in probs] | |
| r, acc = random.random(), 0.0 | |
| for w, p in zip(vocab, probs): | |
| acc += p | |
| if r <= acc: | |
| return w | |
| return vocab[-1] # numeric safety | |
| def generate(self, max_tokens=50, temperature=1.0, seed=None): | |
| if seed is not None: | |
| random.seed(seed) | |
| context = tuple([self.BOS] * (self.n - 1)) if self.n > 1 else tuple() | |
| out = [] | |
| for _ in range(max_tokens): | |
| w = self.sample_next(context, temperature=temperature) | |
| if w == self.EOS: | |
| break | |
| out.append(w) | |
| if self.n > 1: | |
| context = context[1:] + (w,) | |
| return self._detokenize(out) | |
| def log_prob(self, tokens): | |
| """ | |
| Sum of log probabilities for a token sequence (wrapped with BOS/EOS). | |
| Unknown tokens are mapped to <unk>. | |
| """ | |
| if self.n == 1: | |
| wrapped = tokens + [self.EOS] | |
| else: | |
| wrapped = [self.BOS] * (self.n - 1) + tokens + [self.EOS] | |
| logp = 0.0 | |
| for i in range(self.n - 1, len(wrapped)): | |
| context = tuple(wrapped[i - (self.n - 1):i]) if self.n > 1 else tuple() | |
| w = wrapped[i] if wrapped[i] in self.vocab_no_bos else self.UNK | |
| p = self._prob_kn(w, context) | |
| logp += math.log(max(p, 1e-12)) | |
| return logp | |
| def perplexity(self, text): | |
| """ | |
| Tokenize text and compute per-token perplexity (lower is better). | |
| """ | |
| tokens = [] | |
| for sent in self._sentences(text): | |
| tokens.extend([t if t in self.vocab_no_bos else self.UNK for t in self._tokenize(sent)]) | |
| N = max(len(tokens) + 1, 1) # +1 for EOS | |
| return math.exp(-self.log_prob(tokens) / N) | |
| # ---------------------- Usage example ---------------------- | |
| if __name__ == "__main__": | |
| corpus = """ | |
| Alice was beginning to get very tired of sitting by her sister on the bank. | |
| She had nothing to do: once or twice she had peeped into the book her sister was reading, | |
| but it had no pictures or conversations in it, "and what is the use of a book," thought Alice "without pictures or conversations?" | |
| """ | |
| lm = NGramLM(n=3, discount=0.75, min_count=1, lowercase=True).fit(corpus) | |
| print("Vocab size (excl. <s>):", len(lm.vocab_no_bos)) | |
| print("Sample:", lm.generate(max_tokens=40, temperature=1.0, seed=1)) | |
| print("Perplexity on training text:", round(lm.perplexity(corpus), 3)) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment