Created
August 3, 2025 02:07
-
-
Save calebrob6/2f8e9b2ca2f182a5694ec0c59c52e123 to your computer and use it in GitHub Desktop.
PyTorch local binary pattern histograms
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def batch_histogram(data_tensor, num_classes=-1): | |
""" | |
From https://discuss.pytorch.org/t/batched-torch-histc/179741 | |
Computes histograms of integral values, even if in batches (as opposed to torch.histc and torch.histogram). | |
Arguments: | |
data_tensor: a D1 x ... x D_n torch.LongTensor | |
num_classes (optional): the number of classes present in data. | |
If not provided, tensor.max() + 1 is used (an error is thrown if tensor is empty). | |
Returns: | |
A D1 x ... x D_{n-1} x num_classes 'result' torch.LongTensor, | |
containing histograms of the last dimension D_n of tensor, | |
that is, result[d_1,...,d_{n-1}, c] = number of times c appears in tensor[d_1,...,d_{n-1}]. | |
""" | |
return F.one_hot(data_tensor, num_classes).sum(dim=-2) | |
class LBP(nn.Module): | |
def __init__(self): | |
super().__init__() | |
# Define 8 fixed filters for 3x3 neighborhood | |
self.filters = torch.tensor([ | |
[[0, 0, 1], [0, 0, 0], [0, 0, 0]], # top-right | |
[[0, 1, 0], [0, 0, 0], [0, 0, 0]], # top | |
[[1, 0, 0], [0, 0, 0], [0, 0, 0]], # top-left | |
[[0, 0, 0], [0, 0, 1], [0, 0, 0]], # right | |
[[0, 0, 0], [0, 0, 0], [0, 0, 1]], # bottom-right | |
[[0, 0, 0], [0, 0, 0], [0, 1, 0]], # bottom | |
[[0, 0, 0], [0, 0, 0], [1, 0, 0]], # bottom-left | |
[[0, 0, 0], [1, 0, 0], [0, 0, 0]], # left | |
]).float().unsqueeze(1) # shape: [8, 1, 3, 3] | |
self.register_buffer("weight", self.filters) | |
def forward(self, x): | |
""" | |
x: [B, 1, H, W] - grayscale images | |
Returns: [B, 1, H, W] - LBP-encoded images | |
""" | |
B, C, H, W = x.shape | |
#assert C == 1, "Input must be grayscale" | |
lbp_codes = torch.zeros((B, C, H, W), device=x.device) | |
for c in range(C): | |
x_channel = x[:, c, :, :].unsqueeze(1) | |
# Apply 8 filters | |
neighbors = F.conv2d(x_channel, self.weight, padding=1) # shape: [B, 8, H, W] | |
# Center pixel | |
center = x_channel # [B, 1, H, W] | |
center = center.expand(-1, 8, -1, -1) # Match neighbors' shape | |
# Binary pattern: neighbor >= center -> 1 | |
binary = (neighbors >= center).int() # [B, 8, H, W] | |
# Convert binary pattern to decimal | |
powers = 2 ** torch.arange(7, -1, -1, device=x_channel.device).reshape(1, 8, 1, 1) | |
lbp_code = torch.sum(binary * powers, dim=1, keepdim=True).float() # [B, 1, H, W] | |
lbp_codes[:, c, :, :] = lbp_code[:, 0, :, :] # Store the LBP code for this channel | |
lbp_codes = lbp_codes.view(lbp_codes.size(0), -1).long() | |
lbp_histogram = batch_histogram(lbp_codes, num_classes=256) | |
lbp_histogram = lbp_histogram / lbp_histogram.sum(dim=1, keepdim=True) # Normalize histogram | |
lbp_histogram = lbp_histogram.float() | |
return lbp_histogram |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment