Skip to content

Instantly share code, notes, and snippets.

@Sinjhin
Created November 18, 2024 23:04
Show Gist options
  • Save Sinjhin/1c893dcc3f6f99dc40da8cb684fe3d30 to your computer and use it in GitHub Desktop.
Save Sinjhin/1c893dcc3f6f99dc40da8cb684fe3d30 to your computer and use it in GitHub Desktop.
Multi-Dim-Example-Test
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiDimTrial(nn.Module):
def __init__(self, input_shape, output_shape, num_demos, mode='train'):
super(MultiDimTrial, self).__init__()
self.name = 'MultiDimTrial'
self.mode = mode
self.num_demos = num_demos
self.input_shape = input_shape
self.output_shape = output_shape
self.batch_size = 10
self.dropout3d = nn.Dropout3d(p=0.2)
self.dropout2d = nn.Dropout2d(p=0.2)
self.dropout1d = nn.Dropout(p=0.2)
self.bn1 = nn.BatchNorm3d(32)
self.bn2 = nn.BatchNorm3d(64)
self.bn3 = nn.BatchNorm3d(32)
self.upsample = nn.Upsample(
size=(num_demos, output_shape[0], output_shape[1]),
mode='trilinear',
align_corners=False
)
# Transformer to learn relationships across the 3d demo dimension
transformer_dim = 10 * output_shape[0] * output_shape[1]
self.pos_encoding = nn.Parameter(torch.randn(self.batch_size, self.num_demos, transformer_dim))
self.transformer = nn.TransformerEncoderLayer(
d_model=transformer_dim,
nhead=transformer_dim % 4 + 1,
dim_feedforward=transformer_dim*2,
dropout=0.2
)
# 3D to 3D transformation learning
self.transform3d_1 = nn.ConvTranspose3d(
10,
32,
kernel_size=(self.num_demos, 3, 3),
stride=(1, 1, 1),
padding=(0, 1, 1)
)
self.transform3d_2 = nn.ConvTranspose3d(
32,
64,
kernel_size=(self.num_demos, 3, 3),
padding=(0, 1, 1)
)
# 3D feature extraction and pattern learning
self.conv3d_1 = nn.Conv3d(64, 32, kernel_size=(self.num_demos, 2, 2), padding=(0, 2, 2))
self.conv3d_2 = nn.Conv3d(32, 10, kernel_size=(self.num_demos, 4, 4), padding=0)
# 3D to 2D projection layers
self.project_3d_to_2d = nn.Sequential(
nn.Linear(transformer_dim, 128),
nn.ReLU(),
nn.Linear(128, 81),
)
# Compression of 3D knowledge into 2D operations
self.compress_knowledge = nn.Sequential(
nn.Linear(16 * 9 * 9 * 5, 16 * 9 * 9), # Compress demo dimension
nn.ReLU(),
nn.Linear(16 * 9 * 9, 9 * 9) # Project to output size
)
# Final output projection
self.output_projection = nn.Linear(16*9*9, 9*9)
def forward(self, input):
x = input.float()
x = x.unsqueeze(0).unsqueeze(0) # [B, C, D, H, W]
if self.mode == 'train':
x = x.expand(self.batch_size, 10, -1, -1, -1)
x = F.relu(self.upsample(x))
x = self.dropout3d(x)
# Reshape for transformer
x = x.permute(0, 2, 1, 3, 4) # [batch, demos, features, 9, 9]
x = x.flatten(2) # [batch, demos, features*9*9]
# Add positional encoding
x = x + self.pos_encoding
# Apply transformer across demo dimension
x = x.permute(1, 0, 2) # [demos, batch, F*H*W]
x = self.transformer(x)
x = x.permute(1, 0, 2) # [batch, demos, F*H*W]
x = F.interpolate(
x.view(self.batch_size, self.num_demos, 10, 9, 9),
size=(10, 9, 9),
mode='trilinear',
align_corners=False
)
x = x.permute(0, 2, 1, 3, 4)
x = self.dropout3d(x)
else: # test mode
# Input shape: [B, C, H, W]
x = x.expand(1, 10, -1, -1, -1) # [1, 10, H, W]
x = F.relu(self.upsample(x))
x = self.dropout3d(x)
# Similar transformer processing
x = x.permute(0, 2, 1, 3, 4) # [1, demos, features, 9, 9]
x = x.flatten(2) # [1, demos, features*9*9]
# Add positional encoding and transformer
x = x + self.pos_encoding[0:1] # Only use first batch item
x = x.permute(1, 0, 2) # [demos, 1, F*H*W]
x = self.transformer(x)
x = x.permute(1, 0, 2) # [1, demos, F*H*W]
# Reshape back
x = x.view(1, self.num_demos, 10, 9, 9)
x = x.permute(0, 2, 1, 3, 4)
x = self.dropout3d(x)
# Transform in 3D space
x = F.relu(self.transform3d_1(x))
x = F.relu(self.bn1(x)) # We might not want this
x = self.dropout3d(x)
x = F.relu(self.transform3d_2(x)) # [batch, 16, demos, H, W]
x = F.relu(self.bn2(x)) # We might not want this
x = self.dropout3d(x)
# Extract 3D patterns
x = F.relu(self.conv3d_1(x))
x = F.relu(self.bn3(x)) # We might not want this
x = self.dropout3d(x)
x = F.relu(self.conv3d_2(x))
x = x[0] # [C, D, H, W]
if self.mode != 'train':
x = x.mean(dim=1) # Should give us [10, 9, 9]
weights = torch.arange(10, dtype=torch.float32, device=x.device).view(10, 1, 1)
x = (x * weights).sum(dim=0) # Sum across the weighted channels
else:
weights = torch.arange(10, dtype=torch.float32, device=x.device).view(10, 1, 1, 1)
x = (x * weights).sum(dim=0)
return x
def set_mode(self, mode):
"""Switch between train and test modes"""
assert mode in ['train', 'test']
self.mode = mode
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment