Skip to content

Instantly share code, notes, and snippets.

View satwiksunnam19's full-sized avatar
🏠
Working from home

Satwik Sunnam satwiksunnam19

🏠
Working from home
  • University of Bridgeport
  • United States of America
View GitHub Profile
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dropout = 0., emb_dropout = 0.):
super().__init__()
assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
num_patches = (image_size // patch_size) ** 3
patch_dim = channels * patch_size ** 3
self.patch_size = patch_size
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))