Skip to content

Instantly share code, notes, and snippets.

@trojblue
Last active March 4, 2025 02:58
Show Gist options
  • Save trojblue/162275255dc84a417f55e39f43b70b5e to your computer and use it in GitHub Desktop.
Save trojblue/162275255dc84a417f55e39f43b70b5e to your computer and use it in GitHub Desktop.

paste all contents below and replace {TO_CONVERT_CONTENT} with a working example of the new model to integrate

Here's the base inference class for a library called procslib:

# src/procslib/models/base_inference.py

from abc import ABC, abstractmethod
from typing import List

import pandas as pd
import torch
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm


class BaseImageInference(ABC):
    \"\"\"A base class to define a common interface for image inference models.
    All models should extend this class and implement required methods.
    \"\"\"

    def __init__(self, device: str = "cuda", batch_size: int = 32, **kwargs):
        \"\"\"Initialize inference class with a device and default batch size.\"\"\"
        self.device = device
        self.batch_size = batch_size
        self.model = None

    @abstractmethod
    def _load_model(self, checkpoint_path: str):
        \"\"\"Load the model from a checkpoint.
        Should set self.model as a torch.nn.Module instance on self.device.
        \"\"\"

    @abstractmethod
    def _preprocess_image(self, pil_image: Image.Image) -> torch.Tensor:
        \"\"\"Preprocess a single PIL image into the format required by the model.
        Return a torch.Tensor of shape (C, H, W).
        \"\"\"

    @abstractmethod
    def _postprocess_output(self, logits: torch.Tensor):
        \"\"\"Postprocess the raw logits from the model into desired predictions.
        For classification models, this might mean applying softmax and argmax.
        \"\"\"

    def infer_one(self, pil_image: Image.Image, **kwargs):
        \"\"\"Infer for a single image (PIL).\"\"\"
        self.model.eval()
        with torch.no_grad():
            image = self._preprocess_image(pil_image).unsqueeze(0).to(self.device)
            output = self.model(image)
            return self._postprocess_output(output.logits)

    def infer_batch(self, pil_images: List[Image.Image], **kwargs):
        \"\"\"Infer for a batch of images (list of PIL images).\"\"\"
        self.model.eval()
        with torch.no_grad():
            images = torch.stack([self._preprocess_image(img) for img in pil_images]).to(self.device)
            output = self.model(images)
            return self._postprocess_output(output.logits)

    def infer_many(self, image_paths: List[str], **kwargs):
        \"\"\"Infer for many images given their paths using a DataLoader for efficiency.
        Returns a pandas DataFrame with the results.
        \"\"\"
        dataset = ImagePathDataset(
            image_files=image_paths,
            preprocess_fn=self._preprocess_image,
        )
        dataloader = DataLoader(
            dataset,
            batch_size=self.batch_size,
            num_workers=8,
            pin_memory=True,
            collate_fn=custom_collate,
        )

        self.model.eval()
        results = []
        with torch.no_grad():
            for batch in tqdm(dataloader, desc="Inferring paths"):
                if batch is None:
                    continue
                images, paths = batch
                images = images.to(self.device)
                output = self.model(images)
                preds = self._postprocess_output(output.logits)

                for path, pred in zip(paths, preds):
                    results.append({"path": path, "prediction": pred})

        return pd.DataFrame(results)

    def extract_features(self, pil_image: Image.Image):
        \"\"\"Extract features from a single image at some layer before the classifier.
        This might vary per model and can be implemented as needed.
        By default, raise NotImplementedError if a model doesn't support it.
        \"\"\"
        raise NotImplementedError("Feature extraction not implemented for this model.")

    def extract_features_batch(self, pil_images: List[Image.Image]):
        \"\"\"Extract features for a batch of images.\"\"\"
        raise NotImplementedError("Batch feature extraction not implemented.")

    def predict_proba(self, pil_images: List[Image.Image], **kwargs):
        \"\"\"Predict class probabilities for a list of images.
        May be applicable only for classification models.
        \"\"\"
        raise NotImplementedError("Probability prediction not implemented for this model.")


class ImagePathDataset(Dataset):
    \"\"\"A dataset that loads images from paths and preprocesses them using a given preprocess function.\"\"\"

    def __init__(self, image_files: List[str], preprocess_fn):
        self.image_files = image_files
        self.preprocess_fn = preprocess_fn

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_path = self.image_files[idx]
        try:
            image = Image.open(image_path).convert("RGB")
            image = self.preprocess_fn(image)
            return image, image_path
        except Exception as e:
            print(f"Error loading image {image_path}: {e}")
            return None


def custom_collate(batch):
    \"\"\"Custom collate function to filter out None values.\"\"\"
    batch = [item for item in batch if item is not None]
    if not batch:
        return None
    return torch.utils.data.dataloader.default_collate(batch)

various models implement the ABC class so they can be called in the same way:

\"\"\"Using MiDaS 3.0 to analyze the "depthness" of images and returns a numerical metric

[Intel/dpt-hybrid-midas · Hugging Face](https://huggingface.co/Intel/dpt-hybrid-midas)

- To improve speed of inference: lower out_size to 512x512, or use more workers (12 is mostly enough)

- To improve accuracy of inference: increase out_size to 1024x1024

\"\"\"

# src/procslib/models/depth_wrapper.py

import os
from typing import List

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from PIL import Image
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import DPTForDepthEstimation, DPTImageProcessor

from .base_inference import BaseImageInference, ImagePathDataset, custom_collate


class DepthEstimationInference(BaseImageInference):
    def __init__(
        self,
        device="cuda",
        batch_size=48,
        lower_percentile=15,
        upper_percentile=95,
        out_size=(768, 768),  # Reduced from 1024x1024 -> 512x512
        num_workers=12,  # Use 4 CPU workers for loading
    ):
        super().__init__(device=device, batch_size=batch_size)
        self.lower_percentile = lower_percentile
        self.upper_percentile = upper_percentile
        self.out_size = out_size
        self.num_workers = num_workers
        self._load_model(None)

    def _load_model(self, checkpoint_path=None):
        self.feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
        self.model = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to(self.device)
        self.model.eval()

    def _preprocess_image(self, pil_image: Image.Image):
        inputs = self.feature_extractor(images=pil_image, return_tensors="pt")
        return inputs.pixel_values.squeeze(0)  # shape [3, H, W]

    def _postprocess_output(self, depth_map: torch.Tensor):
        # Depth map shape: [B, H, W]
        if depth_map.ndim == 3:
            depth_map = depth_map.unsqueeze(1)  # -> [B, 1, H, W]

        # Resize/normalize depth map
        depth_map = F.interpolate(
            depth_map,
            size=self.out_size,  # e.g. (512, 512) instead of (1024, 1024)
            mode="bicubic",
            align_corners=False,
        )
        depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
        depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
        depth_map_norm = (depth_map - depth_min) / (depth_max - depth_min + 1e-8)

        results = []
        # Now do CPU percentile for each image. 512x512 is 262,144 elements, a quarter the cost of 1024^2
        for i in range(depth_map_norm.size(0)):
            dm = depth_map_norm[i].squeeze().cpu().numpy()  # shape [512, 512]
            depth_values = dm.ravel()
            p_low = np.percentile(depth_values, self.lower_percentile)
            p_high = np.percentile(depth_values, self.upper_percentile)
            depth_score = float(p_high - p_low)

            results.append({"depth_score": depth_score})
        return results

    def infer_many(self, image_paths: List[str]):
        dataset = ImagePathDataset(image_paths, preprocess_fn=self._preprocess_image)
        dataloader = DataLoader(
            dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,  # changed from 0
            pin_memory=True,
            collate_fn=custom_collate,
        )
        self.model.eval()
        results = []
        with torch.no_grad(), torch.autocast("cuda"):
            for batch in tqdm(dataloader, desc="Inferring paths"):
                if batch is None:
                    continue
                images, paths = batch
                images = images.to(self.device, non_blocking=True)
                output = self.model(images)
                batch_results = self._postprocess_output(output.predicted_depth)

                for path, res in zip(paths, batch_results):
                    filename = os.path.basename(path)
                    res["filename"] = filename
                    results.append(res)

        return pd.DataFrame(results)


import glob


# Demo usage
def demo_depth_wrapper():
    folder_to_infer = "/rmt/image_data/dataset-ingested/gallery-dl/twitter/___Jenil"
    image_paths = glob.glob(folder_to_infer + "/*.jpg")
    inference = DepthEstimationInference(
        device="cuda",
        batch_size=24,
        lower_percentile=15,
        upper_percentile=95,
    )

    # Many images (parallelized with ProcessPoolExecutor)
    df = inference.infer_many(image_paths)
    df.to_csv("depth_scores.csv", index=False)
    print("Inference completed. Results saved to 'depth_scores.csv'.")


if __name__ == "__main__":
    demo_depth_wrapper()

and the models are initialized with model_builder.py:

  • if the model requires additional model files, use hf_hub_download to load them.
  • otherwise, initialize directly.
  • a key is added to model registry to be called.
# src/procslib/model_builder.py
import hashlib
import os

from huggingface_hub import hf_hub_download
from procslib.config import get_config

# ===== ADD YOUR MODELS BELOW =====


def get_cv2_metrics_model():
    \"\"\"Calculates OpenCV-based image metrics such as brightness, contrast, noise level, etc.
    输入图片, 输出图片质量评分
    \"\"\"
    from procslib.models import OpenCVMetricsInference

    return OpenCVMetricsInference(device="cpu", batch_size=32)


def get_rtmpose_model():
    \"\"\"A model trained for human pose estimation using RTMPose.
    输入图片, 输出人体姿势关键点
    \"\"\"
    from procslib.models import RTMPoseInference

    # Download the model checkpoint from the Hugging Face Hub
    curr_org = get_config("HF_ORG")
    checkpoint_path = hf_hub_download(
        repo_id=f"{curr_org}/{'rtm-pose-model'}",
        filename="rtmpose-l_8xb256-420e_humanart-256x192-389f2cb0_20230611_e2e.onnx"
    )
    print(checkpoint_path)    
    return RTMPoseInference(onnx_file=onnx_file, device="cuda")



def get_laion_watermark_model(**overrides):
    \"\"\"A model trained for predicting watermarks using Laion.
    输入图片, 输出水印评分
    \"\"\"
    from procslib.models.laion_watermark import LaionWatermarkInference

    return LaionWatermarkInference(**overrides)




MODEL_REGISTRY = {
    "twitter_logfav": get_twitter_logfav_model,
    "weakm_v2": get_weakm_v2_model,
    "weakm_v3": get_weakm_v3_model,
    "siglip_aesthetic": get_siglip_aesthetic_model,
    "pixiv_compound_score": get_pixiv_compound_score_model,
    "cv2_metrics": get_cv2_metrics_model,
    "complexity_ic9600": get_complexity_ic9600_model,
    "rtmpose": get_rtmpose_model,
    "depth": get_depth_model,
    "q_align_quality": get_q_align_quality_model,
    "q_align_aesthetics": get_q_align_aesthetics_model,
    "laion_watermark": get_laion_watermark_model,
    "clip_aesthetic": get_cached_clip_aesthetic_model,
    "vila": get_vila_model,
}


# ============ DO NOT EDIT BELOW THIS LINE ============


def get_model_keys():
    \"\"\"Retrieves the keys and descriptions of the model registry.

    Returns:
        dict: A dictionary where keys are model names and values are descriptions.
    \"\"\"
    return {key: func.__doc__.strip() for key, func in MODEL_REGISTRY.items()}


def get_model(descriptor: str, **overrides):
    \"\"\"Retrieves the actual model instance associated with the given descriptor.

    Args:
        descriptor (str): The model descriptor key in the MODEL_REGISTRY.

    Returns:
        object: The model instance.

    Raises:
        ValueError: If the descriptor is not found in MODEL_REGISTRY.
    \"\"\"
    if descriptor not in MODEL_REGISTRY:
        raise ValueError(f"Descriptor '{descriptor}' not found in MODEL_REGISTRY.")
    return MODEL_REGISTRY[descriptor](**overrides)

and that's how to integrate new models into procslib.

Now Here's a new class of model that I want to add into procslib. You're to write code to integrate the code into procslib, by writing working code of model class, initialization and inference test.

{TO_CONVERT_CONTENT}

Now, give me the working code (of model class, initialization and inference test).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment