Skip to content

Instantly share code, notes, and snippets.

@dheshanm
Last active June 12, 2021 09:11
Show Gist options
  • Save dheshanm/24fe79d251ca8a4aaf4e74ddb97e067a to your computer and use it in GitHub Desktop.
Save dheshanm/24fe79d251ca8a4aaf4e74ddb97e067a to your computer and use it in GitHub Desktop.
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tensorflow as tf
from tqdm import tqdm
# import tensorflow_io as tfio
import cv2
from sklearn.metrics import accuracy_score, precision_score, recall_score
from tensorflow.keras import layers, losses
from tensorflow.keras.models import Model
physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)
from denoiser_v2 import *
from patcher import *
inp_img = "data/RH/19648321_RH.tif"
image = cv2.imread(inp_img, cv2.IMREAD_UNCHANGED)
output = np.empty(image.shape)
IMAGE_SIZE = 128
patcher = Patcher(image, IMAGE_SIZE, IMAGE_SIZE)
model = tf.keras.models.load_model('saved_model/denoiser_v1.1')
def prerocess_tile(tile):
# convert the compressed string to a 3D uint8 tensor
# print(file_path)
image = tile / 50000
image = tf.convert_to_tensor(image)
image = tf.expand_dims(image,2)
image = tf.expand_dims(image,0)
return image
from tqdm import tqdm
chunk_size = (patcher.numTilesX + 1) * (patcher.numTilesY + 1)
with tqdm(total=chunk_size) as pbar:
count = 0
tile = np.empty((IMAGE_SIZE, IMAGE_SIZE))
while (type(tile) != type(None)):
count = count + 1
(tile, startX, endX, startY, endY) = patcher.getNextTile()
if type(tile) != type(None):
processed_tile = prerocess_tile(tile)
processed_tile = model.encoder(processed_tile).numpy()
processed_tile = model.decoder(processed_tile).numpy()
processed_tile = processed_tile * 50000
processed_tile = tf.squeeze(processed_tile)
pbar.update(1)
try:
output[startY:endY, startX:endX] = processed_tile
# print("Count: ", count, " StartX :", startX, " endX: ", endX, " StartY: ", startY, "endY: ", endY)
# break
except:
print("Count: ", count, " StartX :", startX, " endX: ", endX, " StartY: ", startY, "endY: ", endY)
from PIL import Image
im = Image.fromarray(output)
im.save('output.tif', format='TIFF', compression=None)
import math
class Patcher():
def __init__(self, img, height, width):
self.data = img
self.height = height
self.width = width
self.numTilesX = math.ceil(img.shape[1]/height)
self.numTilesY = math.ceil(img.shape[0]/width)
self.makeLastPartFull = True
self.nTileX = 0
self.nTileY = 0
def getNextTile(self):
if (self.nTileY == (self.numTilesY + 1)):
return None, None, None, None, None
startX = self.nTileX*self.height
endX = startX + self.height
startY = self.nTileY*self.width
endY = startY + self.width;
if(endY > self.data.shape[0]):
endY = self.data.shape[0]
if(endX > self.data.shape[1]):
endX = self.data.shape[1]
if( self.makeLastPartFull == True and (self.nTileX == self.numTilesX-1 or self.nTileY == self.numTilesY-1) ):
startX = endX - self.height
startY = endY - self.width
currentTile = self.data[startY:endY, startX:endX]
self.nTileX = self.nTileX + 1
if (self.nTileX == (self.numTilesX + 1)):
self.nTileX = 0
self.nTileY = self.nTileY + 1
return (currentTile, startX, endX, startY, endY)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment