Skip to content

Instantly share code, notes, and snippets.

@imflash217
Last active April 6, 2021 19:21
Show Gist options
  • Save imflash217/51a90cdf993e97c7b5dd199f97f473e9 to your computer and use it in GitHub Desktop.
Save imflash217/51a90cdf993e97c7b5dd199f97f473e9 to your computer and use it in GitHub Desktop.
[AutoEncoder NN] #encoder #decoder #meenet
import torch
input = torch.randn(1, 2, 1025); input
##### ENCODER
# layer-1
downsample_1a = torch.nn.Conv1d(2, 20, 5 , stride=1, padding=0)
downsample_1b = torch.nn.Conv1d(2, 20, 50 , stride=1, padding=0)
downsample_1c = torch.nn.Conv1d(2, 20, 256 , stride=1, padding=0)
downsample_1d = torch.nn.Conv1d(2, 20, 512 , stride=1, padding=0)
downsample_1e = torch.nn.Conv1d(2, 20, 1025 , stride=1, padding=0)
out_1a = downsample_1a(input); print(out_1a.shape) # [1,20,1021]
out_1b = downsample_1b(input); print(out_1b.shape) # [1,20,976]
out_1c = downsample_1c(input); print(out_1c.shape)
out_1d = downsample_1d(input); print(out_1d.shape)
out_1e = downsample_1e(input); print(out_1e.shape)
temp_1a = torch.zeros(1,20,1025); # print(temp_1a)
temp_1b = torch.zeros(1,20,1025); # print(temp_1b)
temp_1c = torch.zeros(1,20,1025); # print(temp_1c)
temp_1d = torch.zeros(1,20,1025); # print(temp_1c)
temp_1e = torch.zeros(1,20,1025); # print(temp_1c)
temp_1a[:,:,0:out_1a.shape[2]] = out_1a
temp_1b[:,:,0:out_1b.shape[2]] = out_1b
temp_1c[:,:,0:out_1c.shape[2]] = out_1c
temp_1d[:,:,0:out_1d.shape[2]] = out_1d
temp_1e[:,:,0:out_1e.shape[2]] = out_1e
# print(temp_1a.shape)
# print(temp_1b[:,:,0:out_1b.shape[2]].shape)
out_1x = torch.cat((temp_1a, temp_1b, temp_1c, temp_1d, temp_1e), dim=1)
print(f'encoder_lyr1-------{out_1x.shape}')
# layer-2
downsample_2a = torch.nn.Conv1d(100, 50, 5 , stride=1, padding=0)
downsample_2b = torch.nn.Conv1d(100, 25, 50 , stride=1, padding=0)
downsample_2c = torch.nn.Conv1d(100, 20, 256 , stride=1, padding=0)
downsample_2d = torch.nn.Conv1d(100, 20, 512 , stride=1, padding=0)
downsample_2e = torch.nn.Conv1d(100, 20, 1025 , stride=1, padding=0)
out_2a = downsample_2a(out_1x); print(out_2a.shape)
out_2b = downsample_2b(out_1x); print(out_2b.shape)
out_2c = downsample_2c(out_1x); print(out_2c.shape)
out_2d = downsample_2d(out_1x); print(out_2d.shape)
out_2e = downsample_2e(out_1x); print(out_2e.shape)
temp_2a = torch.zeros(1,50,1025); # print(temp_1a)
temp_2b = torch.zeros(1,25,1025); # print(temp_1b)
temp_2c = torch.zeros(1,20,1025); # print(temp_1c)
temp_2d = torch.zeros(1,20,1025); # print(temp_1c)
temp_2e = torch.zeros(1,20,1025); # print(temp_1c)
temp_2a[:,:,0:out_2a.shape[2]] = out_2a
temp_2b[:,:,0:out_2b.shape[2]] = out_2b
temp_2c[:,:,0:out_2c.shape[2]] = out_2c
temp_2d[:,:,0:out_2d.shape[2]] = out_2d
temp_2e[:,:,0:out_2e.shape[2]] = out_2e
out_2x = torch.cat((temp_2a, temp_2b, temp_2c, temp_2d, temp_2e), dim=1)
print(f'encoder_lyr2-------{out_2x.shape}')
# DECODER
upsample_1a = torch.nn.ConvTranspose1d(135, 50, 5, stride=1, padding=0)
upsample_1b = torch.nn.ConvTranspose1d(135, 25, 50, stride=1, padding=0)
upsample_1c = torch.nn.ConvTranspose1d(135, 20, 256, stride=1, padding=0)
upsample_1d = torch.nn.ConvTranspose1d(135, 20, 512, stride=1, padding=0)
upsample_1e = torch.nn.ConvTranspose1d(135, 20, 1025, stride=1, padding=0)
out_3a = upsample_1a(out_2x); print(out_3a.shape)
out_3b = upsample_1b(out_2x); print(out_3b.shape)
out_3c = upsample_1c(out_2x); print(out_3c.shape)
out_3d = upsample_1d(out_2x); print(out_3d.shape)
out_3e = upsample_1e(out_2x); print(out_3e.shape)
temp_3a = torch.zeros(1,50,2049); # print(temp_1a)
temp_3b = torch.zeros(1,25,2049); # print(temp_1b)
temp_3c = torch.zeros(1,20,2049); # print(temp_1c)
temp_3d = torch.zeros(1,20,2049); # print(temp_1c)
temp_3e = torch.zeros(1,20,2049); # print(temp_1c)
temp_3a[:,:,0:out_3a.shape[2]] = out_3a
temp_3b[:,:,0:out_3b.shape[2]] = out_3b
temp_3c[:,:,0:out_3c.shape[2]] = out_3c
temp_3d[:,:,0:out_3d.shape[2]] = out_3d
temp_3e[:,:,0:out_3e.shape[2]] = out_3e
out_3x = torch.cat((temp_3a, temp_3b, temp_3c, temp_3d, temp_3e), dim=1)
print(f'encoder_lyr3-------{out_3x.shape}')
upsample_2a = torch.nn.ConvTranspose1d(135, 20, 5, stride=1, padding=0)
upsample_2b = torch.nn.ConvTranspose1d(135, 20, 50, stride=1, padding=0)
upsample_2c = torch.nn.ConvTranspose1d(135, 20, 256, stride=1, padding=0)
upsample_2d = torch.nn.ConvTranspose1d(135, 20, 512, stride=1, padding=0)
upsample_2e = torch.nn.ConvTranspose1d(135, 20, 1025, stride=1, padding=0)
out_4a = upsample_2a(out_3x); print(out_4a.shape)
out_4b = upsample_2b(out_3x); print(out_4b.shape)
out_4c = upsample_2c(out_3x); print(out_4c.shape)
out_4d = upsample_2d(out_3x); print(out_4d.shape)
out_4e = upsample_2e(out_3x); print(out_4e.shape)
m = out_4e.shape[2]
temp_4a = torch.zeros(1,20,m); # print(temp_1a)
temp_4b = torch.zeros(1,20,m); # print(temp_1b)
temp_4c = torch.zeros(1,20,m); # print(temp_1c)
temp_4d = torch.zeros(1,20,m); # print(temp_1c)
temp_4e = torch.zeros(1,20,m); # print(temp_1c)
temp_4a[:,:,0:out_4a.shape[2]] = out_4a
temp_4b[:,:,0:out_4b.shape[2]] = out_4b
temp_4c[:,:,0:out_4c.shape[2]] = out_4c
temp_4d[:,:,0:out_4d.shape[2]] = out_4d
temp_4e[:,:,0:out_4e.shape[2]] = out_4e
out_4x = torch.cat((temp_4a, temp_4b, temp_4c, temp_4d, temp_4e), dim=1)
print(f'encoder_lyr4-------{out_4x.shape}')
#### OUTPUT layer
output_layer = torch.nn.ConvTranspose1d(100, 2, 1025, stride=1, padding=0)
out_5x = output_layer(out_4x)
print(f'encoder_lyr5-------{out_5x.shape}')
# import numpy as np
# print(out_3e[:,:,1026:-1])
@imflash217
Copy link
Author

Output result:

torch.Size([1, 20, 1021])
torch.Size([1, 20, 976])
torch.Size([1, 20, 770])
torch.Size([1, 20, 514])
torch.Size([1, 20, 1])
encoder_lyr1-------torch.Size([1, 100, 1025])
torch.Size([1, 50, 1021])
torch.Size([1, 25, 976])
torch.Size([1, 20, 770])
torch.Size([1, 20, 514])
torch.Size([1, 20, 1])
encoder_lyr2-------torch.Size([1, 135, 1025])
torch.Size([1, 50, 1029])
torch.Size([1, 25, 1074])
torch.Size([1, 20, 1280])
torch.Size([1, 20, 1536])
torch.Size([1, 20, 2049])
encoder_lyr3-------torch.Size([1, 135, 2049])
torch.Size([1, 20, 2053])
torch.Size([1, 20, 2098])
torch.Size([1, 20, 2304])
torch.Size([1, 20, 2560])
torch.Size([1, 20, 3073])
encoder_lyr4-------torch.Size([1, 100, 3073])
encoder_lyr5-------torch.Size([1, 2, 4097])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment