|
import argparse |
|
import logging |
|
import multiprocessing as mp |
|
import os |
|
import sys |
|
from typing import List, Tuple |
|
import torch |
|
|
|
|
|
def setup_logging(log_level: str = "INFO", log_file: str = None) -> None: |
|
"""Configure logging with both console and file handlers""" |
|
numeric_level = getattr(logging, log_level.upper(), logging.INFO) |
|
|
|
# Base configuration |
|
handlers = [] |
|
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") |
|
|
|
# Console handler |
|
console_handler = logging.StreamHandler(sys.stdout) |
|
console_handler.setFormatter(formatter) |
|
handlers.append(console_handler) |
|
|
|
# File handler if specified |
|
if log_file: |
|
os.makedirs(os.path.dirname(log_file), exist_ok=True) |
|
file_handler = logging.FileHandler(log_file) |
|
file_handler.setFormatter(formatter) |
|
handlers.append(file_handler) |
|
|
|
# Configure root logger |
|
logging.basicConfig(level=numeric_level, handlers=handlers) |
|
|
|
|
|
def get_output_path(input_path: str, output_dir: str, suffix: str = ".md") -> str: |
|
"""Generate output file path maintaining directory structure""" |
|
rel_path = os.path.relpath(input_path, start=os.path.dirname(input_path)) |
|
name_without_ext = os.path.splitext(os.path.basename(rel_path))[0] |
|
output_subdir = os.path.join(output_dir, name_without_ext) |
|
return os.path.join(output_subdir, f"{name_without_ext}{suffix}") |
|
|
|
|
|
def find_pdf_files( |
|
input_dir: str, output_dir: str, force_process: bool = False |
|
) -> List[str]: |
|
"""Find PDF files that need processing""" |
|
pdf_files = [] |
|
skipped_files = [] |
|
|
|
for root, _, files in os.walk(input_dir): |
|
for file in files: |
|
if not file.lower().endswith(".pdf"): |
|
continue |
|
|
|
pdf_path = os.path.join(root, file) |
|
md_path = get_output_path(pdf_path, output_dir) |
|
|
|
if ( |
|
not force_process |
|
and os.path.exists(md_path) |
|
and os.path.getsize(md_path) > 0 |
|
): |
|
skipped_files.append(pdf_path) |
|
logging.debug(f"Skipping existing PDF: {pdf_path}") |
|
else: |
|
pdf_files.append(pdf_path) |
|
logging.debug(f"Found PDF to process: {pdf_path}") |
|
|
|
logging.info(f"Total PDFs found: {len(pdf_files) + len(skipped_files)}") |
|
logging.info(f"PDFs to process: {len(pdf_files)}") |
|
logging.info(f"PDFs skipped: {len(skipped_files)}") |
|
|
|
return pdf_files |
|
|
|
|
|
def process_pdf(pdf_path: str, output_dir: str) -> bool: |
|
"""Process a single PDF file. Returns True if successful, False otherwise.""" |
|
try: |
|
from magic_pdf.data.data_reader_writer import ( |
|
FileBasedDataWriter, |
|
FileBasedDataReader, |
|
) |
|
from magic_pdf.data.dataset import PymuDocDataset |
|
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze |
|
from magic_pdf.config.enums import SupportedPdfParseMethod |
|
|
|
name_without_ext = os.path.splitext(os.path.basename(pdf_path))[0] |
|
doc_output_dir = os.path.join(output_dir, name_without_ext) |
|
image_dir = os.path.join(doc_output_dir, "images") |
|
|
|
os.makedirs(image_dir, exist_ok=True) |
|
|
|
# Initialize writers |
|
image_writer = FileBasedDataWriter(image_dir) |
|
md_writer = FileBasedDataWriter(doc_output_dir) |
|
|
|
# Read PDF content |
|
reader = FileBasedDataReader("") |
|
pdf_bytes = reader.read(pdf_path) |
|
|
|
# Process PDF |
|
ds = PymuDocDataset(pdf_bytes) |
|
infer_result = ds.apply( |
|
doc_analyze, ocr=(ds.classify() == SupportedPdfParseMethod.OCR) |
|
) |
|
|
|
# Generate output |
|
pipe_result = ( |
|
infer_result.pipe_ocr_mode(image_writer) |
|
if ds.classify() == SupportedPdfParseMethod.OCR |
|
else infer_result.pipe_txt_mode(image_writer) |
|
) |
|
|
|
pipe_result.dump_md( |
|
md_writer, f"{name_without_ext}.md", os.path.basename(image_dir) |
|
) |
|
logging.info(f"Successfully processed {pdf_path}") |
|
return True |
|
|
|
except Exception as e: |
|
logging.error(f"Error processing {pdf_path}: {str(e)}") |
|
return False |
|
|
|
|
|
def init_worker() -> None: |
|
"""Initialize worker process""" |
|
torch.set_num_threads(1) |
|
|
|
|
|
def worker(worker_id: int, gpu_id: int, file_list: List[str], output_dir: str) -> None: |
|
"""Worker process function""" |
|
try: |
|
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
|
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) |
|
|
|
torch.cuda.empty_cache() |
|
|
|
device_id = torch.cuda.current_device() |
|
device_name = torch.cuda.get_device_name(device_id) |
|
logging.info( |
|
f"Worker {worker_id} using GPU {gpu_id} (CUDA device {device_id}): {device_name}" |
|
) |
|
|
|
for file in file_list: |
|
logging.info(f"Worker {worker_id} (GPU {gpu_id}) processing {file}") |
|
success = process_pdf(file, output_dir) |
|
if not success: |
|
logging.warning( |
|
f"Worker {worker_id} (GPU {gpu_id}) skipped {file} due to errors" |
|
) |
|
|
|
except Exception as e: |
|
logging.error(f"Worker {worker_id} failed: {str(e)}") |
|
|
|
|
|
def parse_args() -> argparse.Namespace: |
|
"""Parse command line arguments""" |
|
parser = argparse.ArgumentParser( |
|
description="PDF Processing Script with Multi-GPU Support" |
|
) |
|
|
|
parser.add_argument( |
|
"--input-dir", required=True, help="Input directory containing PDF files" |
|
) |
|
parser.add_argument( |
|
"--output-dir", required=True, help="Output directory for processed files" |
|
) |
|
parser.add_argument("--num-gpus", type=int, default=1, help="Number of GPUs to use") |
|
parser.add_argument( |
|
"--num-workers", type=int, default=3, help="Number of worker processes" |
|
) |
|
parser.add_argument( |
|
"--log-level", |
|
default="INFO", |
|
choices=["DEBUG", "INFO", "WARNING", "ERROR"], |
|
help="Logging level", |
|
) |
|
parser.add_argument("--log-file", help="Log file path (optional)") |
|
parser.add_argument( |
|
"--force", action="store_true", help="Force processing of all PDFs" |
|
) |
|
|
|
return parser.parse_args() |
|
|
|
|
|
def distribute_files(files: List[str], num_workers: int) -> List[List[str]]: |
|
"""Distribute files among workers""" |
|
if not files: |
|
return [] |
|
|
|
files_per_worker = len(files) // num_workers |
|
chunks = [ |
|
files[i : i + files_per_worker] for i in range(0, len(files), files_per_worker) |
|
] |
|
|
|
# Handle remainder |
|
if len(files) % num_workers != 0: |
|
chunks[-1].extend(files[-(len(files) % num_workers) :]) |
|
|
|
return chunks |
|
|
|
|
|
def main() -> None: |
|
"""Main function""" |
|
args = parse_args() |
|
|
|
# Setup logging |
|
setup_logging(args.log_level, args.log_file) |
|
|
|
# Verify CUDA availability |
|
if not torch.cuda.is_available(): |
|
logging.error("CUDA is not available. This script requires GPU support.") |
|
sys.exit(1) |
|
|
|
logging.info(f"Found {torch.cuda.device_count()} CUDA devices") |
|
|
|
# Configure multiprocessing |
|
mp.set_start_method("spawn", force=True) |
|
|
|
# Find PDF files |
|
pdf_files = find_pdf_files(args.input_dir, args.output_dir, args.force) |
|
if not pdf_files: |
|
logging.info("No PDFs to process. Exiting.") |
|
return |
|
|
|
# Distribute files among workers |
|
file_chunks = distribute_files(pdf_files, args.num_workers) |
|
|
|
# Start worker processes |
|
processes = [] |
|
for worker_id in range(args.num_workers): |
|
gpu_id = worker_id % args.num_gpus |
|
p = mp.Process( |
|
target=worker, |
|
args=(worker_id, gpu_id, file_chunks[worker_id], args.output_dir), |
|
) |
|
p.start() |
|
processes.append(p) |
|
|
|
# Wait for completion |
|
for p in processes: |
|
p.join() |
|
|
|
logging.info("All processing completed") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |