Skip to content

Instantly share code, notes, and snippets.

@ronghanghu
Created August 14, 2017 16:53
Show Gist options
  • Save ronghanghu/71a1cb8dbda406058d3a78cbfd89e8d1 to your computer and use it in GitHub Desktop.
Save ronghanghu/71a1cb8dbda406058d3a78cbfd89e8d1 to your computer and use it in GitHub Desktop.
ImageNET 22k Laser test
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import numpy as np
import logging
import os
import cv2
from laser.client.py import LaserClient, Options
logger = logging.getLogger(__name__)
LASER_PROVIDER_CFG = {
'imagenet1k': 'fair_laser_imagenet_full_size',
'imagenet5k': 'fair_laser_imagenet_5k',
'imagenet22k': 'fair_laser_imagenet22k_1',
}
# To fetch images from laser, we need the following prefix
LASER_DATA_PREFIX = {
'imagenet1k': '/imagenet_full_size/',
'imagenet5k': '/imagenet_5k/',
'imagenet22k': '/imagenet22k/',
}
class LaserLoader():
def __init__(self, dataset, split):
logger.info('Using LASER')
self.split = split
self.dataset = dataset
# Much faster!
opts = Options(ignoreMemcache=True)
self.client = LaserClient(
provider=LASER_PROVIDER_CFG[dataset], options=opts
)
def convert_path(self, image_path):
prefix = LASER_DATA_PREFIX[self.dataset]
key = image_path.split('/')[-2:]
key = '/'.join((j for j in key))
if self.dataset == 'imagenet22k':
key = key.replace('1k_', '')
new_img_path = os.path.join(prefix, key)
else:
new_img_path = os.path.join(prefix, self.split, key)
image_path = new_img_path
return new_img_path
def imread(self, image_path):
image_path = self.convert_path(image_path)
result = self.client.get(image_path)
img_array = np.fromstring(result.value(), dtype=np.uint8)
img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
return img_array
# # TODO: Use multithreaded queue to hide second-long retries
# def read(self, image_path):
# img_path = convert_path(image_path)
# # image_paths = map(self.convert_path, image_paths)
# result = self.client.multiget([image_paths])
# return map(lambda x: x.value(), result)
loader = LaserLoader('imagenet22k', 'train')
loader.imread('/home/prigoyal/local/imagenet_22k/train/n07877187/n07877187_5.JPEG')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment