Last active
June 8, 2019 15:47
-
-
Save leopd/4adf59135049641916d41efe50f0af16 to your computer and use it in GitHub Desktop.
PyTorch Dataset class to access line-delimited text files too big to hold in memory.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import lru_cache | |
import subprocess | |
from torch.utils.data import Dataset | |
class FileReaderDataset(Dataset): | |
"""Exposes a line-delimited text file as a PyTorch Dataset. | |
Maintains an LRU cache of lines it has read, while supporting random access into | |
files too large to hold in memory. Memory requirement still scales by O(N), but just | |
for pointers into the file, about 8 bytes per line. After the file has been scanned, | |
random access will be very fast - as fast as the disk plus the OS's cache of it. | |
""" | |
def __init__(self, filename:str, line_cache_size:int=1048576): | |
super().__init__() | |
self._filename = filename | |
self._filehandle = open(filename,"r") | |
self._pos = 0 | |
self._linenum = 0 | |
self._lineseeks = [0] # list of seek-byte-offset in file for every line we've read to. | |
self._cached_getitem = lru_cache(maxsize=line_cache_size)(self._getitem) | |
self._file_len = None | |
def _readnextline(self) -> str: | |
#print(f"Reading line {self._linenum}") | |
line = next(self._filehandle) | |
self._pos += len(line) | |
self._linenum += 1 | |
if len(self._lineseeks) == self._linenum: | |
self._lineseeks.append(self._pos) | |
return line | |
def _seektoline(self, linenum:int) -> None: | |
pos = self._lineseeks[linenum] | |
self._filehandle.seek(pos) | |
self._pos = pos | |
self._linenum = linenum | |
def __getitem__(self, n:int) -> str: | |
return self._cached_getitem(n) | |
def _getitem(self, n:int) -> str: | |
"""Uncached version of __getitem__ | |
""" | |
if n == self._linenum: | |
# Next line, just read it | |
return self._readnextline() | |
if n < len(self._lineseeks): | |
# Seek back. | |
self._seektoline(n) | |
return self._readnextline() | |
# Seek forward, reading. | |
while self._linenum < n: | |
#print(f"Seeking {self._linenum} to {n}") | |
#NOTE: This isn't caching the lines we scan through, but that logic is a bit tricky. | |
self._readnextline() | |
assert n == self._linenum | |
return self._readnextline() | |
def __len__(self) -> int: | |
if self._file_len is None: | |
out = subprocess.check_output(["wc", "-l", self._filename]) | |
numstr, _ = out.split(b" ") # out looks like "1234 your_filename" | |
self._file_len = int(numstr) | |
return self._file_len |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment