paste all contents below and replace
{TO_CONVERT_CONTENT}
with a working example of the new model to integrate
Here's the base inference class for a library called procslib
:
# src/procslib/models/base_inference.py
from abc import ABC, abstractmethod
from typing import List
import pandas as pd
import torch
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm
class BaseImageInference(ABC):
\"\"\"A base class to define a common interface for image inference models.
All models should extend this class and implement required methods.
\"\"\"
def __init__(self, device: str = "cuda", batch_size: int = 32, **kwargs):
\"\"\"Initialize inference class with a device and default batch size.\"\"\"
self.device = device
self.batch_size = batch_size
self.model = None
@abstractmethod
def _load_model(self, checkpoint_path: str):
\"\"\"Load the model from a checkpoint.
Should set self.model as a torch.nn.Module instance on self.device.
\"\"\"
@abstractmethod
def _preprocess_image(self, pil_image: Image.Image) -> torch.Tensor:
\"\"\"Preprocess a single PIL image into the format required by the model.
Return a torch.Tensor of shape (C, H, W).
\"\"\"
@abstractmethod
def _postprocess_output(self, logits: torch.Tensor):
\"\"\"Postprocess the raw logits from the model into desired predictions.
For classification models, this might mean applying softmax and argmax.
\"\"\"
def infer_one(self, pil_image: Image.Image, **kwargs):
\"\"\"Infer for a single image (PIL).\"\"\"
self.model.eval()
with torch.no_grad():
image = self._preprocess_image(pil_image).unsqueeze(0).to(self.device)
output = self.model(image)
return self._postprocess_output(output.logits)
def infer_batch(self, pil_images: List[Image.Image], **kwargs):
\"\"\"Infer for a batch of images (list of PIL images).\"\"\"
self.model.eval()
with torch.no_grad():
images = torch.stack([self._preprocess_image(img) for img in pil_images]).to(self.device)
output = self.model(images)
return self._postprocess_output(output.logits)
def infer_many(self, image_paths: List[str], **kwargs):
\"\"\"Infer for many images given their paths using a DataLoader for efficiency.
Returns a pandas DataFrame with the results.
\"\"\"
dataset = ImagePathDataset(
image_files=image_paths,
preprocess_fn=self._preprocess_image,
)
dataloader = DataLoader(
dataset,
batch_size=self.batch_size,
num_workers=8,
pin_memory=True,
collate_fn=custom_collate,
)
self.model.eval()
results = []
with torch.no_grad():
for batch in tqdm(dataloader, desc="Inferring paths"):
if batch is None:
continue
images, paths = batch
images = images.to(self.device)
output = self.model(images)
preds = self._postprocess_output(output.logits)
for path, pred in zip(paths, preds):
results.append({"path": path, "prediction": pred})
return pd.DataFrame(results)
def extract_features(self, pil_image: Image.Image):
\"\"\"Extract features from a single image at some layer before the classifier.
This might vary per model and can be implemented as needed.
By default, raise NotImplementedError if a model doesn't support it.
\"\"\"
raise NotImplementedError("Feature extraction not implemented for this model.")
def extract_features_batch(self, pil_images: List[Image.Image]):
\"\"\"Extract features for a batch of images.\"\"\"
raise NotImplementedError("Batch feature extraction not implemented.")
def predict_proba(self, pil_images: List[Image.Image], **kwargs):
\"\"\"Predict class probabilities for a list of images.
May be applicable only for classification models.
\"\"\"
raise NotImplementedError("Probability prediction not implemented for this model.")
class ImagePathDataset(Dataset):
\"\"\"A dataset that loads images from paths and preprocesses them using a given preprocess function.\"\"\"
def __init__(self, image_files: List[str], preprocess_fn):
self.image_files = image_files
self.preprocess_fn = preprocess_fn
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
image_path = self.image_files[idx]
try:
image = Image.open(image_path).convert("RGB")
image = self.preprocess_fn(image)
return image, image_path
except Exception as e:
print(f"Error loading image {image_path}: {e}")
return None
def custom_collate(batch):
\"\"\"Custom collate function to filter out None values.\"\"\"
batch = [item for item in batch if item is not None]
if not batch:
return None
return torch.utils.data.dataloader.default_collate(batch)
various models implement the ABC class so they can be called in the same way:
\"\"\"Using MiDaS 3.0 to analyze the "depthness" of images and returns a numerical metric
[Intel/dpt-hybrid-midas · Hugging Face](https://huggingface.co/Intel/dpt-hybrid-midas)
- To improve speed of inference: lower out_size to 512x512, or use more workers (12 is mostly enough)
- To improve accuracy of inference: increase out_size to 1024x1024
\"\"\"
# src/procslib/models/depth_wrapper.py
import os
from typing import List
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from PIL import Image
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import DPTForDepthEstimation, DPTImageProcessor
from .base_inference import BaseImageInference, ImagePathDataset, custom_collate
class DepthEstimationInference(BaseImageInference):
def __init__(
self,
device="cuda",
batch_size=48,
lower_percentile=15,
upper_percentile=95,
out_size=(768, 768), # Reduced from 1024x1024 -> 512x512
num_workers=12, # Use 4 CPU workers for loading
):
super().__init__(device=device, batch_size=batch_size)
self.lower_percentile = lower_percentile
self.upper_percentile = upper_percentile
self.out_size = out_size
self.num_workers = num_workers
self._load_model(None)
def _load_model(self, checkpoint_path=None):
self.feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
self.model = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to(self.device)
self.model.eval()
def _preprocess_image(self, pil_image: Image.Image):
inputs = self.feature_extractor(images=pil_image, return_tensors="pt")
return inputs.pixel_values.squeeze(0) # shape [3, H, W]
def _postprocess_output(self, depth_map: torch.Tensor):
# Depth map shape: [B, H, W]
if depth_map.ndim == 3:
depth_map = depth_map.unsqueeze(1) # -> [B, 1, H, W]
# Resize/normalize depth map
depth_map = F.interpolate(
depth_map,
size=self.out_size, # e.g. (512, 512) instead of (1024, 1024)
mode="bicubic",
align_corners=False,
)
depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
depth_map_norm = (depth_map - depth_min) / (depth_max - depth_min + 1e-8)
results = []
# Now do CPU percentile for each image. 512x512 is 262,144 elements, a quarter the cost of 1024^2
for i in range(depth_map_norm.size(0)):
dm = depth_map_norm[i].squeeze().cpu().numpy() # shape [512, 512]
depth_values = dm.ravel()
p_low = np.percentile(depth_values, self.lower_percentile)
p_high = np.percentile(depth_values, self.upper_percentile)
depth_score = float(p_high - p_low)
results.append({"depth_score": depth_score})
return results
def infer_many(self, image_paths: List[str]):
dataset = ImagePathDataset(image_paths, preprocess_fn=self._preprocess_image)
dataloader = DataLoader(
dataset,
batch_size=self.batch_size,
num_workers=self.num_workers, # changed from 0
pin_memory=True,
collate_fn=custom_collate,
)
self.model.eval()
results = []
with torch.no_grad(), torch.autocast("cuda"):
for batch in tqdm(dataloader, desc="Inferring paths"):
if batch is None:
continue
images, paths = batch
images = images.to(self.device, non_blocking=True)
output = self.model(images)
batch_results = self._postprocess_output(output.predicted_depth)
for path, res in zip(paths, batch_results):
filename = os.path.basename(path)
res["filename"] = filename
results.append(res)
return pd.DataFrame(results)
import glob
# Demo usage
def demo_depth_wrapper():
folder_to_infer = "/rmt/image_data/dataset-ingested/gallery-dl/twitter/___Jenil"
image_paths = glob.glob(folder_to_infer + "/*.jpg")
inference = DepthEstimationInference(
device="cuda",
batch_size=24,
lower_percentile=15,
upper_percentile=95,
)
# Many images (parallelized with ProcessPoolExecutor)
df = inference.infer_many(image_paths)
df.to_csv("depth_scores.csv", index=False)
print("Inference completed. Results saved to 'depth_scores.csv'.")
if __name__ == "__main__":
demo_depth_wrapper()
and the models are initialized with model_builder.py
:
- if the model requires additional model files, use hf_hub_download to load them.
- otherwise, initialize directly.
- a key is added to model registry to be called.
# src/procslib/model_builder.py
import hashlib
import os
from huggingface_hub import hf_hub_download
from procslib.config import get_config
# ===== ADD YOUR MODELS BELOW =====
def get_cv2_metrics_model():
\"\"\"Calculates OpenCV-based image metrics such as brightness, contrast, noise level, etc.
输入图片, 输出图片质量评分
\"\"\"
from procslib.models import OpenCVMetricsInference
return OpenCVMetricsInference(device="cpu", batch_size=32)
def get_rtmpose_model():
\"\"\"A model trained for human pose estimation using RTMPose.
输入图片, 输出人体姿势关键点
\"\"\"
from procslib.models import RTMPoseInference
# Download the model checkpoint from the Hugging Face Hub
curr_org = get_config("HF_ORG")
checkpoint_path = hf_hub_download(
repo_id=f"{curr_org}/{'rtm-pose-model'}",
filename="rtmpose-l_8xb256-420e_humanart-256x192-389f2cb0_20230611_e2e.onnx"
)
print(checkpoint_path)
return RTMPoseInference(onnx_file=onnx_file, device="cuda")
def get_laion_watermark_model(**overrides):
\"\"\"A model trained for predicting watermarks using Laion.
输入图片, 输出水印评分
\"\"\"
from procslib.models.laion_watermark import LaionWatermarkInference
return LaionWatermarkInference(**overrides)
MODEL_REGISTRY = {
"twitter_logfav": get_twitter_logfav_model,
"weakm_v2": get_weakm_v2_model,
"weakm_v3": get_weakm_v3_model,
"siglip_aesthetic": get_siglip_aesthetic_model,
"pixiv_compound_score": get_pixiv_compound_score_model,
"cv2_metrics": get_cv2_metrics_model,
"complexity_ic9600": get_complexity_ic9600_model,
"rtmpose": get_rtmpose_model,
"depth": get_depth_model,
"q_align_quality": get_q_align_quality_model,
"q_align_aesthetics": get_q_align_aesthetics_model,
"laion_watermark": get_laion_watermark_model,
"clip_aesthetic": get_cached_clip_aesthetic_model,
"vila": get_vila_model,
}
# ============ DO NOT EDIT BELOW THIS LINE ============
def get_model_keys():
\"\"\"Retrieves the keys and descriptions of the model registry.
Returns:
dict: A dictionary where keys are model names and values are descriptions.
\"\"\"
return {key: func.__doc__.strip() for key, func in MODEL_REGISTRY.items()}
def get_model(descriptor: str, **overrides):
\"\"\"Retrieves the actual model instance associated with the given descriptor.
Args:
descriptor (str): The model descriptor key in the MODEL_REGISTRY.
Returns:
object: The model instance.
Raises:
ValueError: If the descriptor is not found in MODEL_REGISTRY.
\"\"\"
if descriptor not in MODEL_REGISTRY:
raise ValueError(f"Descriptor '{descriptor}' not found in MODEL_REGISTRY.")
return MODEL_REGISTRY[descriptor](**overrides)
and that's how to integrate new models into procslib.
Now Here's a new class of model that I want to add into procslib
. You're to write code to integrate the code into procslib, by writing working code of model class, initialization and inference test.
{TO_CONVERT_CONTENT}
Now, give me the working code (of model class, initialization and inference test).