Skip to content

Instantly share code, notes, and snippets.

@sizhky
Created November 14, 2019 16:50
Show Gist options
  • Save sizhky/87938dc01eb53691289a42785b35faca to your computer and use it in GitHub Desktop.
Save sizhky/87938dc01eb53691289a42785b35faca to your computer and use it in GitHub Desktop.
loader.py
__all__ = ['B','BB','C','choose','crop_from_bb','cv2', 'dumpdill','df2bbs', 'FName','glob','Glob',
'line','loaddill','logger','extn', 'np', 'now','os','pd','parent','Path','pdb',
'plt','puttext','randint', 'rand', 'read','rect','see','show','stem','tqdm','Tqdm']
import cv2, glob, numpy as np, pandas as pd, tqdm, os
import matplotlib#; matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.patheffects as path_effects
import pdb, datetime, dill
from pathlib import Path
from loguru import logger
line = lambda N=66: print('='*N)
see = lambda *X: list(map(lambda x: print('='*66+'\n{}'.format(x)), X))
def choose(List, n=1):
n = None if n == 1 else n
return np.random.choice(List, size=n)
rand = lambda : ''.join(choose(list('1234567890qwertyuiopasdfghjklzxcvbnmQWERTYUIOPASDFGHJKLZXCVBNM'), n=6))
randint = lambda high: np.random.randint(high)
def Tqdm(x, total=None):
total = len(x) if total is None else total
return tqdm.tqdm(x, total=total)
now = lambda : str(datetime.datetime.now())[:-10].replace(' ', '_')
def read(fname, mode=0):
img = cv2.imread(str(fname), mode)
if mode == 1: img = img[...,::-1] # BGR to RGB
return img
def crop_from_bb(im, bb):
x,y,X,Y = bb
return im.copy()[y:Y,x:X]
def rect(im, bb, c=None, th=2):
c = (0,255,0) if c is None else c
x,y,X,Y = bb
cv2.rectangle(im, (x,y), (X,Y), c, th)
'Binarize Image'
def B(im, th=180): return 255*(im > th).astype(np.uint8)
'Make 3D image from 2D image'
def C(im):
'make bw into 3 channels'
if im.shape==3: return im
else:
return np.repeat(im[...,None], 3, 2)
makedir = lambda x: os.makedirs(x, exist_ok=True)
FName = lambda fpath: fpath.split('/')[-1]
def stem(fpath): return '.'.join(FName(fpath).split('.')[:-1])
def parent(fpath):
out = '/'.join(fpath.split('/')[:-1])
if out == '': return './'
else: return out
extn = lambda x: x.split('.')[-1]
def Glob(x, silent=False):
files = glob.glob(x+'/*') if '*' not in x else glob.glob(x)
if not silent: logger.info('{} files found at {}'.format(len(files), x))
return files
def puttext(im, string, org, scale=1, color=(255,0,0), thickness=2):
x,y = org
org = x, int(y+30*scale)
cv2.putText(im, str(string), org, cv2.FONT_HERSHEY_COMPLEX, scale, color, thickness)
def show(img=None, ax=None, title=None, sz=None, bbs=None, texts=None, bb_colors=None, cmap='gray', grid=False, save_path=None, **kwargs):
'show an image'
img = np.copy(img)
if img.max() == 255: img = img.astype(np.uint8)
h, w = img.shape[:2]
if sz is None:
if w<300: sz=5
elif w<600: sz=10
else: sz=20
if isinstance(sz, int):
sz = (sz, sz)
if ax is None:
fig, ax = plt.subplots(figsize=sz)
_show = True
else: _show = False
if isinstance(texts, pd.core.series.Series): texts = texts.tolist()
if texts:
assert len(texts) == len(bbs), 'Expecting as many texts as bounding boxes'
[puttext(ax, str(text).replace('$','\$'), tuple(bbs[ix][:2]), size=20) for ix,text in enumerate(texts)]
if bbs:
'rect-th'
if 'th' in kwargs:
th = kwargs.get('th')
kwargs.pop('th')
else:
if w<800: th=2
elif w<1600: th=3
else: th=4
bb_colors = [None]*len(bbs) if bb_colors is None else bb_colors
img = C(img) if len(img.shape) == 2 else img
[rect(img, tuple(bb), c=bb_colors[ix], th=th) for ix,bb in enumerate(bbs)]
ax.imshow(img, cmap=cmap, **kwargs)
ax.set_title(title)
if not grid: ax.set_axis_off()
if save_path:
fig.savefig(save_path)
return
if _show: plt.show()
def puttext(ax, string, org, size=30, color=(255,0,0), thickness=2):
x,y = org
va = 'top' if y < 15 else 'bottom'
text = ax.text(x, y, str(string), color='red', ha='left', va=va, size=size)
text.set_path_effects([path_effects.Stroke(linewidth=3, foreground='white'),
path_effects.Normal()])
def dumpdill(obj, fpath):
os.makedirs(parent(fpath), exist_ok=True)
with open(fpath, 'wb') as f:
dill.dump(obj, f)
logger.info('Dumped object @ {}'.format(fpath))
def loaddill(fpath):
with open(fpath, 'rb') as f:
obj = dill.load(f)
return obj
class BB:
def __init__(self, bb):
assert len(bb) == 4, 'expecting a list/tuple of 4 values respectively for (x,y,X,Y)'
self.bb = x,y,X,Y = bb
self.x, self.y, self.X, self.Y = x,y,X,Y
self.h = Y-y
self.w = X-x
def __getitem__(self, i): return self.bb[i]
def __repr__(self): return self.bb.__repr__()
def __len__(self): return 4
def __eq__(self, other):
return self.x == other.x and self.y == other.y and self.X == other.X and self.Y == other.Y
def __hash__(self): return hash(tuple(self))
df2bbs = lambda df: [BB(bb) for bb in df[list('xyXY')].values.tolist()]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment