Last active
December 31, 2025 19:19
-
-
Save Nikolaj-K/07164ea2ce0eaa966b0565dab8fe36db to your computer and use it in GitHub Desktop.
Minimal Wikipedia semantic search wrapper for web integration.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| """ | |
| Minimal Wikipedia semantic search wrapper for web integration. | |
| Install hints (CPU): | |
| - python, duh | |
| - pip install -U sentence-transformers faiss-cpu numpy | |
| Expected index artifacts (see zip): | |
| - path/to/pages_embeddings.index (main data item, 200 MB) | |
| - path/to/pages_metadata.json (title, main_cat, list length) | |
| - path/to/manifest.json (just model_name) | |
| Files are found at | |
| https://drive.google.com/file/d/1H1bqKrnd8z1aAVTPpI0eN0n7cYCSK0OP/view?usp=drive_link | |
| Also, INDEX_DIR_REL has to be set to your local location. See line 57. | |
| Usage: | |
| python3 | |
| >>> from path.to.wiki_page_searcher import WikiPageSearcher | |
| >>> searcher = WikiPageSearcher() | |
| >>> TOP_K = 20 | |
| >>> INPUT_TEXT_EXAMPLE = "There's a curious statistical phenomenon where 1's end up appearing more often in real life documents than other digits." | |
| >>> res = searcher(INPUT_TEXT_EXAMPLE, top_k=TOP_K) | |
| >>> res | |
| The input text can be restricted to about 2000 chars. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import os | |
| import warnings | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import numpy as np | |
| # Suppress noisy LibreSSL warning that can appear on some macOS Python builds. | |
| warnings.filterwarnings( | |
| "ignore", | |
| message=r"urllib3 v2 only supports OpenSSL 1\.1\.1\+.*", | |
| ) | |
| @dataclass(frozen=True) | |
| class Config: | |
| """ | |
| Adjust these paths for your environment if needed. | |
| The defaults assume this file lives at: | |
| path/to/wiki_page_searcher.py | |
| """ | |
| PROJECT_ROOT: Path = Path(__file__).resolve().parents[3] | |
| INDEX_DIR_REL: Path = Path("data/derived/wikipedia_api/page_search") | |
| INDEX_FILENAME: str = "pages_embeddings.index" # 200 MB | |
| METADATA_FILENAME: str = "pages_metadata.json" | |
| MANIFEST_FILENAME: str = "manifest.json" | |
| DEFAULT_MODEL_NAME: str = "sentence-transformers/all-mpnet-base-v2" | |
| DEFAULT_DEVICE: str = "cpu" | |
| # Useful on macOS/CPU to avoid rare BLAS/torch threading issues. | |
| TORCH_THREADS: int = 1 | |
| # Output formatting | |
| WIKI_HOST: str = "en.wikipedia.org/wiki" | |
| class WikiPageSearcher: | |
| """ | |
| Minimal callable wrapper around: | |
| - SentenceTransformer encoder | |
| - FAISS index (L2 distance) | |
| - metadata list (title + main_cat) | |
| Call: | |
| results = searcher(input_text, top_k=10) | |
| Returns: | |
| list[dict] where each element is: | |
| { | |
| "page_link": "en.wikipedia.org/wiki/Sparse_polynomial", | |
| "page_category": "Polynomials", | |
| "distance": 1.016 | |
| } | |
| """ | |
| def __init__( | |
| self, | |
| cfg: Config = Config(), | |
| index_dir: Optional[Path] = None, | |
| model_name: Optional[str] = None, | |
| device: Optional[str] = None, | |
| ) -> None: | |
| self.cfg = cfg | |
| self.project_root = cfg.PROJECT_ROOT.resolve() | |
| self.index_dir = ( | |
| (self.project_root / cfg.INDEX_DIR_REL).resolve() | |
| if index_dir is None | |
| else Path(index_dir).resolve() | |
| ) | |
| self.index_path = self.index_dir / cfg.INDEX_FILENAME | |
| self.metadata_path = self.index_dir / cfg.METADATA_FILENAME | |
| self.manifest_path = self.index_dir / cfg.MANIFEST_FILENAME | |
| # Thread limits (set early) | |
| os.environ.setdefault("OMP_NUM_THREADS", str(cfg.TORCH_THREADS)) | |
| os.environ.setdefault("MKL_NUM_THREADS", str(cfg.TORCH_THREADS)) | |
| os.environ.setdefault("OPENBLAS_NUM_THREADS", str(cfg.TORCH_THREADS)) | |
| self.model_name, self.device = self._resolve_model_device( | |
| model_name=model_name, | |
| device=device, | |
| ) | |
| self._faiss = self._import_faiss() | |
| self._index = self._load_faiss_index() | |
| self._meta = self._load_metadata() | |
| self._model = self._load_model() | |
| self._sanity_check() | |
| def __call__(self, input_text: str, top_k: int = 10) -> List[Dict[str, Any]]: | |
| if not isinstance(input_text, str) or not input_text.strip(): | |
| return [] | |
| top_k = int(top_k) | |
| if top_k <= 0: | |
| return [] | |
| # Encode query | |
| q = self._model.encode( | |
| [input_text.strip()], | |
| batch_size=1, | |
| show_progress_bar=False, | |
| convert_to_numpy=True, | |
| normalize_embeddings=False, | |
| ) | |
| q = np.asarray(q, dtype="float32") | |
| k = min(top_k, int(self._index.ntotal)) | |
| distances, indices = self._index.search(q, k) | |
| out: List[Dict[str, Any]] = [] | |
| for dist, idx in zip(distances[0].tolist(), indices[0].tolist()): | |
| if idx < 0: | |
| continue | |
| row = self._meta[idx] | |
| title = str(row.get("title", "")).strip() | |
| main_cat = str(row.get("main_cat", "unknown")).strip() | |
| if not title: | |
| continue | |
| out.append( | |
| { | |
| "page_link": f"{self.cfg.WIKI_HOST}/{title}", | |
| "page_category": main_cat, | |
| "distance": float(dist), | |
| } | |
| ) | |
| return out | |
| def _resolve_model_device(self, model_name: Optional[str], device: Optional[str]) -> Tuple[str, str]: | |
| resolved_device = (device or self.cfg.DEFAULT_DEVICE).strip() | |
| if model_name is not None and model_name.strip(): | |
| return model_name.strip(), resolved_device | |
| # Try to read from manifest.json if present, otherwise fall back. | |
| if self.manifest_path.exists(): | |
| try: | |
| obj = json.loads(self.manifest_path.read_text(encoding="utf-8")) | |
| m = str(obj.get("model_name", "")).strip() | |
| if m: | |
| return m, resolved_device | |
| except Exception: | |
| pass | |
| return self.cfg.DEFAULT_MODEL_NAME, resolved_device | |
| def _import_faiss(self): | |
| try: | |
| import faiss # type: ignore | |
| except Exception as e: | |
| raise RuntimeError( | |
| "faiss not installed. Install with:\n" | |
| " pip install faiss-cpu\n" | |
| f"Original error: {e}" | |
| ) from e | |
| return faiss | |
| def _load_faiss_index(self): | |
| if not self.index_path.exists(): | |
| raise FileNotFoundError(f"FAISS index not found. Given path: {self.index_path}") | |
| return self._faiss.read_index(str(self.index_path)) | |
| def _load_metadata(self) -> List[Dict[str, Any]]: | |
| if not self.metadata_path.exists(): | |
| raise FileNotFoundError(f"Metadata file not found: {self.metadata_path}") | |
| obj = json.loads(self.metadata_path.read_text(encoding="utf-8")) | |
| if not isinstance(obj, list): | |
| raise RuntimeError(f"Expected a JSON list in: {self.metadata_path}") | |
| return obj # type: ignore[return-value] | |
| def _load_model(self): | |
| try: | |
| import torch | |
| from sentence_transformers import SentenceTransformer | |
| except Exception as e: | |
| raise RuntimeError( | |
| "Missing dependencies. Install with:\n" | |
| " pip install sentence-transformers torch numpy\n" | |
| f"Original error: {e}" | |
| ) from e | |
| torch.set_num_threads(int(self.cfg.TORCH_THREADS)) | |
| return SentenceTransformer(self.model_name, device=self.device) | |
| def _sanity_check(self) -> None: | |
| if int(self._index.ntotal) != len(self._meta): | |
| raise RuntimeError( | |
| "Index and metadata length mismatch.\n" | |
| f"index.ntotal={int(self._index.ntotal)}, metadata_len={len(self._meta)}\n" | |
| f"index_path={self.index_path}\n" | |
| f"metadata_path={self.metadata_path}" | |
| ) | |
| if int(self._index.ntotal) == 0: | |
| raise RuntimeError(f"Index is empty: {self.index_path}") | |
| if __name__ == "__main__": | |
| INPUT_TEXT_EXAMPLE = "There's a curious statistical phenomenon where 1's end up appearing more often in real life documents than other digits." | |
| TOP_K = 20 # Number of outputs in the res dict | |
| searcher = WikiPageSearcher() | |
| res = searcher(INPUT_TEXT_EXAMPLE, top_k=TOP_K) # returns a dict/json | |
| # Pretty print res (can be nicer..) | |
| print(f"\nInput text example:\n'{INPUT_TEXT_EXAMPLE}'\n\nClosest {TOP_K} Wikipedia pages:") | |
| for page in res: | |
| print(f"{page['page_link']}, \t\tCategory: {page['page_category']}, \t\tdistance in R^768: {round(page['distance'], 3)}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment