Last active
          September 26, 2025 00:20 
        
      - 
      
- 
        Save capttwinky/346bec3a15929816e7162093a53fbdc8 to your computer and use it in GitHub Desktop. 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | #!/usr/bin/env python3 | |
| # Quick start: | |
| # macOS/Linux — install uv: curl -LsSf https://astral.sh/uv/install.sh | sh | |
| # Windows (PowerShell): powershell -c "irm https://astral.sh/uv/install.ps1 | iex" | |
| # Run with uvx: uvx --from nltk python bongo_sentances.py --target 1000 --metric typed \ | |
| # --lm-order 4 --punct end-only --progress off | |
| """ | |
| Typing practice with an **NLTK POS-class language model** (letters-only, end-only punctuation). | |
| Now with small, fast **entertainment boosts**: | |
| - length shaping (normal target length) | |
| - adaptive temperature (function vs content tags) | |
| - soft alliteration bias | |
| - tiny micro-templates | |
| - rarity (mid-frequency) bias for emissions | |
| Design: POS 4‑gram over Brown (Universal tagset) → P(word|tag) emissions (letters only), | |
| robust I/O (`SessionParams`, `SessionIO`, `SessionSummary`). | |
| Doctests (helpers only): | |
| python -m doctest -v bongo_sentances.py | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| from dataclasses import dataclass, field | |
| import math | |
| import random | |
| import re | |
| import sys | |
| from typing import Callable, Dict, Iterable, Iterator, List, Literal, Optional, Sequence, Tuple | |
| # -------------------- Constants -------------------- | |
| NEWLINE = " | |
| " | |
| CARRIAGE_RETURN = " | |
| " | |
| BAR_WIDTH: int = 40 | |
| # -------------------- Progress UI -------------------- | |
| def progress_bar(current: int, total: int, width: int = BAR_WIDTH) -> None: | |
| ratio = min(current / total if total > 0 else 1.0, 1.0) | |
| filled = int(width * ratio) | |
| bar = "█" * filled + "─" * (width - filled) | |
| sys.stderr.write(f"{CARRIAGE_RETURN}[{bar}] {current}/{total} ({ratio*100:5.1f}%)") | |
| sys.stderr.flush() | |
| def noop_progress(_c: int, _t: int) -> None: | |
| pass | |
| def choose_progress(mode: str) -> Callable[[int, int], None]: | |
| if mode == "off": | |
| return noop_progress | |
| if mode == "auto": | |
| return progress_bar if sys.stderr.isatty() else noop_progress | |
| return progress_bar | |
| # -------------------- Detokenizer -------------------- | |
| def _detok_space_join(tokens: Iterable[str], *, add_final_period: bool) -> str: | |
| s = " ".join(tokens).strip() | |
| if not s: | |
| return s | |
| s = s[:1].upper() + s[1:] | |
| if add_final_period and s[-1] not in ".!?": | |
| s += "." | |
| return s | |
| # -------------------- NLTK data + models -------------------- | |
| WORD_RX = re.compile(r"^[A-Za-z]+$") | |
| VOWEL_RX = re.compile(r"[aeiou]", re.I) | |
| PunctMode = Literal["end-only", "none"] | |
| def ensure_brown_available() -> None: | |
| """Ensure Brown and the Universal Tagset mapping are available.""" | |
| try: | |
| import nltk # type: ignore | |
| try: | |
| nltk.data.find("corpora/brown") | |
| except LookupError: | |
| nltk.download("brown", quiet=True) | |
| nltk.data.find("corpora/brown") | |
| try: | |
| nltk.data.find("taggers/universal_tagset") | |
| except LookupError: | |
| nltk.download("universal_tagset", quiet=True) | |
| nltk.data.find("taggers/universal_tagset") | |
| except ImportError: | |
| raise SystemExit("NLTK is required. Install with: pip install nltk") | |
| # ---- POS tag LM (class model) ---- | |
| def build_tag_lm(order: int = 4, smoothing: str = "kneser_ney", lidstone_gamma: float = 0.2): | |
| """Train an n-gram LM over universal POS tags from Brown; return (model, tagged_sents).""" | |
| ensure_brown_available() | |
| from nltk.corpus import brown # type: ignore | |
| from nltk.lm import MLE, Laplace, Lidstone # type: ignore | |
| from nltk.lm.models import KneserNeyInterpolated, WittenBellInterpolated # type: ignore | |
| from nltk.lm.preprocessing import padded_everygram_pipeline # type: ignore | |
| tagged_sents = brown.tagged_sents(tagset="universal") | |
| tag_sequences: List[List[str]] = [[tag for (_w, tag) in sent] for sent in tagged_sents] | |
| if smoothing == "kneser_ney": | |
| model = KneserNeyInterpolated(order) | |
| elif smoothing == "witten_bell": | |
| model = WittenBellInterpolated(order) | |
| elif smoothing == "laplace": | |
| model = Laplace(order) | |
| elif smoothing == "lidstone": | |
| model = Lidstone(gamma=lidstone_gamma, order=order) | |
| else: | |
| model = MLE(order) | |
| train_data, vocab = padded_everygram_pipeline(order, tag_sequences) | |
| model.fit(train_data, vocab) | |
| return model, tagged_sents | |
| # ---- Emissions P(word|tag) ---- | |
| def _has_vowel(w: str) -> bool: | |
| return bool(VOWEL_RX.search(w)) | |
| def build_emissions(tagged_sents, *, min_count: int = 3, vowel_filter: bool = True, min_len: int = 3) -> Dict[str, Tuple[List[str], List[float]]]: | |
| """Build P(word|tag) over letters-only words with light filtering (freq/length/vowel).""" | |
| from collections import Counter | |
| tag_counters: Dict[str, Counter[str]] = {} | |
| for sent in tagged_sents: | |
| for word, tag in sent: | |
| w = word.lower() | |
| if not WORD_RX.fullmatch(w): | |
| continue | |
| if vowel_filter and w not in {"a", "i"} and not _has_vowel(w): | |
| continue | |
| if len(w) < min_len and w not in {"a", "i"}: | |
| continue | |
| tag_counters.setdefault(tag, Counter())[w] += 1 | |
| emissions: Dict[str, Tuple[List[str], List[float]]] = {} | |
| for tag, ctr in tag_counters.items(): | |
| items = [(w, c) for w, c in ctr.items() if (c >= min_count or w in {"a", "i"})] | |
| if not items: | |
| continue | |
| words = [w for w, _ in items] | |
| counts = [c for _, c in items] | |
| total = float(sum(counts)) or 1.0 | |
| probs = [c / total for c in counts] | |
| emissions[tag] = (words, probs) | |
| return emissions | |
| # -------------------- Sampling utils -------------------- | |
| def _softmax_sample(cands: List[str], probs: List[float], rnd: random.Random, *, | |
| temperature: float = 1.0, topk: Optional[int] = None, topp: Optional[float] = None) -> str: | |
| assert len(cands) == len(probs) and len(cands) > 0 | |
| if temperature <= 0: | |
| temperature = 1e-6 | |
| scaled = [max(p, 0.0) ** (1.0 / temperature) for p in probs] | |
| order_idx = sorted(range(len(scaled)), key=lambda i: scaled[i], reverse=True) | |
| cands = [cands[i] for i in order_idx] | |
| scaled = [scaled[i] for i in order_idx] | |
| if topk is not None and topk > 0: | |
| cands, scaled = cands[:topk], scaled[:topk] | |
| if topp is not None and 0.0 < topp < 1.0: | |
| total = sum(scaled) or 1.0 | |
| kept_c, kept_p, cum = [], [], 0.0 | |
| for c, p in zip(cands, scaled): | |
| kept_c.append(c); kept_p.append(p); cum += p | |
| if cum >= topp * total: | |
| break | |
| cands, scaled = kept_c, kept_p | |
| total = sum(scaled) | |
| if total <= 0: | |
| return cands[0] | |
| r, acc = rnd.uniform(0.0, total), 0.0 | |
| for c, p in zip(cands, scaled): | |
| acc += p | |
| if acc >= r: | |
| return c | |
| return cands[-1] | |
| # -------------------- POS-class sentence generator -------------------- | |
| def _allowed_tag(tag: str) -> bool: | |
| return tag != "." # exclude sentence punctuation | |
| CONTENT_TAGS = {"NOUN", "VERB", "ADJ", "ADV", "PROPN", "NUM"} | |
| FUNCTION_TAGS = {"AUX", "DET", "ADP", "PRON", "PART", "CCONJ", "CONJ", "SCONJ"} | |
| GUARD_SAME_TAG = FUNCTION_TAGS | |
| def _tag_candidates(model, history: Sequence[str]) -> Tuple[List[str], List[float]]: | |
| EOS = "</s>" | |
| tags: List[str] = [] | |
| probs: List[float] = [] | |
| for tok in model.vocab: # type: ignore[attr-defined] | |
| if tok == "<s>": | |
| continue | |
| if tok != EOS and not _allowed_tag(tok): | |
| continue | |
| p = float(model.score(tok, history)) | |
| if p > 0.0: | |
| tags.append(tok); probs.append(p) | |
| return (tags, probs) if tags else ([EOS], [1.0]) | |
| def _apply_content_bias(tags: List[str], probs: List[float], *, bias: float) -> List[float]: | |
| if bias <= 1.0: | |
| return probs | |
| scaled = [p * (bias if t in CONTENT_TAGS else 1.0) for t, p in zip(tags, probs)] | |
| s = sum(scaled) or 1.0 | |
| return [x / s for x in scaled] | |
| def _mask_same_tag_if_needed(tags: List[str], probs: List[float], last_tag: Optional[str]) -> Tuple[List[str], List[float]]: | |
| if not last_tag or last_tag not in GUARD_SAME_TAG: | |
| return tags, probs | |
| kept = [(t, p) for t, p in zip(tags, probs) if t != last_tag] | |
| return (list(zip(*kept))[0], list(zip(*kept))[1]) if kept else (tags, probs) | |
| def _avoid_recent_word_repeats(words: List[str], cand_words: List[str], cand_probs: List[float], window: int) -> Tuple[List[str], List[float]]: | |
| if not words or window <= 0: | |
| return cand_words, cand_probs | |
| recent = set(words[-window:]) | |
| kept = [(w, p) for w, p in zip(cand_words, cand_probs) if (not words or w != words[-1]) and (w not in recent)] | |
| return (list(zip(*kept))[0], list(zip(*kept))[1]) if kept else (cand_words, cand_probs) | |
| # --- Entertainment helpers --- | |
| def _rarity_weights(probs: List[float], rarity: float) -> List[float]: | |
| if rarity <= 0: | |
| return [1.0] * len(probs) | |
| # bias toward mid-frequency: weight ~ (p*(1-p))^rarity | |
| return [max(p * (1 - p), 1e-12) ** rarity for p in probs] | |
| def _allit_weights(words: List[str], probs: List[float], allit_char: Optional[str], alliteration: float) -> List[float]: | |
| if not allit_char or alliteration <= 0: | |
| return [1.0] * len(probs) | |
| return [1.0 + alliteration if w.startswith(allit_char) else 1.0 for w in words] | |
| def _adjust_emission_probs(words: List[str], probs: List[float], *, allit_char: Optional[str], alliteration: float, rarity: float) -> List[float]: | |
| rw = _rarity_weights(probs, rarity) | |
| aw = _allit_weights(words, probs, allit_char, alliteration) | |
| scaled = [p * r * a for p, r, a in zip(probs, rw, aw)] | |
| s = sum(scaled) or 1.0 | |
| return [x / s for x in scaled] | |
| # --- Micro-templates --- | |
| def maybe_template(rnd: random.Random, p: float, emissions) -> Optional[List[str]]: | |
| if p <= 0 or rnd.random() > p: | |
| return None | |
| # simple patterns using literal function words + emitted content | |
| patterns = [ | |
| ("as", "ADJ", "as", "a", "NOUN"), | |
| ("like", "a", "ADJ", "NOUN"), | |
| ("the", "NOUN", "of", "NOUN"), | |
| ("VERB", "the", "ADJ", "NOUN"), | |
| ] | |
| pat = rnd.choice(patterns) | |
| out: List[str] = [] | |
| for tok in pat: | |
| if tok in {"a", "as", "like", "the", "of"}: | |
| out.append(tok) | |
| else: | |
| tag = tok # ADJ/NOUN/VERB | |
| words, probs = emissions.get(tag, ([], [])) | |
| if not words: | |
| return None | |
| out.append(random.choice(words)) | |
| return out | |
| # -------------------- Generation -------------------- | |
| def generate_sentence_pos( | |
| tag_lm, | |
| emissions: Dict[str, Tuple[List[str], List[float]]], | |
| rnd: random.Random, | |
| *, | |
| order: int = 4, | |
| max_len: int = 30, | |
| punct_mode: PunctMode = "end-only", | |
| min_tokens: int = 1, | |
| temperature: float = 1.0, | |
| topk: Optional[int] = 16, | |
| topp: Optional[float] = 0.95, | |
| content_bias: float = 1.15, | |
| no_repeat_window: int = 4, | |
| max_func_run: int = 2, | |
| len_mean: float = 12.0, | |
| len_stdev: float = 3.0, | |
| temp_func: float = 0.85, | |
| temp_content: float = 1.05, | |
| alliteration: float = 0.2, | |
| templates: float = 0.15, | |
| rarity: float = 0.5, | |
| ) -> str: | |
| EOS = "</s>" | |
| # Template branch | |
| tpl = maybe_template(rnd, templates, emissions) | |
| if tpl is not None: | |
| return _detok_space_join(tpl, add_final_period=(punct_mode == "end-only")) | |
| # Length shaping | |
| target_len = max(min_tokens, min(max_len, int(round(rnd.normalvariate(len_mean, len_stdev))))) | |
| max_retries = 10 | |
| for _attempt in range(max_retries): | |
| history = tuple(["<s>"] * (max(order - 1, 0))) | |
| words: List[str] = [] | |
| last_tag: Optional[str] = None | |
| last_content_initial: Optional[str] = None | |
| func_run = 0 | |
| for _ in range(max_len): | |
| tags, probs = _tag_candidates(tag_lm, history) | |
| tags, probs = _mask_same_tag_if_needed(tags, probs, last_tag) | |
| if func_run >= max_func_run: | |
| keep = [(t, p) for t, p in zip(tags, probs) if t in CONTENT_TAGS or t == EOS] | |
| if keep: | |
| tags, probs = list(zip(*keep)) # type: ignore | |
| tags, probs = list(tags), list(probs) | |
| probs = _apply_content_bias(tags, probs, bias=content_bias) | |
| tag = _softmax_sample(tags, probs, rnd, temperature=temperature, topk=topk, topp=topp) | |
| if tag == EOS: | |
| break | |
| # Emission | |
| cand_w, cand_p = emissions.get(tag, ([], [])) | |
| if not cand_w: | |
| break | |
| cand_w, cand_p = _avoid_recent_word_repeats(words, cand_w, cand_p, window=no_repeat_window) | |
| adj_p = _adjust_emission_probs(cand_w, cand_p, allit_char=last_content_initial, alliteration=alliteration, rarity=rarity) | |
| t_emit = temp_content if tag in CONTENT_TAGS else temp_func | |
| word = _softmax_sample(cand_w, adj_p, rnd, temperature=t_emit, topk=topk, topp=topp) | |
| if words and word == words[-1]: # defensive | |
| alt = [(w, p) for w, p in zip(cand_w, adj_p) if w != word] | |
| if alt: | |
| cw, cp = list(zip(*alt)); word = _softmax_sample(list(cw), list(cp), rnd, temperature=t_emit, topk=topk, topp=topp) | |
| words.append(word) | |
| if tag in FUNCTION_TAGS: | |
| func_run += 1 | |
| else: | |
| func_run = 0 | |
| last_content_initial = word[0] if word else last_content_initial | |
| last_tag = tag | |
| if order > 1: | |
| history = (*history[1:], tag) | |
| if len(words) >= target_len: | |
| break | |
| if len(words) >= min_tokens: | |
| break | |
| return _detok_space_join(words, add_final_period=(punct_mode == "end-only")) | |
| def sentence_stream_pos( | |
| tag_lm, | |
| emissions: Dict[str, Tuple[List[str], List[float]]], | |
| *, | |
| seed: Optional[int] = None, | |
| order: int = 4, | |
| max_len: int = 30, | |
| punct_mode: PunctMode = "end-only", | |
| temperature: float = 1.0, | |
| topk: Optional[int] = 16, | |
| topp: Optional[float] = 0.95, | |
| min_tokens: int = 1, | |
| content_bias: float = 1.15, | |
| no_repeat_window: int = 4, | |
| max_func_run: int = 2, | |
| len_mean: float = 12.0, | |
| len_stdev: float = 3.0, | |
| temp_func: float = 0.85, | |
| temp_content: float = 1.05, | |
| alliteration: float = 0.2, | |
| templates: float = 0.15, | |
| rarity: float = 0.5, | |
| ) -> Iterator[str]: | |
| rnd = random.Random(seed) | |
| while True: | |
| yield generate_sentence_pos( | |
| tag_lm, emissions, rnd, | |
| order=order, max_len=max_len, punct_mode=punct_mode, | |
| min_tokens=min_tokens, temperature=temperature, topk=topk, topp=topp, | |
| content_bias=content_bias, no_repeat_window=no_repeat_window, max_func_run=max_func_run, | |
| len_mean=len_mean, len_stdev=len_stdev, temp_func=temp_func, temp_content=temp_content, | |
| alliteration=alliteration, templates=templates, rarity=rarity, | |
| ) | |
| # -------------------- Checking / metrics -------------------- | |
| CheckMode = Literal["none", "exact", "lev"] | |
| def normalize_text(s: str, *, case_sensitive: bool, keep_whitespace: bool) -> str: | |
| if not case_sensitive: | |
| s = s.lower() | |
| return s if keep_whitespace else " ".join(s.split()) | |
| def levenshtein(a: str, b: str) -> int: | |
| if a == b: | |
| return 0 | |
| if not a: | |
| return len(b) | |
| if not b: | |
| return len(a) | |
| if len(a) < len(b): | |
| a, b = b, a | |
| prev = list(range(len(b) + 1)) | |
| for i, ca in enumerate(a, 1): | |
| cur = [i] | |
| for j, cb in enumerate(b, 1): | |
| ins = cur[j - 1] + 1 | |
| dele = prev[j] + 1 | |
| sub = prev[j - 1] + (ca != cb) | |
| cur.append(min(ins, dele, sub)) | |
| prev = cur | |
| return prev[-1] | |
| def compute_metrics(expected: str, typed: str, *, mode: CheckMode, | |
| case_sensitive: bool, keep_whitespace: bool) -> dict: | |
| e = normalize_text(expected, case_sensitive=case_sensitive, keep_whitespace=keep_whitespace) | |
| t = normalize_text(typed, case_sensitive=case_sensitive, keep_whitespace=keep_whitespace) | |
| if mode == "none": | |
| return {"exact": None, "distance": None, "accuracy": None, | |
| "expected_len": len(e), "typed_len": len(t)} | |
| if mode == "exact": | |
| ok = (e == t) | |
| acc = 1.0 if ok else 0.0 | |
| return {"exact": ok, "distance": 0 if ok else max(len(e), len(t)), | |
| "accuracy": acc, "expected_len": len(e), "typed_len": len(t)} | |
| d = levenshtein(e, t) | |
| denom = max(len(e), len(t), 1) | |
| return {"exact": e == t, "distance": d, "accuracy": 1 - d/denom, | |
| "expected_len": len(e), "typed_len": len(t)} | |
| # -------------------- Session types -------------------- | |
| Metric = Literal["typed", "generated"] | |
| @dataclass(frozen=True) | |
| class SessionParams: | |
| target: int = 1000 | |
| metric: Metric = "typed" | |
| gen_sep: str = NEWLINE | |
| include_enter: bool = True | |
| check: CheckMode = "none" | |
| case_sensitive: bool = True | |
| keep_whitespace: bool = True | |
| @dataclass | |
| class SessionSummary: | |
| final_count: int | |
| samples: int | |
| total_expected: int | |
| total_typed: int | |
| total_distance: int | |
| avg_accuracy: Optional[float] | |
| @dataclass | |
| class SessionIO: | |
| show: Callable[[str], None] | |
| read: Callable[[], str] | |
| on_progress: Callable[[int, int], None] = field(default=lambda _c, _t: None) | |
| emit_feedback: Callable[[str], None] = field(default=lambda _m: None) | |
| # -------------------- Session runner -------------------- | |
| def run_session(stream: Iterable[str], io: SessionIO, params: SessionParams) -> SessionSummary: | |
| if params.metric == "typed": | |
| def increment(typed: str, _sentence: str) -> int: | |
| return len(typed) + (1 if params.include_enter else 0) | |
| else: | |
| sep_len = len(params.gen_sep) | |
| def increment(_typed: str, sentence: str) -> int: | |
| return len(sentence) + sep_len | |
| count = total_expected = total_typed = total_distance = 0 | |
| samples = 0 | |
| for sentence in stream: | |
| io.show(sentence) | |
| typed = io.read() | |
| if params.check != "none": | |
| m = compute_metrics(sentence, typed, mode=params.check, | |
| case_sensitive=params.case_sensitive, | |
| keep_whitespace=params.keep_whitespace) | |
| samples += 1 | |
| total_expected += m["expected_len"]; total_typed += m["typed_len"] | |
| if m["distance"] is not None: | |
| total_distance += m["distance"] | |
| if params.check == "exact": | |
| io.emit_feedback(f" ✓ exact{NEWLINE}" if m["exact"] else f" ✗ mismatch{NEWLINE}") | |
| else: | |
| io.emit_feedback(f" dist={m['distance']} acc={m['accuracy']:.3f}{NEWLINE}") | |
| count += increment(typed, sentence) | |
| io.on_progress(count, params.target) | |
| if count >= params.target: | |
| break | |
| avg = None | |
| if samples and params.check == "lev": | |
| denom = max(total_expected, total_typed, 1) | |
| avg = 1 - total_distance / denom | |
| elif params.check == "exact": | |
| avg = 1.0 | |
| return SessionSummary(final_count=count, samples=samples, total_expected=total_expected, | |
| total_typed=total_typed, total_distance=total_distance, avg_accuracy=avg) | |
| # -------------------- CLI -------------------- | |
| def parse_args(argv: List[str] | None = None) -> argparse.Namespace: | |
| ap = argparse.ArgumentParser(description="Typing practice (letters-only) using an NLTK POS-class LM.") | |
| ap.add_argument("--target", type=int, default=1000, help="Stop after this many characters (default: 1000)") | |
| ap.add_argument("--seed", type=int, help="Random seed for reproducibility") | |
| ap.add_argument("--max-sent-len", type=int, default=30, help="Maximum tokens per generated sentence") | |
| ap.add_argument("--metric", choices=["typed", "generated"], default="typed", | |
| help='Counting mode: "typed" (keypresses) or "generated" (program output)') | |
| ap.add_argument("--sep", choices=["newline", "space", "none"], default="newline", | |
| help="Printed separator when --metric=generated") | |
| ap.add_argument("--check", choices=["none", "exact", "lev"], default="none", | |
| help="Correctness checking mode (default: none)") | |
| ap.add_argument("--ignore-case", action="store_true", help="Ignore case for checking") | |
| ap.add_argument("--keep-whitespace", action="store_true", help="Do not collapse whitespace when checking") | |
| ap.add_argument("--progress", choices=["off", "bar", "auto"], default="off", | |
| help="Progress output mode (default: off)") | |
| ap.add_argument("--feedback", choices=["none", "summary", "lines"], default="summary", | |
| help="Per-line feedback mode (default: summary)") | |
| ap.add_argument("--punct", choices=["end-only", "none"], default="end-only", | |
| help='Punctuation policy: "end-only" (final period) or "none"') | |
| # Tag LM controls | |
| ap.add_argument("--lm-order", type=int, default=4, help="POS LM n-gram order (default: 4)") | |
| ap.add_argument("--lm-smoothing", choices=["kneser_ney", "witten_bell", "laplace", "lidstone", "mle"], | |
| default="kneser_ney", help="LM smoothing method (default: kneser_ney)") | |
| ap.add_argument("--lidstone-gamma", type=float, default=0.2, help="Gamma for Lidstone smoothing") | |
| # Sampling controls (applied to tags and emissions unless split below) | |
| ap.add_argument("--temperature", type=float, default=1.0, help="Sampling temperature (default: 1.0)") | |
| ap.add_argument("--topk", type=int, default=16, help="Top-K sampling cap (default: 16; 0 disables)") | |
| ap.add_argument("--topp", type=float, default=0.95, help="Nucleus (top-p) sampling (default: 0.95; 0 disables)") | |
| ap.add_argument("--min-tokens", type=int, default=1, help="Minimum words per sentence (default: 1)") | |
| # Emission & quality controls | |
| ap.add_argument("--min-count", type=int, default=3, help="Minimum token count per tag for emissions (default: 3)") | |
| ap.add_argument("--min-len", type=int, default=3, help="Minimum word length for emissions (default: 3; 'a'/'i' allowed)") | |
| ap.add_argument("--no-vowel-filter", action="store_true", help="Allow words without vowels (except a/i)") | |
| ap.add_argument("--no-repeat-window", type=int, default=4, help="Window to suppress recent word repeats (default: 4)") | |
| ap.add_argument("--content-bias", type=float, default=1.15, help="Up-weight for content tags (default: 1.15)") | |
| ap.add_argument("--max-func-run", type=int, default=2, help="Max consecutive functional tags before forcing content (default: 2)") | |
| # Entertainment | |
| ap.add_argument("--len-mean", type=float, default=12.0, help="Target sentence length (mean)") | |
| ap.add_argument("--len-stdev", type=float, default=3.0, help="Target sentence length (stdev)") | |
| ap.add_argument("--temp-func", type=float, default=0.85, help="Temperature for function-tag emissions") | |
| ap.add_argument("--temp-content", type=float, default=1.05, help="Temperature for content-tag emissions") | |
| ap.add_argument("--alliteration", type=float, default=0.2, help="Alliteration bonus (0=off)") | |
| ap.add_argument("--templates", type=float, default=0.15, help="Probability to use a micro-template (0-1)") | |
| ap.add_argument("--rarity", type=float, default=0.5, help="Bias toward mid-frequency emissions (0=off)") | |
| return ap.parse_args(argv) | |
| def cli(argv: List[str] | None = None) -> int: | |
| args = parse_args(argv) | |
| # Train tag LM and build emissions | |
| tag_lm, tagged_sents = build_tag_lm(order=args.lm_order, smoothing=args.lm_smoothing, lidstone_gamma=args.lidstone_gamma) | |
| emissions = build_emissions(tagged_sents, min_count=args.min_count, vowel_filter=not args.no_vowel_filter, min_len=args.min_len) | |
| # Decide printed separator and build IO | |
| if args.metric == "generated": | |
| gen_sep = NEWLINE if args.sep == "newline" else (" " if args.sep == "space" else "") | |
| def show_fn(s: str) -> None: | |
| print(s, end=gen_sep) | |
| else: | |
| gen_sep = NEWLINE | |
| def show_fn(s: str) -> None: | |
| print(s) | |
| on_progress = choose_progress(args.progress) | |
| emit_feedback = (sys.stderr.write if args.feedback == "lines" else (lambda _msg: None)) | |
| params = SessionParams(target=args.target, metric=args.metric, gen_sep=gen_sep, | |
| include_enter=True, check=args.check, | |
| case_sensitive=not args.ignore_case, keep_whitespace=args.keep_whitespace) | |
| io = SessionIO(show=show_fn, read=lambda: input("> "), on_progress=on_progress, emit_feedback=emit_feedback) | |
| # Build the sentence stream | |
| sentences = sentence_stream_pos( | |
| tag_lm, emissions, seed=args.seed, order=args.lm_order, max_len=args.max_sent_len, | |
| punct_mode=args.punct, temperature=args.temperature, topk=(None if args.topk <= 0 else args.topk), | |
| topp=(None if args.topp <= 0 else args.topp), min_tokens=args.min_tokens, | |
| content_bias=args.content_bias, no_repeat_window=max(0, args.no_repeat_window), | |
| max_func_run=max(0, args.max_func_run), len_mean=args.len_mean, len_stdev=args.len_stdev, | |
| temp_func=args.temp_func, temp_content=args.temp_content, alliteration=args.alliteration, | |
| templates=args.templates, rarity=args.rarity, | |
| ) | |
| summary = run_session(sentences, io, params) | |
| if args.feedback != "none" and args.check != "none": | |
| print(f"@ Done | samples={summary.samples} avg_accuracy={summary.avg_accuracy}") | |
| else: | |
| print("@ Done") | |
| return 0 | |
| if __name__ == "__main__": | |
| raise SystemExit(cli()) | 
  
    Sign up for free
    to join this conversation on GitHub.
    Already have an account?
    Sign in to comment