Last active
April 6, 2021 19:21
-
-
Save imflash217/51a90cdf993e97c7b5dd199f97f473e9 to your computer and use it in GitHub Desktop.
[AutoEncoder NN] #encoder #decoder #meenet
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
input = torch.randn(1, 2, 1025); input | |
##### ENCODER | |
# layer-1 | |
downsample_1a = torch.nn.Conv1d(2, 20, 5 , stride=1, padding=0) | |
downsample_1b = torch.nn.Conv1d(2, 20, 50 , stride=1, padding=0) | |
downsample_1c = torch.nn.Conv1d(2, 20, 256 , stride=1, padding=0) | |
downsample_1d = torch.nn.Conv1d(2, 20, 512 , stride=1, padding=0) | |
downsample_1e = torch.nn.Conv1d(2, 20, 1025 , stride=1, padding=0) | |
out_1a = downsample_1a(input); print(out_1a.shape) # [1,20,1021] | |
out_1b = downsample_1b(input); print(out_1b.shape) # [1,20,976] | |
out_1c = downsample_1c(input); print(out_1c.shape) | |
out_1d = downsample_1d(input); print(out_1d.shape) | |
out_1e = downsample_1e(input); print(out_1e.shape) | |
temp_1a = torch.zeros(1,20,1025); # print(temp_1a) | |
temp_1b = torch.zeros(1,20,1025); # print(temp_1b) | |
temp_1c = torch.zeros(1,20,1025); # print(temp_1c) | |
temp_1d = torch.zeros(1,20,1025); # print(temp_1c) | |
temp_1e = torch.zeros(1,20,1025); # print(temp_1c) | |
temp_1a[:,:,0:out_1a.shape[2]] = out_1a | |
temp_1b[:,:,0:out_1b.shape[2]] = out_1b | |
temp_1c[:,:,0:out_1c.shape[2]] = out_1c | |
temp_1d[:,:,0:out_1d.shape[2]] = out_1d | |
temp_1e[:,:,0:out_1e.shape[2]] = out_1e | |
# print(temp_1a.shape) | |
# print(temp_1b[:,:,0:out_1b.shape[2]].shape) | |
out_1x = torch.cat((temp_1a, temp_1b, temp_1c, temp_1d, temp_1e), dim=1) | |
print(f'encoder_lyr1-------{out_1x.shape}') | |
# layer-2 | |
downsample_2a = torch.nn.Conv1d(100, 50, 5 , stride=1, padding=0) | |
downsample_2b = torch.nn.Conv1d(100, 25, 50 , stride=1, padding=0) | |
downsample_2c = torch.nn.Conv1d(100, 20, 256 , stride=1, padding=0) | |
downsample_2d = torch.nn.Conv1d(100, 20, 512 , stride=1, padding=0) | |
downsample_2e = torch.nn.Conv1d(100, 20, 1025 , stride=1, padding=0) | |
out_2a = downsample_2a(out_1x); print(out_2a.shape) | |
out_2b = downsample_2b(out_1x); print(out_2b.shape) | |
out_2c = downsample_2c(out_1x); print(out_2c.shape) | |
out_2d = downsample_2d(out_1x); print(out_2d.shape) | |
out_2e = downsample_2e(out_1x); print(out_2e.shape) | |
temp_2a = torch.zeros(1,50,1025); # print(temp_1a) | |
temp_2b = torch.zeros(1,25,1025); # print(temp_1b) | |
temp_2c = torch.zeros(1,20,1025); # print(temp_1c) | |
temp_2d = torch.zeros(1,20,1025); # print(temp_1c) | |
temp_2e = torch.zeros(1,20,1025); # print(temp_1c) | |
temp_2a[:,:,0:out_2a.shape[2]] = out_2a | |
temp_2b[:,:,0:out_2b.shape[2]] = out_2b | |
temp_2c[:,:,0:out_2c.shape[2]] = out_2c | |
temp_2d[:,:,0:out_2d.shape[2]] = out_2d | |
temp_2e[:,:,0:out_2e.shape[2]] = out_2e | |
out_2x = torch.cat((temp_2a, temp_2b, temp_2c, temp_2d, temp_2e), dim=1) | |
print(f'encoder_lyr2-------{out_2x.shape}') | |
# DECODER | |
upsample_1a = torch.nn.ConvTranspose1d(135, 50, 5, stride=1, padding=0) | |
upsample_1b = torch.nn.ConvTranspose1d(135, 25, 50, stride=1, padding=0) | |
upsample_1c = torch.nn.ConvTranspose1d(135, 20, 256, stride=1, padding=0) | |
upsample_1d = torch.nn.ConvTranspose1d(135, 20, 512, stride=1, padding=0) | |
upsample_1e = torch.nn.ConvTranspose1d(135, 20, 1025, stride=1, padding=0) | |
out_3a = upsample_1a(out_2x); print(out_3a.shape) | |
out_3b = upsample_1b(out_2x); print(out_3b.shape) | |
out_3c = upsample_1c(out_2x); print(out_3c.shape) | |
out_3d = upsample_1d(out_2x); print(out_3d.shape) | |
out_3e = upsample_1e(out_2x); print(out_3e.shape) | |
temp_3a = torch.zeros(1,50,2049); # print(temp_1a) | |
temp_3b = torch.zeros(1,25,2049); # print(temp_1b) | |
temp_3c = torch.zeros(1,20,2049); # print(temp_1c) | |
temp_3d = torch.zeros(1,20,2049); # print(temp_1c) | |
temp_3e = torch.zeros(1,20,2049); # print(temp_1c) | |
temp_3a[:,:,0:out_3a.shape[2]] = out_3a | |
temp_3b[:,:,0:out_3b.shape[2]] = out_3b | |
temp_3c[:,:,0:out_3c.shape[2]] = out_3c | |
temp_3d[:,:,0:out_3d.shape[2]] = out_3d | |
temp_3e[:,:,0:out_3e.shape[2]] = out_3e | |
out_3x = torch.cat((temp_3a, temp_3b, temp_3c, temp_3d, temp_3e), dim=1) | |
print(f'encoder_lyr3-------{out_3x.shape}') | |
upsample_2a = torch.nn.ConvTranspose1d(135, 20, 5, stride=1, padding=0) | |
upsample_2b = torch.nn.ConvTranspose1d(135, 20, 50, stride=1, padding=0) | |
upsample_2c = torch.nn.ConvTranspose1d(135, 20, 256, stride=1, padding=0) | |
upsample_2d = torch.nn.ConvTranspose1d(135, 20, 512, stride=1, padding=0) | |
upsample_2e = torch.nn.ConvTranspose1d(135, 20, 1025, stride=1, padding=0) | |
out_4a = upsample_2a(out_3x); print(out_4a.shape) | |
out_4b = upsample_2b(out_3x); print(out_4b.shape) | |
out_4c = upsample_2c(out_3x); print(out_4c.shape) | |
out_4d = upsample_2d(out_3x); print(out_4d.shape) | |
out_4e = upsample_2e(out_3x); print(out_4e.shape) | |
m = out_4e.shape[2] | |
temp_4a = torch.zeros(1,20,m); # print(temp_1a) | |
temp_4b = torch.zeros(1,20,m); # print(temp_1b) | |
temp_4c = torch.zeros(1,20,m); # print(temp_1c) | |
temp_4d = torch.zeros(1,20,m); # print(temp_1c) | |
temp_4e = torch.zeros(1,20,m); # print(temp_1c) | |
temp_4a[:,:,0:out_4a.shape[2]] = out_4a | |
temp_4b[:,:,0:out_4b.shape[2]] = out_4b | |
temp_4c[:,:,0:out_4c.shape[2]] = out_4c | |
temp_4d[:,:,0:out_4d.shape[2]] = out_4d | |
temp_4e[:,:,0:out_4e.shape[2]] = out_4e | |
out_4x = torch.cat((temp_4a, temp_4b, temp_4c, temp_4d, temp_4e), dim=1) | |
print(f'encoder_lyr4-------{out_4x.shape}') | |
#### OUTPUT layer | |
output_layer = torch.nn.ConvTranspose1d(100, 2, 1025, stride=1, padding=0) | |
out_5x = output_layer(out_4x) | |
print(f'encoder_lyr5-------{out_5x.shape}') | |
# import numpy as np | |
# print(out_3e[:,:,1026:-1]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Output result: