Created
August 6, 2025 20:05
-
-
Save pllopis/5faf2ecc66ae5d87e3460bf7950511a4 to your computer and use it in GitHub Desktop.
dots_mps_parse.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # dots_mps_parse.py | |
| # macOS/M1-friendly runner for dots.ocr (local, offline, SDPA on MPS) | |
| # - Uses repo's built-in prompts via --prompt-mode (JSON-ready) | |
| # - Harmless flash_attn shim (with valid __spec__) to avoid import crashes | |
| # - Forces SDPA attention (not Flash-Attn) | |
| # - Disables Sliding-Window Attention for SDPA/Qwen2 | |
| # - Avoids BF16 by forcing FP16 everywhere and monkey-patching the vision tower | |
| # - Processes PDFs page-by-page with pixel caps to avoid OOM on 16 GB | |
| # - Parses JSON output flexibly (array OR object) and can drop headers/footers | |
| # --- harmless flash_attn shim with a valid __spec__ -------------------------- | |
| import sys, types, importlib.machinery as _machinery | |
| if "flash_attn" not in sys.modules: | |
| _m = types.ModuleType("flash_attn") | |
| _m.__spec__ = _machinery.ModuleSpec(name="flash_attn", loader=None) | |
| _m.__path__ = [] | |
| def flash_attn_varlen_func(*args, **kwargs): | |
| raise RuntimeError( | |
| "flash_attn was requested but isn't available on this platform. " | |
| "Use attn_implementation='sdpa'." | |
| ) | |
| _m.flash_attn_varlen_func = flash_attn_varlen_func | |
| sys.modules["flash_attn"] = _m | |
| # --------------------------------------------------------------------------- | |
| import os, gc, argparse, fitz, json, re | |
| from PIL import Image | |
| import torch | |
| from transformers import AutoConfig, AutoProcessor, AutoModelForCausalLM | |
| from qwen_vl_utils import process_vision_info | |
| from dots_ocr.utils.prompts import dict_promptmode_to_prompt # repo prompts | |
| def pil_from_page(page, dpi=144): | |
| # Render the PDF page to an RGB PIL.Image at a modest DPI to cap memory. | |
| pix = page.get_pixmap(dpi=dpi) | |
| return Image.frombytes("RGB", [pix.width, pix.height], pix.samples) | |
| # ---- JSON helpers ----------------------------------------------------------- | |
| def parse_json_flex(s: str): | |
| """ | |
| Accepts: | |
| - a full JSON object: {...} | |
| - a full JSON array: [...] | |
| - messy text with one JSON object/array embedded: ...{...}... or ...[...]... | |
| Returns the parsed Python value. | |
| """ | |
| t = s.strip() | |
| # Fast paths (trim trailing junk if the model echoes EOS tokens etc.) | |
| if t.startswith("{"): | |
| return json.loads(t[: t.rfind("}") + 1]) | |
| if t.startswith("["): | |
| return json.loads(t[: t.rfind("]") + 1]) | |
| # Slow path: extract first {...} or [...] block | |
| m_obj = re.search(r"\{[\s\S]*\}", s) | |
| m_arr = re.search(r"\[[\s\S]*\]", s) | |
| m = None | |
| if m_obj and m_arr: | |
| m = m_obj if m_obj.start() < m_arr.start() else m_arr | |
| else: | |
| m = m_obj or m_arr | |
| if not m: | |
| raise ValueError("No JSON object/array found in model output.") | |
| return json.loads(m.group(0)) | |
| def maybe_filter_blocks(obj, drop_headers_footers: bool): | |
| if not drop_headers_footers: | |
| return obj | |
| obj["blocks"] = [ | |
| b for b in obj.get("blocks", []) | |
| if b.get("category") not in ("Page-header", "Page-footer") | |
| ] | |
| return obj | |
| # ---------------------------------------------------------------------------- | |
| def main(args): | |
| # Offline caches: safe to set in-code | |
| os.environ.setdefault("TRANSFORMERS_OFFLINE", "1") | |
| os.environ.setdefault("HF_HUB_OFFLINE", "1") | |
| # If you hit MPS-missing ops, set this in your shell before running: | |
| # export PYTORCH_ENABLE_MPS_FALLBACK=1 | |
| device = "mps" if torch.backends.mps.is_available() else "cpu" | |
| dtype = torch.float16 if device == "mps" else torch.float32 | |
| # Cap visual tokens to avoid OOM on 16 GB (raise max_pix later if you have headroom) | |
| min_pix = 256 * 28 * 28 | |
| max_pix = 640 * 28 * 28 | |
| # Force SDPA & disable SWA (SWA not implemented for SDPA/Qwen2) | |
| config = AutoConfig.from_pretrained( | |
| args.model_dir, trust_remote_code=True, local_files_only=True | |
| ) | |
| if getattr(config, "vision_config", None) is not None: | |
| config.vision_config.attn_implementation = "sdpa" | |
| else: | |
| setattr(config, "attn_implementation", "sdpa") | |
| if hasattr(config, "use_sliding_window"): | |
| config.use_sliding_window = False | |
| if hasattr(config, "sliding_window"): | |
| config.sliding_window = None | |
| processor = AutoProcessor.from_pretrained( | |
| args.model_dir, | |
| local_files_only=True, | |
| min_pixels=min_pix, | |
| max_pixels=max_pix, | |
| trust_remote_code=True, | |
| # use_fast=True, # uncomment to quiet the "slow processor" warning | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| args.model_dir, | |
| local_files_only=True, | |
| trust_remote_code=True, | |
| config=config, # keep SDPA + SWA=off | |
| attn_implementation="sdpa", | |
| torch_dtype=dtype, | |
| low_cpu_mem_usage=True, | |
| device_map="auto", | |
| ) | |
| if device == "mps": | |
| model.to("mps") | |
| # CRITICAL: their vision tower defaults to bf16; force fp16 to avoid dtype mismatch on MPS | |
| import types as _types | |
| orig_forward_func = model.vision_tower.forward.__func__ # unbound function | |
| def _forward_no_bf16(self, hidden_states, grid_thw, bf16=True): | |
| return orig_forward_func(self, hidden_states, grid_thw, bf16=False) | |
| model.vision_tower.forward = _types.MethodType(_forward_no_bf16, model.vision_tower) | |
| attn_impl = getattr( | |
| getattr(model.config, "vision_config", model.config), "attn_implementation", None | |
| ) | |
| print(f"[info] device={device}, dtype={dtype}, vision_attn={attn_impl}, " | |
| f"use_sliding_window={getattr(model.config, 'use_sliding_window', None)}") | |
| doc = fitz.open(args.pdf) | |
| for i, page in enumerate(doc, start=1): | |
| img = pil_from_page(page, dpi=args.dpi) | |
| # Choose prompt text: repo prompt-mode first; --prompt overrides if provided | |
| prompt_text = dict_promptmode_to_prompt.get(args.prompt_mode, "") | |
| if args.prompt is not None: | |
| prompt_text = args.prompt | |
| # Qwen-style chat with user (image+text). Repo prompt already asks for JSON. | |
| messages = [ | |
| {"role": "user", "content": [ | |
| {"type": "image", "image": img}, | |
| {"type": "text", "text": prompt_text}, | |
| ]}, | |
| ] | |
| text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| image_inputs, _ = process_vision_info(messages) | |
| inputs = processor(text=[text], images=image_inputs, padding=True, return_tensors="pt") | |
| # Move to device + force FP16 on floats | |
| for k, v in list(inputs.items()): | |
| if isinstance(v, torch.Tensor): | |
| v = v.to(device) | |
| if torch.is_floating_point(v): | |
| v = v.to(dtype=dtype) | |
| inputs[k] = v | |
| with torch.inference_mode(): | |
| out_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=args.max_new_tokens, | |
| temperature=None, # deterministic | |
| do_sample=False, | |
| ) | |
| trimmed = [o[len(iids):] for iids, o in zip(inputs["input_ids"], out_ids)] | |
| out_text = processor.batch_decode( | |
| trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False | |
| )[0] | |
| if args.print_raw: | |
| print("\n--- RAW MODEL OUTPUT ---\n") | |
| print(out_text) | |
| # Parse JSON flexibly (array OR object) | |
| parsed = parse_json_flex(out_text) | |
| obj = {"blocks": parsed} if isinstance(parsed, list) else parsed | |
| # Ensure page metadata exists | |
| w, h = img.size | |
| obj.setdefault("page", i) | |
| obj.setdefault("width", w) | |
| obj.setdefault("height", h) | |
| obj = maybe_filter_blocks(obj, args.drop_headers_footers) | |
| # Print for debugging | |
| print(f"\n===== Page {i}/{len(doc)} JSON =====\n") | |
| print(json.dumps(obj, ensure_ascii=False, indent=2)) | |
| # Persist to disk if requested | |
| if args.json_out: | |
| out_line = json.dumps(obj, ensure_ascii=False) | |
| if args.json_out.lower().endswith(".jsonl"): | |
| with open(args.json_out, "a", encoding="utf-8") as f: | |
| f.write(out_line + "\n") | |
| else: | |
| # If a plain .json was given, write alongside as .jsonl | |
| with open(args.json_out + ".jsonl", "a", encoding="utf-8") as f: | |
| f.write(out_line + "\n") | |
| # Free per-page memory | |
| del img, inputs, out_ids, trimmed, obj, parsed | |
| gc.collect() | |
| if device == "mps": | |
| torch.mps.empty_cache() # release unoccupied cached VRAM | |
| if __name__ == "__main__": | |
| # Build argparse choices from repo prompts at runtime | |
| PROMPT_KEYS = tuple(dict_promptmode_to_prompt.keys()) | |
| DEFAULT_PROMPT_MODE = "prompt_layout_all_en" if "prompt_layout_all_en" in PROMPT_KEYS else (PROMPT_KEYS[0] if PROMPT_KEYS else None) | |
| ap = argparse.ArgumentParser() | |
| ap.add_argument("pdf") | |
| ap.add_argument("--model-dir", default="./DotsOCR") | |
| ap.add_argument("--dpi", type=int, default=144) # drop to 120 if memory spikes | |
| ap.add_argument("--max-new-tokens", type=int, default=1536) # 512–2048 is realistic on 16 GB | |
| ap.add_argument("--prompt", default=None, help="Custom user prompt (overrides --prompt-mode).") | |
| ap.add_argument("--prompt-mode", | |
| default=DEFAULT_PROMPT_MODE, | |
| choices=list(PROMPT_KEYS), | |
| help="Use a built-in dots.ocr prompt (e.g. prompt_layout_all_en outputs JSON).") | |
| ap.add_argument("--json-out", default=None, | |
| help="Path to write JSONL (.jsonl) or base name (.json -> will write .jsonl). " | |
| "Appends one JSON object per page.") | |
| ap.add_argument("--drop-headers-footers", action="store_true", | |
| help="Drop Page-header/Page-footer blocks from JSON.") | |
| ap.add_argument("--print-raw", action="store_true", | |
| help="Print raw model output before JSON parsing.") | |
| args = ap.parse_args() | |
| main(args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment