Created
December 4, 2025 21:51
-
-
Save pszemraj/95f45c092a2480375fc1a94c7a0690dd to your computer and use it in GitHub Desktop.
slightly optimized image tiling for VLMs based on "Jina-VLM Small Multilingual Vision Language Model"
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| slightly optimized image tiling for vlms based on "Jina-VLM: Small Multilingual Vision Language Model" | |
| Based on the pseudocode in Appendix A.1: https://arxiv.org/abs/2512.04032 | |
| """ | |
| import math | |
| from typing import List, Tuple | |
| import torch | |
| import torch.nn.functional as F | |
| # 1. PRE-COMPUTE VALID GRIDS | |
| # We only calculate valid (h, w) pairs once. | |
| # This eliminates the nested loops completely during inference. | |
| MAX_TILES = 12 | |
| VALID_GRIDS = torch.tensor( | |
| [ | |
| (h, w) | |
| for h in range(1, MAX_TILES + 1) | |
| for w in range(1, MAX_TILES + 1) | |
| if h * w <= MAX_TILES | |
| ], | |
| dtype=torch.float32, | |
| ) # Shape: (N, 2) | |
| def get_tiles_pytorch( | |
| image_tensor: torch.Tensor, # Shape: (C, H, W) | |
| base_size: int = 378, | |
| patch_size: int = 14, | |
| overlap_margins: Tuple[int, int] = (4, 4), | |
| ) -> torch.Tensor: | |
| """ | |
| Efficiently tiles an image tensor using PyTorch operations. | |
| Returns: Batch of tiles (Num_Tiles + 1, C, 378, 378) | |
| """ | |
| # Constants derived from paper | |
| c, h, w = image_tensor.shape | |
| overlap_pixels = patch_size * (overlap_margins[0] + overlap_margins[1]) # 112 | |
| stride = (base_size // patch_size - sum(overlap_margins)) * patch_size # 266 | |
| # 2. VECTORIZED GRID SEARCH | |
| # Calculate target dimensions for all valid grids at once | |
| # grid_dims = (rows * stride + overlap, cols * stride + overlap) | |
| grid_pixel_dims = VALID_GRIDS * stride + overlap_pixels | |
| # Calculate Aspect Ratios (AR) | |
| # Image AR: w / h | |
| img_ar = w / h | |
| # Grid ARs: grid_width / grid_height | |
| grid_ars = grid_pixel_dims[:, 1] / grid_pixel_dims[:, 0] | |
| # Find index of grid with AR closest to image AR (in log space) | |
| # abs(log(img_ar) - log(grid_ar)) -> abs(log(img_ar / grid_ar)) | |
| ar_diffs = torch.abs(torch.log(img_ar / grid_ars)) | |
| best_grid_idx = torch.argmin(ar_diffs) | |
| # Get best configuration | |
| best_rows, best_cols = VALID_GRIDS[best_grid_idx].int().tolist() | |
| target_h = int(best_rows * stride + overlap_pixels) | |
| target_w = int(best_cols * stride + overlap_pixels) | |
| # 3. GPU-ACCELERATED RESIZE & TILING | |
| # Resize image to the exact target resolution | |
| # unsqueeze(0) adds batch dim for interpolate, then squeeze it back | |
| resized_img = F.interpolate( | |
| image_tensor.unsqueeze(0), | |
| size=(target_h, target_w), | |
| mode="bilinear", | |
| align_corners=False, | |
| ) # Shape: (1, C, Target_H, Target_W) | |
| # UN FOLD: The "Magic" Operation | |
| # Extracts sliding windows (tiles) in a single C++ optimized kernel. | |
| # We treat the tiles as "patches" of size 378x378 with stride 266. | |
| tiles = F.unfold( | |
| resized_img, kernel_size=(base_size, base_size), stride=(stride, stride) | |
| ) | |
| # Output shape: (1, C * 378 * 378, Num_Tiles) | |
| # Reshape back to image format: (Num_Tiles, C, 378, 378) | |
| tiles = tiles.transpose(1, 2).reshape(-1, c, base_size, base_size) | |
| # 4. ADD THUMBNAIL | |
| thumbnail = ( | |
| F.interpolate( | |
| image_tensor.unsqueeze(0), | |
| size=(base_size, base_size), | |
| mode="bilinear", | |
| align_corners=False, | |
| ) | |
| .squeeze(0) | |
| .unsqueeze(0) | |
| ) # Shape: (1, C, 378, 378) | |
| # Concatenate thumbnail + tiles | |
| final_batch = torch.cat([thumbnail, tiles], dim=0) | |
| return final_batch, (best_rows, best_cols) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment