Skip to content

Instantly share code, notes, and snippets.

@pszemraj
Created December 4, 2025 21:51
Show Gist options
  • Select an option

  • Save pszemraj/95f45c092a2480375fc1a94c7a0690dd to your computer and use it in GitHub Desktop.

Select an option

Save pszemraj/95f45c092a2480375fc1a94c7a0690dd to your computer and use it in GitHub Desktop.
slightly optimized image tiling for VLMs based on "Jina-VLM Small Multilingual Vision Language Model"
"""
slightly optimized image tiling for vlms based on "Jina-VLM: Small Multilingual Vision Language Model"
Based on the pseudocode in Appendix A.1: https://arxiv.org/abs/2512.04032
"""
import math
from typing import List, Tuple
import torch
import torch.nn.functional as F
# 1. PRE-COMPUTE VALID GRIDS
# We only calculate valid (h, w) pairs once.
# This eliminates the nested loops completely during inference.
MAX_TILES = 12
VALID_GRIDS = torch.tensor(
[
(h, w)
for h in range(1, MAX_TILES + 1)
for w in range(1, MAX_TILES + 1)
if h * w <= MAX_TILES
],
dtype=torch.float32,
) # Shape: (N, 2)
def get_tiles_pytorch(
image_tensor: torch.Tensor, # Shape: (C, H, W)
base_size: int = 378,
patch_size: int = 14,
overlap_margins: Tuple[int, int] = (4, 4),
) -> torch.Tensor:
"""
Efficiently tiles an image tensor using PyTorch operations.
Returns: Batch of tiles (Num_Tiles + 1, C, 378, 378)
"""
# Constants derived from paper
c, h, w = image_tensor.shape
overlap_pixels = patch_size * (overlap_margins[0] + overlap_margins[1]) # 112
stride = (base_size // patch_size - sum(overlap_margins)) * patch_size # 266
# 2. VECTORIZED GRID SEARCH
# Calculate target dimensions for all valid grids at once
# grid_dims = (rows * stride + overlap, cols * stride + overlap)
grid_pixel_dims = VALID_GRIDS * stride + overlap_pixels
# Calculate Aspect Ratios (AR)
# Image AR: w / h
img_ar = w / h
# Grid ARs: grid_width / grid_height
grid_ars = grid_pixel_dims[:, 1] / grid_pixel_dims[:, 0]
# Find index of grid with AR closest to image AR (in log space)
# abs(log(img_ar) - log(grid_ar)) -> abs(log(img_ar / grid_ar))
ar_diffs = torch.abs(torch.log(img_ar / grid_ars))
best_grid_idx = torch.argmin(ar_diffs)
# Get best configuration
best_rows, best_cols = VALID_GRIDS[best_grid_idx].int().tolist()
target_h = int(best_rows * stride + overlap_pixels)
target_w = int(best_cols * stride + overlap_pixels)
# 3. GPU-ACCELERATED RESIZE & TILING
# Resize image to the exact target resolution
# unsqueeze(0) adds batch dim for interpolate, then squeeze it back
resized_img = F.interpolate(
image_tensor.unsqueeze(0),
size=(target_h, target_w),
mode="bilinear",
align_corners=False,
) # Shape: (1, C, Target_H, Target_W)
# UN FOLD: The "Magic" Operation
# Extracts sliding windows (tiles) in a single C++ optimized kernel.
# We treat the tiles as "patches" of size 378x378 with stride 266.
tiles = F.unfold(
resized_img, kernel_size=(base_size, base_size), stride=(stride, stride)
)
# Output shape: (1, C * 378 * 378, Num_Tiles)
# Reshape back to image format: (Num_Tiles, C, 378, 378)
tiles = tiles.transpose(1, 2).reshape(-1, c, base_size, base_size)
# 4. ADD THUMBNAIL
thumbnail = (
F.interpolate(
image_tensor.unsqueeze(0),
size=(base_size, base_size),
mode="bilinear",
align_corners=False,
)
.squeeze(0)
.unsqueeze(0)
) # Shape: (1, C, 378, 378)
# Concatenate thumbnail + tiles
final_batch = torch.cat([thumbnail, tiles], dim=0)
return final_batch, (best_rows, best_cols)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment