Last active
September 17, 2020 09:06
-
-
Save hengck23/6ebe1c75f8b3bcc953c0599ac76bad45 to your computer and use it in GitHub Desktop.
model
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from common import * | |
num_class = 266 | |
# https://github.com/qiuqiangkong/audioset_tagging_cnn/blob/8d9999d72b282d2dc50a5b5f668dd91369f853c5/pytorch/models.py | |
# https://www.kaggle.com/hidehisaarai1213/introduction-to-sound-event-detection | |
# https://github.com/qiuqiangkong/sed_from_wekaly_labelled_data/blob/master/spectrogram_to_wave.py | |
class ConvBlock(nn.Module): | |
def __init__(self, in_channel, out_channel, pool_size=1): | |
super(ConvBlock, self).__init__() | |
self.pool_size = pool_size | |
self.conv1 = nn.Conv2d(in_channels=in_channel, | |
out_channels=out_channel, | |
kernel_size=(3, 3), stride=(1, 1), | |
padding=(1, 1), bias=False) | |
self.conv2 = nn.Conv2d(in_channels=out_channel, | |
out_channels=out_channel, | |
kernel_size=(3, 3), stride=(1, 1), | |
padding=(1, 1), bias=False) | |
self.bn1 = nn.BatchNorm2d(out_channel) | |
self.bn2 = nn.BatchNorm2d(out_channel) | |
def forward(self, x): | |
x = F.relu(self.bn1(self.conv1(x)),inplace=True) | |
x = F.relu(self.bn2(self.conv2(x)),inplace=True) | |
if self.pool_size !=1 : | |
x = F.avg_pool2d(x, kernel_size=self.pool_size) | |
return x | |
class AttentPool(nn.Module): | |
def __init__(self, in_channel, out_channel): | |
super(AttentPool, self).__init__() | |
#self.temperature = temperature | |
self.conv = nn.Conv1d(in_channels=in_channel, out_channels=out_channel, kernel_size=1, stride=1, padding=0, bias=False) | |
self.bn = nn.BatchNorm1d(out_channel) | |
def forward(self, x): | |
# x: (batch_size, C, num_time) | |
#norm_att = torch.softmax(torch.clamp(self.att(x), -10, 10), dim=-1) | |
attention = self.bn(self.conv(x)) | |
attention = torch.tanh(attention/10)*10 | |
attention = torch.softmax(attention, dim=-1) | |
return attention | |
# type-1 roi | |
class ROI1(nn.Module): | |
def forward(self, x): | |
x = torch.mean(x, dim=3) | |
x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1) | |
x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1) | |
x = x1 + x2 | |
return x | |
#----------------------------------------------------------------------------- | |
# Cnn14_DecisionLevelAtt | |
class Net (nn.Module): | |
def load_pretrain(self, skip=[], is_print=True): | |
checkpoint = '/root/share1/kaggle/2020/birdsong/data/pretrain/Cnn14_DecisionLevelAtt_mAP0.425.pth' | |
state_dict = torch.load(checkpoint, map_location=lambda storage, loc: storage)['model'] | |
for k in list(state_dict.keys()): | |
if any(s in k for s in [ | |
'att_block','spectrogram_extractor','logmel_extractor','fc1', | |
]+skip): state_dict.pop(k, None) | |
self.load_state_dict(state_dict,strict=False) #True | |
def __init__(self): | |
super(Net, self).__init__() | |
self.bn0 = nn.BatchNorm2d(64) | |
self.conv_block1 = ConvBlock(in_channel= 1, out_channel= 64, pool_size=2) | |
self.conv_block2 = ConvBlock(in_channel= 64, out_channel= 128, pool_size=2) | |
self.conv_block3 = ConvBlock(in_channel= 128, out_channel= 256, pool_size=2) | |
self.conv_block4 = ConvBlock(in_channel= 256, out_channel= 512, pool_size=2) | |
self.conv_block5 = ConvBlock(in_channel= 512, out_channel=1024, pool_size=2) | |
self.conv_block6 = ConvBlock(in_channel=1024, out_channel=2048, pool_size=1) | |
self.roi = nn.Sequential( | |
ROI1(), | |
nn.Conv1d(2048, 2048, 1, bias=False), | |
nn.BatchNorm1d(2048), | |
nn.ReLU(inplace=True), | |
) | |
self.probability = nn.Conv1d(2048, num_class, kernel_size=1, bias=False) | |
self.attention = AttentPool(2048, num_class) | |
def forward(self,x): #( batch_size, 1, num_freq, num_frame) | |
batch_size, c, num_freq, num_frame = x.shape | |
x = x.permute(0,2,1,3).contiguous() #(batch_size,num_freq, c, num_frame) | |
x = self.bn0(x) | |
x = x.permute(0,2,3,1).contiguous() #(batch_size,c, num_frame, num_freq) | |
x = self.conv_block1(x) | |
#x = F.dropout(x, p=0.2, training=self.training) | |
x = self.conv_block2(x) | |
#x = F.dropout(x, p=0.2, training=self.training) | |
x = self.conv_block3(x) | |
#x = F.dropout(x, p=0.2, training=self.training) | |
x = self.conv_block4(x) | |
#x = F.dropout(x, p=0.2, training=self.training) | |
x = self.conv_block5(x) | |
#x = F.dropout(x, p=0.2, training=self.training) | |
x = self.conv_block6(x) #torch.Size([8, 2048, 15, 2]) | |
#x = F.dropout(x, p=0.2, training=self.training) | |
#-------- | |
# frame-wise rol | |
roi = self.roi(x) | |
#roi = F.dropout(roi, p=0.5, training=self.training) | |
#-------- | |
probability = torch.sigmoid(self.probability(roi)) # frame-wise | |
attention = self.attention(roi) #torch.Size([8, 266]), torch.Size([8, 266, 15]) | |
pool = torch.sum(attention * probability, dim=2) # clip-wise | |
return pool, attention | |
def binary_cross_entropy_with_logit_loss(logit,truth): | |
w = 1/num_class#0.5 # | |
batch_size = len(logit) | |
onehot = F.one_hot(truth, num_class).type(logit.dtype) | |
#loss = F.binary_cross_entropy_with_logits(logit,onehot) | |
num_p = onehot.sum().item() | |
num_n = (1-onehot).sum().item() | |
log_p = -F.logsigmoid( logit) | |
log_n = -F.logsigmoid(-logit) | |
loss_p = (onehot*log_p).sum()/num_p | |
loss_n = ((1-onehot)*log_n).sum()/num_n | |
loss = w*loss_p + (1-w)*loss_n | |
return loss | |
def binary_cross_entropy_loss(probability,truth): | |
w = 1/num_class#0.5 # | |
batch_size = len(probability) | |
onehot = F.one_hot(truth, num_class).type(truth.dtype) | |
#loss = F.binary_cross_entropy_with_logits(logit,onehot) | |
num_p = onehot.sum().item() | |
num_n = (1-onehot).sum().item() | |
probability = torch.clamp(probability,1e-5,1-1e-5) | |
log_p = -torch.log( probability) | |
log_n = -torch.log(1-probability) | |
loss_p = (onehot*log_p).sum()/num_p | |
loss_n = ((1-onehot)*log_n).sum()/num_n | |
loss = w*loss_p + (1-w)*loss_n | |
return loss | |
# check ################################################################# | |
def run_check_net(): | |
net = Net() | |
net.load_pretrain() | |
batch_size = 8 | |
num_frame = 501 | |
num_freq = 64 | |
melspec = torch.randn((batch_size, 1, num_freq, num_frame)) | |
probability, attention = net(melspec) | |
print('') | |
print('melspec: ',melspec.shape) | |
print('probability: ', probability.shape) | |
print('attention: ',attention.shape) | |
# main ################################################################# | |
if __name__ == '__main__': | |
print( '%s: calling main function ... ' % os.path.basename(__file__)) | |
run_check_net() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment