Last active
April 12, 2022 02:47
-
-
Save PkuRainBow/79481d39b3c499ec7aea597d15b4e280 to your computer and use it in GitHub Desktop.
HRNet+OCR
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | |
## Created by: RainbowSecret | |
## Modified from: https://github.com/AlexHex7/Non-local_pytorch | |
## Microsoft Research | |
## [email protected] | |
## Copyright (c) 2018 | |
## | |
## This source code is licensed under the MIT-style license found in the | |
## LICENSE file in the root directory of this source tree | |
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | |
import matplotlib | |
matplotlib.use('Agg') | |
import torch | |
import os | |
import sys | |
import pdb | |
import cv2 | |
import numpy as np | |
from torch import nn | |
from torch.nn import functional as F | |
import functools | |
import matplotlib.pyplot as plt | |
from sklearn import svm, datasets | |
from sklearn.model_selection import train_test_split | |
from sklearn.metrics import confusion_matrix | |
from PIL import Image as PILImage | |
torch_ver = torch.__version__[:3] | |
ignore_label = 255 | |
id_to_trainid = {-1: ignore_label, 0: ignore_label, 1: ignore_label, 2: ignore_label, | |
3: ignore_label, 4: ignore_label, 5: ignore_label, 6: ignore_label, | |
7: 0, 8: 1, 9: ignore_label, 10: ignore_label, 11: 2, 12: 3, 13: 4, | |
14: ignore_label, 15: ignore_label, 16: ignore_label, 17: 5, | |
18: ignore_label, 19: 6, 20: 7, 21: 8, 22: 9, 23: 10, 24: 11, 25: 12, 26: 13, 27: 14, | |
28: 15, 29: ignore_label, 30: ignore_label, 31: 16, 32: 17, 33: 18} | |
class_name_dict = {0:'road', 1:'sidewalk', 2:'building', 3:'wall', 4:'fence', 5:'pole', | |
6:'trafficlight', 7:'trafficsign', 8:'vegetation', 9:'terrian', 10:'sky', | |
11:'person', 12:'rider', 13:'car', 14:'truck', 15:'bus', 16:'train', | |
17:'motorcycle', 18:'bicycle', 255: 'none'} | |
def get_palette(num_cls): | |
""" Returns the color map for visualizing the segmentation mask. | |
Args: | |
num_cls: Number of classes | |
Returns: | |
The color map | |
""" | |
palette = [0] * (num_cls * 3) | |
palette[0:3] = (128, 64, 128) # 0: 'road' | |
palette[3:6] = (244, 35,232) # 1 'sidewalk' | |
palette[6:9] = (70, 70, 70) # 2''building' | |
palette[9:12] = (102,102,156) # 3 wall | |
palette[12:15] = (190,153,153) # 4 fence | |
palette[15:18] = (153,153,153) # 5 pole | |
palette[18:21] = (250,170, 30) # 6 'traffic light' | |
palette[21:24] = (220,220, 0) # 7 'traffic sign' | |
palette[24:27] = (107,142, 35) # 8 'vegetation' | |
palette[27:30] = (152,251,152) # 9 'terrain' | |
palette[30:33] = ( 70,130,180) # 10 sky | |
palette[33:36] = (220, 20, 60) # 11 person | |
palette[36:39] = (255, 0, 0) # 12 rider | |
palette[39:42] = (0, 0, 142) # 13 car | |
palette[42:45] = (0, 0, 70) # 14 truck | |
palette[45:48] = (0, 60,100) # 15 bus | |
palette[48:51] = (0, 80,100) # 16 train | |
palette[51:54] = (0, 0,230) # 17 'motorcycle' | |
palette[54:57] = (119, 11, 32) # 18 'bicycle' | |
palette[57:60] = (105, 105, 105) | |
return palette | |
palette = get_palette(20) | |
def id2trainId(label, id_to_trainid, reverse=False): | |
label_copy = label.copy() | |
if reverse: | |
for v, k in id_to_trainid.items(): | |
label_copy[label == k] = v | |
else: | |
for k, v in id_to_trainid.items(): | |
label_copy[label == k] = v | |
return label_copy | |
def down_sample_target(target, scale): | |
row, col = target.shape | |
step = scale | |
r_target = target[0:row:step, :] | |
c_target = r_target[:, 0:col:step] | |
return c_target | |
def visualize_map(atten, shape, out_path): | |
atten_np = atten.cpu().data.numpy() # c x hw | |
(h, w) = shape | |
for row in range(2): | |
for col in range(9): | |
# plt.subplot(5,8,9+row*8+col) | |
# pdb.set_trace() | |
cm = atten_np[row*8+col] | |
cm = np.reshape(cm, (h, w)) | |
plt.tight_layout() | |
plt.imshow(cm, cmap='Blues', interpolation='nearest') | |
plt.axis('off') | |
plt.savefig(out_path+'regionmap_'+str(row*8+col)+'png', bbox_inches='tight', pad_inches = 0) | |
pdb.set_trace() | |
def Vis_A2_Atten(img_path, | |
label_path, | |
image, | |
label, | |
atten, | |
shape, | |
cmap=plt.cm.Blues, | |
index=1, | |
choice=1, | |
maps_count=32): | |
""" | |
This function prints and plots the attention weight matrix. | |
Input: | |
choice: 1 represents plotting the histogram of the weights' distribution | |
2 represents plotting the attention weights' map | |
""" | |
atten_np = atten.cpu().data.numpy() # c x hw | |
(h, w) = shape | |
if choice == 1: | |
# read image/ label from the given paths | |
image = cv2.imread(img_path[index], cv2.IMREAD_COLOR) #1024x2048x3 | |
image = image[:, :, -1] | |
image = cv2.resize(image, dsize=(h, w),interpolation=cv2.INTER_CUBIC) | |
label = cv2.imread(label_path[index], cv2.IMREAD_GRAYSCALE) #1024x2048 | |
label = id2trainId(label, id_to_trainid) | |
label = down_sample_target(label, 8) | |
else: | |
# use the image crop directly. | |
image = image.astype(np.float)[index] #3x1024x2048 | |
image = np.transpose(image, (1,2,0)) | |
mean = (102.9801, 115.9465, 122.7717) | |
image += mean | |
image = image.astype(np.uint8) | |
image = cv2.resize(image, dsize=(w, h),interpolation=cv2.INTER_CUBIC) | |
label = label.cpu().numpy().astype(np.uint8)[index] | |
label = down_sample_target(label, 8) | |
img_label = PILImage.fromarray(label) | |
img_label.putpalette(palette) | |
plt.tight_layout() | |
plt.figure(figsize=(48, 24)) | |
plt.axis('off') | |
plt.subplot(5,8,1) | |
plt.imshow(image) | |
plt.axis('off') | |
plt.subplot(5,8,2) | |
plt.imshow(img_label) | |
plt.axis('off') | |
for row in range(4): | |
for col in range(8): | |
plt.subplot(5,8,9+row*8+col) | |
cm = atten_np[row*8+col] | |
cm = np.reshape(cm, (h, w)) | |
plt.imshow(cm, cmap='Blues', interpolation='nearest') | |
plt.axis('off') | |
plt.gca().set_title("Attention Map %d" %(row*8+col)) | |
# plt.subplot(3,7,1) | |
# plt.imshow(image) | |
# plt.axis('off') | |
# plt.subplot(3,7,2) | |
# plt.imshow(img_label) | |
# plt.axis('off') | |
# for row in range(3): | |
# for col in range(7): | |
# if (row*7+col) == 0 or (row*7+col) == 1: | |
# continue | |
# plt.subplot(3,7,row*7+col+1) | |
# cm = atten_np[row*7+col-2] | |
# cm = np.reshape(cm, (h, w)) | |
# plt.imshow(cm, cmap='Blues', interpolation='nearest') | |
# plt.axis('off') | |
# plt.gca().set_title("Attention Map %d" %(row*7+col-2)) | |
plt.show() | |
outpath='./object_context_vis/a2map_32/' | |
plt.savefig(outpath+'a2map_'+str(img_path[0][0:-3].split('/')[-1])+'png', bbox_inches='tight', pad_inches = 0) | |
print("image id: {}".format(img_path[0][0:-3].split('/')[-1])) | |
def Vis_FastOC_Atten(img_path, | |
label_path, | |
image, | |
label, | |
atten, | |
shape, | |
cmap=plt.cm.Blues, | |
index=1, | |
choice=1, | |
subplot=False): | |
""" | |
This function prints and plots the attention weight matrix. | |
Input: | |
choice: 1 represents plotting the histogram of the weights' distribution | |
2 represents plotting the attention weights' map | |
""" | |
atten_np = atten.cpu().data.numpy() # c x hw | |
(h, w) = shape | |
if choice == 1: | |
# read image/ label from the given paths | |
image = cv2.imread(img_path[index], cv2.IMREAD_COLOR) #1024x2048x3 | |
image = image[:, :, -1] | |
image = cv2.resize(image, dsize=(h, w),interpolation=cv2.INTER_CUBIC) | |
label = cv2.imread(label_path[index], cv2.IMREAD_GRAYSCALE) #1024x2048 | |
label = id2trainId(label, id_to_trainid) | |
label = down_sample_target(label, 8) | |
else: | |
# use the image crop directly. | |
image = image.astype(np.float)[index] #3x1024x2048 | |
image = np.transpose(image, (1,2,0)) | |
mean = (102.9801, 115.9465, 122.7717) | |
image += mean | |
image = image.astype(np.uint8) | |
image = cv2.resize(image, dsize=(w, h),interpolation=cv2.INTER_CUBIC) | |
label = label.cpu().numpy().astype(np.uint8)[index] | |
label = down_sample_target(label, 8) | |
img_label = PILImage.fromarray(label) | |
img_label.putpalette(palette) | |
plt.tight_layout() | |
plt.figure(figsize=(48, 24)) | |
plt.axis('off') | |
if subplot: | |
plt.subplot(3,7,1) | |
plt.imshow(image) | |
plt.axis('off') | |
plt.subplot(3,7,2) | |
plt.imshow(img_label) | |
plt.axis('off') | |
for row in range(3): | |
for col in range(7): | |
if (row*7+col) == 0 or (row*7+col) == 1: | |
continue | |
if subplot: | |
plt.subplot(3,7,row*7+col+1) | |
cm = atten_np[row*7+col-2] | |
cm = np.reshape(cm, (h, w)) | |
plt.imshow(cm, cmap='Blues', interpolation='nearest') | |
plt.axis('off') | |
if not subplot: | |
plt.show() | |
outpath='./object_context_vis/fast_baseoc_map/' | |
plt.savefig(outpath+'fast_baseoc_map_'+str(img_path[0][0:-3].split('/')[-1])+'_'+str(row*7+col-2)+'.png', bbox_inches='tight', pad_inches = 0) | |
else: | |
plt.gca().set_title("Attention Map %d" %(row*7+col-2)) | |
if subplot: | |
plt.show() | |
outpath='./object_context_vis/fast_baseoc_map/' | |
plt.savefig(outpath+'fast_baseoc_map_'+str(img_path[0][0:-3].split('/')[-1])+'png', bbox_inches='tight', pad_inches = 0) | |
print("image id: {}".format(img_path[0][0:-3].split('/')[-1])) | |
# usage example | |
class SpatialGather_Module(nn.Module): | |
""" | |
Aggregate the context features according to the initial predicted probability distribution. | |
Employ the soft-weighted method to aggregate the context. | |
""" | |
def __init__(self, cls_num=0, scale=1): | |
super(SpatialGather_Module, self).__init__() | |
self.cls_num = cls_num | |
self.scale = scale | |
self.relu = nn.ReLU(inplace=True) | |
def forward(self, feats, probs): | |
batch_size, c, h, w = probs.size(0), probs.size(1), probs.size(2), probs.size(3) | |
probs = probs.view(batch_size, c, -1) | |
feats = feats.view(batch_size, feats.size(1), -1) | |
feats = feats.permute(0, 2, 1) # batch x hw x c | |
# probs = F.normalize(probs, p=2, dim=1) | |
probs = F.softmax(self.scale * probs, dim=2)# batch x k x hw | |
cc = torch.matmul(probs, feats)# batch x k x c | |
return cc.permute(0, 2, 1).unsqueeze(3) | |
class _ObjectAttentionBlock(nn.Module): | |
''' | |
The basic implementation for self-attention block/non-local block | |
Input: | |
N X C X H X W | |
Parameters: | |
in_channels : the dimension of the input feature map | |
key_channels : the dimension after the key/query transform | |
value_channels : the dimension after the value transform | |
scale : choose the scale to downsample the input feature maps (save memory cost) | |
Return: | |
N X C X H X W | |
position-aware context features.(w/o concate or add with the input) | |
''' | |
def __init__(self, in_channels, key_channels, scale=1, bn_type=None): | |
super(_ObjectAttentionBlock, self).__init__() | |
self.scale = scale | |
self.in_channels = in_channels | |
self.key_channels = key_channels | |
self.pool = nn.MaxPool2d(kernel_size=(scale, scale)) | |
self.f_pixel = nn.Sequential( | |
nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, | |
kernel_size=1, stride=1, padding=0), | |
ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type), | |
nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels, | |
kernel_size=1, stride=1, padding=0), | |
ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type), | |
) | |
self.f_object = nn.Sequential( | |
nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, | |
kernel_size=1, stride=1, padding=0), | |
ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type), | |
nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels, | |
kernel_size=1, stride=1, padding=0), | |
ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type), | |
) | |
self.f_down = nn.Sequential( | |
nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, | |
kernel_size=1, stride=1, padding=0, bias=False), | |
ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type), | |
) | |
self.f_up = nn.Sequential( | |
nn.Conv2d(in_channels=self.key_channels, out_channels=self.in_channels, | |
kernel_size=1, stride=1, padding=0, bias=False), | |
ModuleHelper.BNReLU(self.in_channels, bn_type=bn_type), | |
) | |
def forward(self, x, proxy): | |
batch_size, h, w = x.size(0), x.size(2), x.size(3) | |
if self.scale > 1: | |
x = self.pool(x) | |
query = self.f_pixel(x).view(batch_size, self.key_channels, -1) | |
query = query.permute(0, 2, 1) | |
key = self.f_object(proxy).view(batch_size, self.key_channels, -1) | |
value = self.f_down(proxy).view(batch_size, self.key_channels, -1) | |
value = value.permute(0, 2, 1) | |
sim_map = torch.matmul(query, key) | |
sim_map = (self.key_channels**-.5) * sim_map | |
sim_map = F.softmax(sim_map, dim=-1) | |
# visualize the assignment maps | |
# assign_map = sim_map[0].permute(1, 0) | |
# assign_map = assign_map.view(19, h, w) | |
# from lib.vis.attention_visualizer import visualize_map | |
# visualize_map(assign_map, [h, w], | |
# out_path="/msravcshare/yuyua/code/segmentation/openseg.pytorch/visualize/assign_maps/") | |
context = torch.matmul(sim_map, value) | |
context = context.permute(0, 2, 1).contiguous() | |
context = context.view(batch_size, self.key_channels, *x.size()[2:]) | |
context = self.f_up(context) | |
if self.scale > 1: | |
context = F.interpolate(input=context, size=(h, w), mode='bilinear', align_corners=True) | |
return context |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment