Created
May 26, 2023 01:59
-
-
Save shnhrtkyk/66d7fdc68c8fb13f31259fc8a11131af to your computer and use it in GitHub Desktop.
FWNet architecture using pytorch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # -*- coding: utf-8 -*- | |
| """ | |
| Created on Tue Dec 17 08:04:29 2019 | |
| @author: shino | |
| """ | |
| from __future__ import print_function | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.parallel | |
| import torch.utils.data | |
| from torch.autograd import Variable | |
| import numpy as np | |
| import torch.nn.functional as F | |
| class STN3d(nn.Module): | |
| def __init__(self): | |
| super(STN3d, self).__init__() | |
| self.conv1 = torch.nn.Conv1d(163, 128, 1) | |
| self.conv2 = torch.nn.Conv1d(128, 256, 1) | |
| self.conv3 = torch.nn.Conv1d(256, 1024, 1) | |
| self.fc1 = nn.Linear(1024, 512) | |
| self.fc2 = nn.Linear(512, 256) | |
| self.fc3 = nn.Linear(256, 9) | |
| self.relu = nn.ReLU() | |
| self.bn1 = nn.BatchNorm1d(128) | |
| self.bn2 = nn.BatchNorm1d(256) | |
| self.bn3 = nn.BatchNorm1d(1024) | |
| self.bn4 = nn.BatchNorm1d(512) | |
| self.bn5 = nn.BatchNorm1d(256) | |
| def forward(self, x): | |
| batchsize = x.size()[0] | |
| x = F.relu(self.bn1(self.conv1(x))) | |
| x = F.relu(self.bn2(self.conv2(x))) | |
| x = F.relu(self.bn3(self.conv3(x))) | |
| x = torch.max(x, 2, keepdim=True)[0] | |
| x = x.view(-1, 1024) | |
| x = F.relu(self.bn4(self.fc1(x))) | |
| x = F.relu(self.bn5(self.fc2(x))) | |
| x = self.fc3(x) | |
| iden = Variable(torch.from_numpy(np.array([1,0,0,0,1,0,0,0,1]).astype(np.float32))).view(1,9).repeat(batchsize,1) | |
| if x.is_cuda: | |
| iden = iden.cuda() | |
| x = x + iden | |
| x = x.view(-1, 3, 3) | |
| return x | |
| class STNkd(nn.Module): | |
| def __init__(self, k=163): | |
| super(STNkd, self).__init__() | |
| self.conv1 = torch.nn.Conv1d(k, 128, 1) | |
| self.conv2 = torch.nn.Conv1d(128, 256, 1) | |
| self.conv3 = torch.nn.Conv1d(256, 1024, 1) | |
| self.fc1 = nn.Linear(1024, 512) | |
| self.fc2 = nn.Linear(512, 256) | |
| self.fc3 = nn.Linear(256, k*k) | |
| self.relu = nn.ReLU() | |
| self.bn1 = nn.BatchNorm1d(128) | |
| self.bn2 = nn.BatchNorm1d(256) | |
| self.bn3 = nn.BatchNorm1d(1024) | |
| self.bn4 = nn.BatchNorm1d(512) | |
| self.bn5 = nn.BatchNorm1d(256) | |
| self.k = k | |
| def forward(self, x): | |
| batchsize = x.size()[0] | |
| #print('x = '+str(x.size())) | |
| #print('batchsize = '+str(batchsize)) | |
| x = F.relu(self.bn1(self.conv1(x))) | |
| x = F.relu(self.bn2(self.conv2(x))) | |
| x = F.relu(self.bn3(self.conv3(x))) | |
| x = torch.max(x, 2, keepdim=True)[0] | |
| x = x.view(-1, 1024) | |
| x = F.relu(self.bn4(self.fc1(x))) | |
| x = F.relu(self.bn5(self.fc2(x))) | |
| x = self.fc3(x) | |
| iden = Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))).view(1,self.k*self.k).repeat(batchsize,1) | |
| if x.is_cuda: | |
| iden = iden.cuda() | |
| x = x + iden | |
| x = x.view(-1, self.k, self.k) | |
| return x | |
| class PointNetfeat(nn.Module): | |
| def __init__(self, global_feat = True, feature_transform = False, isVAE=False, istrain = True): | |
| super(PointNetfeat, self).__init__() | |
| self.stn = STNkd() | |
| self.conv1 = torch.nn.Conv1d(163, 128, 1) | |
| self.conv2 = torch.nn.Conv1d(128, 256, 1) | |
| self.conv3 = torch.nn.Conv1d(256, 1024, 1) | |
| self.bn1 = nn.BatchNorm1d(128) | |
| self.bn2 = nn.BatchNorm1d(256) | |
| self.bn3 = nn.BatchNorm1d(1024) | |
| self.global_feat = global_feat | |
| self.feature_transform = feature_transform | |
| if self.feature_transform: | |
| self.fstn = STNkd(k=64) | |
| self.isVAE = isVAE | |
| self.istrain = istrain | |
| def forward(self, x): | |
| n_pts = x.size()[2] | |
| #print('n_pts = '+str(n_pts)) | |
| trans = self.stn(x) | |
| x = x.transpose(2, 1) | |
| x = torch.bmm(x, trans) | |
| x = x.transpose(2, 1) | |
| x = F.relu(self.bn1(self.conv1(x))) | |
| if self.feature_transform: | |
| trans_feat = self.fstn(x) | |
| x = x.transpose(2,1) | |
| x = torch.bmm(x, trans_feat) | |
| x = x.transpose(2,1) | |
| else: | |
| trans_feat = None | |
| pointfeat = x | |
| x = F.relu(self.bn2(self.conv2(x))) | |
| x = self.bn3(self.conv3(x)) | |
| if(self.isVAE == True): | |
| if(self.istrain == True): | |
| rand = torch.rand(x.size()).cuda() | |
| #print(rand.size()) | |
| x = x + rand | |
| #print(x.size()) | |
| elif(self.istrain == False): | |
| x = x | |
| x = torch.max(x, 2, keepdim=True)[0] | |
| x = x.view(-1, 1024) | |
| if self.global_feat: | |
| return x, trans, trans_feat | |
| else: | |
| x = x.view(-1, 1024, 1).repeat(1, 1, n_pts) | |
| return torch.cat([x, pointfeat], 1), trans, trans_feat | |
| class PointNetCls(nn.Module): | |
| def __init__(self, k=6, feature_transform=False): | |
| super(PointNetCls, self).__init__() | |
| self.feature_transform = feature_transform | |
| self.feat = PointNetfeat(global_feat=True, feature_transform=feature_transform) | |
| self.fc1 = nn.Linear(1024, 512) | |
| self.fc2 = nn.Linear(512, 256) | |
| self.fc3 = nn.Linear(256, k) | |
| self.dropout = nn.Dropout(p=0.3) | |
| self.bn1 = nn.BatchNorm1d(512) | |
| self.bn2 = nn.BatchNorm1d(256) | |
| self.relu = nn.ReLU() | |
| def forward(self, x): | |
| x, trans, trans_feat = self.feat(x) | |
| x = F.relu(self.bn1(self.fc1(x))) | |
| x = F.relu(self.bn2(self.dropout(self.fc2(x)))) | |
| x = self.fc3(x) | |
| return F.log_softmax(x, dim=1), trans, trans_feat | |
| class PointNetDenseCls(nn.Module): | |
| def __init__(self, k = 163, feature_transform=False, isVAE = False, istrain = True): | |
| super(PointNetDenseCls, self).__init__() | |
| self.k = k | |
| self.feature_transform=feature_transform | |
| self.feat = PointNetfeat(global_feat=False, feature_transform=feature_transform,isVAE=isVAE, istrain = istrain) | |
| self.conv1 = torch.nn.Conv1d(1152, 512, 1) | |
| self.conv2 = torch.nn.Conv1d(512, 256, 1) | |
| self.conv3 = torch.nn.Conv1d(256, 128, 1) | |
| self.conv4 = torch.nn.Conv1d(128, k, 1) | |
| self.conv4_1 = torch.nn.Conv1d(128, 3, 1) | |
| self.conv4_2 = torch.nn.Conv1d(128, 163, 1) | |
| self.bn1 = nn.BatchNorm1d(512) | |
| self.bn2 = nn.BatchNorm1d(256) | |
| self.bn3 = nn.BatchNorm1d(128) | |
| def forward(self, x): | |
| batchsize = x.size()[0] | |
| #print('batch size = '+str(batchsize)) | |
| n_pts = x.size()[2] | |
| #print('n_pts = '+str(n_pts)) | |
| bottleneck, trans, trans_feat = self.feat(x) | |
| x = F.relu(self.bn1(self.conv1(bottleneck))) | |
| x = F.relu(self.bn2(self.conv2(x))) | |
| x = F.relu(self.bn3(self.conv3(x))) | |
| x = self.conv4(x) | |
| #x_1 = self.conv4_1(x) | |
| #x_2 = self.conv4_2(x) | |
| #x = x.transpose(2,1).contiguous() | |
| #x = F.linear(x.view(-1,self.k), dim=-1) | |
| #x = x.view(batchsize, n_pts, self.k) | |
| #return x_1, x_2, trans, trans_feat, bottleneck | |
| return x, trans, trans_feat, bottleneck | |
| def feature_transform_regularizer(trans): | |
| d = trans.size()[1] | |
| batchsize = trans.size()[0] | |
| I = torch.eye(d)[None, :, :] | |
| if trans.is_cuda: | |
| I = I.cuda() | |
| loss = torch.mean(torch.norm(torch.bmm(trans, trans.transpose(2,1)) - I, dim=(1,2))) | |
| return loss | |
| def mse(pred,gt): | |
| loss = torch.mean(torch.norm(torch.bmm(trans, trans.transpose(2,1)) - I, dim=(1,2))) | |
| return loss | |
| if __name__ == '__main__': | |
| sim_data = Variable(torch.rand(32,62,2048)) | |
| print(sim_data.size()) | |
| pointfeat = PointNetfeat(global_feat=True) | |
| out, _, _ = pointfeat(sim_data) | |
| print('global feat', out.size()) | |
| pointfeat = PointNetfeat(global_feat=False) | |
| out, _, _ = pointfeat(sim_data) | |
| print('point feat', out.size()) | |
| seg = PointNetDenseCls(k = 163) | |
| out, _, _ = seg(sim_data) | |
| print('seg', out.size()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment