Last active
March 31, 2020 19:18
-
-
Save danielhavir/3f026a9a8c68ecc9f341431d0761b2a1 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Full Example: https://gist.github.com/danielhavir/407a6cfd592dfc2ad1e23a1ed3539e07 | |
import os | |
from typing import Callable, List, Tuple, Generator, Dict | |
import torch | |
import torch.utils.data | |
from PIL.Image import Image as ImageType | |
def list_items_local(path: str) -> List[str]: | |
return sorted(os.path.splitext(f)[0] for f in os.listdir(path)) | |
class ImageDataset(torch.utils.data.Dataset): | |
def __init__(self, data_root: str, items: List[str], loader: Callable[[str], ImageType] = pil_loader, transform=None): | |
self.data_root = data_root | |
self.loader = loader | |
self.items = items | |
self.transform = transform | |
def __len__(self): | |
return len(self.items) | |
def __getitem__(self, item): | |
item_id = self.items[item] | |
image = self.loader(os.path.join(self.data_root, "images", item_id + ".jpg")) | |
label = self.loader(os.path.join(self.data_root, "labels", item_id + ".png")) | |
if self.transform is not None: | |
image, label = self.transform((image, label)) | |
return image, label | |
def get_local_dataloaders(local_data_root: str, batch_size: int = 8, transform: Callable = None, | |
test_ratio: float = 0.1, num_workers: int = 8) -> Dict[str, torch.utils.data.DataLoader]: | |
# Local training | |
local_items = list_items_local(os.path.join(local_data_root, "images")) | |
dataset = ImageDataset(local_data_root, local_items, transform=transform) | |
# Split using consistent hashing | |
train_indices, test_indices = consistent_train_test_split(local_items, test_ratio) | |
return { | |
"train": torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=torch.utils.data.SubsetRandomSampler(train_indices), | |
num_workers=num_workers), | |
"test": torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=torch.utils.data.SubsetRandomSampler(test_indices), | |
num_workers=num_workers) | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import io | |
from urllib.parse import urlparse | |
from PIL import Image | |
from google.cloud import storage | |
from google.api_core.retry import Retry | |
@Retry() | |
def gcs_pil_loader(uri: str) -> ImageType: | |
uri = urlparse(uri) | |
client = storage.Client() | |
bucket = client.get_bucket(uri.netloc) | |
b = bucket.blob(uri.path[1:], chunk_size=None) | |
image = Image.open(io.BytesIO(b.download_as_string())) | |
return image.convert("RGB") | |
@Retry() | |
def load_items_gcs(path: str) -> List[str]: | |
uri = urlparse(path) | |
client = storage.Client() | |
bucket = client.get_bucket(uri.netloc) | |
blobs = bucket.list_blobs(prefix=uri.path[1:], delimiter=None) | |
return sorted(os.path.splitext(os.path.basename(blob.name))[0] for blob in blobs) | |
def get_streamed_dataloaders(gcs_data_root: str, batch_size: int = 8, transform: Callable = None, | |
test_ratio: float = 0.1, num_workers: int = 8) -> Dict[str, torch.utils.data.DataLoader]: | |
# Streaming | |
streamed_items = load_items_gcs(os.path.join(gcs_data_root, "images")) | |
dataset = ImageDataset(gcs_data_root, streamed_items, loader=gcs_pil_loader, transform=transform) | |
# Identical for both local and streamed | |
# This is handy for CrossValidation, use consistent hashing | |
train_indices, test_indices = consistent_train_test_split(streamed_items, test_ratio) | |
return { | |
"train": torch.utils.data.DataLoader(dataset, batch_size=batch_size, | |
sampler=torch.utils.data.SubsetRandomSampler(train_indices), | |
num_workers=num_workers), | |
"test": torch.utils.data.DataLoader(dataset, batch_size=batch_size, | |
sampler=torch.utils.data.SubsetRandomSampler(test_indices), | |
num_workers=num_workers) | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
import asyncio | |
import aiohttp | |
from janus import Queue | |
from gcloud.aio.storage import Storage | |
def generate_stream(items: List[str]) -> Generator[str, None, None]: | |
while True: | |
# Python's randint has inclusive upper bound | |
index = random.randint(0, len(items) - 1) | |
yield items[index] | |
class AsyncImageDataset(torch.utils.data.IterableDataset): | |
def __init__(self, data_root: str, items: List[str], transform: Callable = None, concurrency: int = 64): | |
self.data_root = data_root | |
self.items = items | |
self.transform = transform | |
self.worker_initialized = False | |
self.loop_thread = None | |
self.q = None | |
self.creds = os.environ["GOOGLE_APPLICATION_CREDENTIALS"] | |
self.concurrency = concurrency | |
self.stream = generate_stream(self.items) | |
async def run(self, loop, session): | |
for item in self.stream: | |
try: | |
image_gs = urlparse(os.path.join(self.data_root, "images", item + ".jpg")) | |
label_gs = urlparse(os.path.join(self.data_root, "labels", item + ".png")) | |
aio_storage = Storage(service_file=self.creds, session=session) | |
blobs = await asyncio.gather( | |
aio_storage.download(image_gs.netloc, image_gs.path[1:]), | |
aio_storage.download(label_gs.netloc, label_gs.path[1:]), | |
loop=loop | |
) | |
image = Image.open(io.BytesIO(blobs[0])) | |
label = Image.open(io.BytesIO(blobs[1])).convert("RGB") | |
await self.q.async_q.put((image, label)) | |
except aiohttp.ClientError as e: | |
logging.debug(e) | |
except TimeoutError: | |
pass | |
except Exception as e: | |
logging.exception(e) | |
def init_worker(self): | |
loop = asyncio.new_event_loop() | |
session = aiohttp.ClientSession(loop=loop, connector=aiohttp.TCPConnector(limit=0, loop=loop), | |
raise_for_status=True) | |
self.q = Queue(self.concurrency, loop=loop) | |
# Spin up workers | |
for _ in range(self.concurrency): | |
loop.create_task(self.run(loop, session)) | |
def loop_in_thread(loop): | |
asyncio.set_event_loop(loop) | |
loop.run_forever() | |
self.loop_thread = Thread(target=loop_in_thread, args=(loop,), daemon=True) | |
self.loop_thread.start() | |
self.worker_initialized = True | |
def __iter__(self): | |
while True: | |
if not self.worker_initialized: | |
self.init_worker() | |
image, label = self.q.sync_q.get() | |
if self.transform is not None: | |
image, label = self.transform((image, label)) | |
yield image, label | |
def get_async_dataloaders(gcs_data_root: str, batch_size: int = 8, transform: Callable = None, | |
test_ratio: float = 0.1, num_workers: int = 8) -> Dict[str, torch.utils.data.DataLoader]: | |
# Async Streaming | |
streamed_items = load_items_gcs(os.path.join(gcs_data_root, "images")) | |
train_indices, test_indices = consistent_train_test_split(streamed_items, test_ratio) | |
train_items = [streamed_items[i.item()] for i in train_indices] | |
train_dataset = AsyncImageDataset(gcs_data_root, train_items, transform=transform, concurrency=128) | |
test_items = [streamed_items[i.item()] for i in test_indices] | |
test_dataset = AsyncImageDataset(gcs_data_root, test_items, transform=transform, concurrency=128) | |
return { | |
"train": torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, worker_init_fn=worker_init_fn, | |
num_workers=num_workers), | |
"test": torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, worker_init_fn=worker_init_fn, | |
num_workers=num_workers) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment