Skip to content

Instantly share code, notes, and snippets.

@calebrob6
Created August 3, 2025 02:07
Show Gist options
  • Save calebrob6/2f8e9b2ca2f182a5694ec0c59c52e123 to your computer and use it in GitHub Desktop.
Save calebrob6/2f8e9b2ca2f182a5694ec0c59c52e123 to your computer and use it in GitHub Desktop.
PyTorch local binary pattern histograms
def batch_histogram(data_tensor, num_classes=-1):
"""
From https://discuss.pytorch.org/t/batched-torch-histc/179741
Computes histograms of integral values, even if in batches (as opposed to torch.histc and torch.histogram).
Arguments:
data_tensor: a D1 x ... x D_n torch.LongTensor
num_classes (optional): the number of classes present in data.
If not provided, tensor.max() + 1 is used (an error is thrown if tensor is empty).
Returns:
A D1 x ... x D_{n-1} x num_classes 'result' torch.LongTensor,
containing histograms of the last dimension D_n of tensor,
that is, result[d_1,...,d_{n-1}, c] = number of times c appears in tensor[d_1,...,d_{n-1}].
"""
return F.one_hot(data_tensor, num_classes).sum(dim=-2)
class LBP(nn.Module):
def __init__(self):
super().__init__()
# Define 8 fixed filters for 3x3 neighborhood
self.filters = torch.tensor([
[[0, 0, 1], [0, 0, 0], [0, 0, 0]], # top-right
[[0, 1, 0], [0, 0, 0], [0, 0, 0]], # top
[[1, 0, 0], [0, 0, 0], [0, 0, 0]], # top-left
[[0, 0, 0], [0, 0, 1], [0, 0, 0]], # right
[[0, 0, 0], [0, 0, 0], [0, 0, 1]], # bottom-right
[[0, 0, 0], [0, 0, 0], [0, 1, 0]], # bottom
[[0, 0, 0], [0, 0, 0], [1, 0, 0]], # bottom-left
[[0, 0, 0], [1, 0, 0], [0, 0, 0]], # left
]).float().unsqueeze(1) # shape: [8, 1, 3, 3]
self.register_buffer("weight", self.filters)
def forward(self, x):
"""
x: [B, 1, H, W] - grayscale images
Returns: [B, 1, H, W] - LBP-encoded images
"""
B, C, H, W = x.shape
#assert C == 1, "Input must be grayscale"
lbp_codes = torch.zeros((B, C, H, W), device=x.device)
for c in range(C):
x_channel = x[:, c, :, :].unsqueeze(1)
# Apply 8 filters
neighbors = F.conv2d(x_channel, self.weight, padding=1) # shape: [B, 8, H, W]
# Center pixel
center = x_channel # [B, 1, H, W]
center = center.expand(-1, 8, -1, -1) # Match neighbors' shape
# Binary pattern: neighbor >= center -> 1
binary = (neighbors >= center).int() # [B, 8, H, W]
# Convert binary pattern to decimal
powers = 2 ** torch.arange(7, -1, -1, device=x_channel.device).reshape(1, 8, 1, 1)
lbp_code = torch.sum(binary * powers, dim=1, keepdim=True).float() # [B, 1, H, W]
lbp_codes[:, c, :, :] = lbp_code[:, 0, :, :] # Store the LBP code for this channel
lbp_codes = lbp_codes.view(lbp_codes.size(0), -1).long()
lbp_histogram = batch_histogram(lbp_codes, num_classes=256)
lbp_histogram = lbp_histogram / lbp_histogram.sum(dim=1, keepdim=True) # Normalize histogram
lbp_histogram = lbp_histogram.float()
return lbp_histogram
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment