Created
November 14, 2019 16:50
-
-
Save sizhky/87938dc01eb53691289a42785b35faca to your computer and use it in GitHub Desktop.
loader.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
__all__ = ['B','BB','C','choose','crop_from_bb','cv2', 'dumpdill','df2bbs', 'FName','glob','Glob', | |
'line','loaddill','logger','extn', 'np', 'now','os','pd','parent','Path','pdb', | |
'plt','puttext','randint', 'rand', 'read','rect','see','show','stem','tqdm','Tqdm'] | |
import cv2, glob, numpy as np, pandas as pd, tqdm, os | |
import matplotlib#; matplotlib.use('Agg') | |
import matplotlib.pyplot as plt | |
import matplotlib.patheffects as path_effects | |
import pdb, datetime, dill | |
from pathlib import Path | |
from loguru import logger | |
line = lambda N=66: print('='*N) | |
see = lambda *X: list(map(lambda x: print('='*66+'\n{}'.format(x)), X)) | |
def choose(List, n=1): | |
n = None if n == 1 else n | |
return np.random.choice(List, size=n) | |
rand = lambda : ''.join(choose(list('1234567890qwertyuiopasdfghjklzxcvbnmQWERTYUIOPASDFGHJKLZXCVBNM'), n=6)) | |
randint = lambda high: np.random.randint(high) | |
def Tqdm(x, total=None): | |
total = len(x) if total is None else total | |
return tqdm.tqdm(x, total=total) | |
now = lambda : str(datetime.datetime.now())[:-10].replace(' ', '_') | |
def read(fname, mode=0): | |
img = cv2.imread(str(fname), mode) | |
if mode == 1: img = img[...,::-1] # BGR to RGB | |
return img | |
def crop_from_bb(im, bb): | |
x,y,X,Y = bb | |
return im.copy()[y:Y,x:X] | |
def rect(im, bb, c=None, th=2): | |
c = (0,255,0) if c is None else c | |
x,y,X,Y = bb | |
cv2.rectangle(im, (x,y), (X,Y), c, th) | |
'Binarize Image' | |
def B(im, th=180): return 255*(im > th).astype(np.uint8) | |
'Make 3D image from 2D image' | |
def C(im): | |
'make bw into 3 channels' | |
if im.shape==3: return im | |
else: | |
return np.repeat(im[...,None], 3, 2) | |
makedir = lambda x: os.makedirs(x, exist_ok=True) | |
FName = lambda fpath: fpath.split('/')[-1] | |
def stem(fpath): return '.'.join(FName(fpath).split('.')[:-1]) | |
def parent(fpath): | |
out = '/'.join(fpath.split('/')[:-1]) | |
if out == '': return './' | |
else: return out | |
extn = lambda x: x.split('.')[-1] | |
def Glob(x, silent=False): | |
files = glob.glob(x+'/*') if '*' not in x else glob.glob(x) | |
if not silent: logger.info('{} files found at {}'.format(len(files), x)) | |
return files | |
def puttext(im, string, org, scale=1, color=(255,0,0), thickness=2): | |
x,y = org | |
org = x, int(y+30*scale) | |
cv2.putText(im, str(string), org, cv2.FONT_HERSHEY_COMPLEX, scale, color, thickness) | |
def show(img=None, ax=None, title=None, sz=None, bbs=None, texts=None, bb_colors=None, cmap='gray', grid=False, save_path=None, **kwargs): | |
'show an image' | |
img = np.copy(img) | |
if img.max() == 255: img = img.astype(np.uint8) | |
h, w = img.shape[:2] | |
if sz is None: | |
if w<300: sz=5 | |
elif w<600: sz=10 | |
else: sz=20 | |
if isinstance(sz, int): | |
sz = (sz, sz) | |
if ax is None: | |
fig, ax = plt.subplots(figsize=sz) | |
_show = True | |
else: _show = False | |
if isinstance(texts, pd.core.series.Series): texts = texts.tolist() | |
if texts: | |
assert len(texts) == len(bbs), 'Expecting as many texts as bounding boxes' | |
[puttext(ax, str(text).replace('$','\$'), tuple(bbs[ix][:2]), size=20) for ix,text in enumerate(texts)] | |
if bbs: | |
'rect-th' | |
if 'th' in kwargs: | |
th = kwargs.get('th') | |
kwargs.pop('th') | |
else: | |
if w<800: th=2 | |
elif w<1600: th=3 | |
else: th=4 | |
bb_colors = [None]*len(bbs) if bb_colors is None else bb_colors | |
img = C(img) if len(img.shape) == 2 else img | |
[rect(img, tuple(bb), c=bb_colors[ix], th=th) for ix,bb in enumerate(bbs)] | |
ax.imshow(img, cmap=cmap, **kwargs) | |
ax.set_title(title) | |
if not grid: ax.set_axis_off() | |
if save_path: | |
fig.savefig(save_path) | |
return | |
if _show: plt.show() | |
def puttext(ax, string, org, size=30, color=(255,0,0), thickness=2): | |
x,y = org | |
va = 'top' if y < 15 else 'bottom' | |
text = ax.text(x, y, str(string), color='red', ha='left', va=va, size=size) | |
text.set_path_effects([path_effects.Stroke(linewidth=3, foreground='white'), | |
path_effects.Normal()]) | |
def dumpdill(obj, fpath): | |
os.makedirs(parent(fpath), exist_ok=True) | |
with open(fpath, 'wb') as f: | |
dill.dump(obj, f) | |
logger.info('Dumped object @ {}'.format(fpath)) | |
def loaddill(fpath): | |
with open(fpath, 'rb') as f: | |
obj = dill.load(f) | |
return obj | |
class BB: | |
def __init__(self, bb): | |
assert len(bb) == 4, 'expecting a list/tuple of 4 values respectively for (x,y,X,Y)' | |
self.bb = x,y,X,Y = bb | |
self.x, self.y, self.X, self.Y = x,y,X,Y | |
self.h = Y-y | |
self.w = X-x | |
def __getitem__(self, i): return self.bb[i] | |
def __repr__(self): return self.bb.__repr__() | |
def __len__(self): return 4 | |
def __eq__(self, other): | |
return self.x == other.x and self.y == other.y and self.X == other.X and self.Y == other.Y | |
def __hash__(self): return hash(tuple(self)) | |
df2bbs = lambda df: [BB(bb) for bb in df[list('xyXY')].values.tolist()] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment