Created
February 9, 2024 20:24
-
-
Save satyaog/eb664202daba42dfb60d0d9bf2883c5f to your computer and use it in GitHub Desktop.
MilaDatasetBuilder
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from dataclasses import dataclass | |
import io | |
import os | |
from pathlib import Path | |
import re | |
import tarfile | |
import datasets | |
from datasets.data_files import DataFilesDict | |
from datasets.download.download_manager import DownloadManager | |
from datasets.features import Features | |
from datasets.info import DatasetInfo | |
import numpy as np | |
from huggingface_hub import HfFileSystem | |
import huggingface_hub | |
def strip(string:str, chars=r"\s"): | |
return re.sub(f"{chars}+$", "", string) | |
class SplittedFile(io.RawIOBase): | |
@dataclass | |
class Split: | |
name: str | |
pos: int | |
size: int | |
def __init__(self, filesplits: list, mode: str = "rb") -> None: | |
super().__init__() | |
self._splits = [ | |
self.Split(fn, 0, Path(fn).stat().st_size) | |
for fn in filesplits | |
] | |
self._size = 0 | |
for split in self._splits: | |
split.pos = self._size | |
self._size += split.size | |
self._file: io.IOBase = None | |
self._mode = mode | |
self._split_index = None | |
def __enter__(self): | |
if self.closed: | |
self._open_split(0) | |
return self | |
def __exit__(self, *args, **kwargs): | |
del args, kwargs | |
self.close() | |
def close(self) -> None: | |
if not self.closed: | |
self._file.close() | |
self._file = None | |
self._split_index = None | |
@property | |
def closed(self): | |
return self._file is None or self._file.closed | |
@property | |
def _current_split(self) -> "SplittedFile.Split | None": | |
return self._splits[self._split_index] if self._split_index is not None else None | |
def flush(self) -> None: | |
pass | |
def isatty(self) -> bool: | |
return False | |
def readable(self) -> bool: | |
return True | |
def read(self, size: int = -1) -> bytes | None: | |
buffer = np.empty(size if size > -1 else self._size, dtype="<u1") | |
size = self.readinto(memoryview(buffer)) | |
return bytes(buffer[:size]) | |
def readall(self) -> bytes: | |
return self.read(-1) | |
def readinto(self, buffer: io.IOBase) -> int | None: | |
if not isinstance(buffer, memoryview): | |
buffer = memoryview(buffer) | |
cum_bytes_read = 0 | |
while cum_bytes_read < len(buffer): | |
bytes_read = self._file.readinto(buffer[cum_bytes_read:]) | |
cum_bytes_read += bytes_read | |
if not bytes_read: | |
if self._split_index + 1 >= len(self._splits): | |
break | |
# Open the next split to read from | |
self._open_split(self._split_index + 1) | |
return cum_bytes_read | |
def seek(self, offset: int, whence: int = io.SEEK_SET) -> int: | |
if self.closed: | |
raise ValueError("seek of closed file") | |
if whence == io.SEEK_CUR: | |
offset += self.tell() | |
elif whence == io.SEEK_END: | |
offset += self._size | |
for i, split in enumerate(self._splits): | |
if offset < split.pos + split.size: | |
self._open_split(i, offset - split.pos) | |
break | |
return offset | |
def seekable(self) -> bool: | |
return True | |
def tell(self) -> int: | |
if self.closed: | |
raise ValueError("I/O operation on closed file") | |
return self._current_split.pos + self._file.tell() | |
def writable(self) -> bool: | |
return False | |
def _open_split(self, split_index, split_offset=0) -> None: | |
split = self._splits[split_index] | |
if split_index != self._split_index: | |
self.close() | |
self._file = open(split.name, self._mode) | |
self._split_index = split_index | |
self._file.seek(split_offset) | |
class MilaDatasetBuilder(datasets.GeneratorBasedBuilder): | |
DEFAULT_VERSION = "main" | |
def __init__( | |
self, cache_dir: str | None = None, dataset_name: str | None = None, | |
config_name: str | None = None, hash: str | None = None, base_path: | |
str | None = None, info: DatasetInfo | None = None, features: | |
Features | None = None, token: bool | str | None = None, | |
use_auth_token="deprecated", repo_id: str | None = None, data_files: | |
str | list | dict | DataFilesDict | None = None, data_dir: str | | |
None = None, storage_options: dict | None = None, writer_batch_size: | |
int | None = None, name="deprecated", | |
**config_kwargs): | |
if not dataset_name: | |
dataset_name = repo_id.replace("/", "___") | |
if not base_path and config_name: | |
base_path = config_name | |
if not os.path.isdir(base_path or ""): | |
base_path = f"hf://datasets/{repo_id}@{config_kwargs.get('version', self.DEFAULT_VERSION)}/{base_path or ''}".rstrip("/") | |
if any(isinstance(data_files, t) for t in (dict, DataFilesDict)): | |
if "*" in data_files and datasets.Split.ALL not in data_files: | |
data_files[datasets.Split.ALL] = data_files["*"] | |
del data_files["*"] | |
super().__init__(cache_dir, dataset_name, config_name, hash, base_path, info, features, token, use_auth_token, repo_id, data_files, data_dir, storage_options, writer_batch_size, name, **config_kwargs) | |
self._cache_downloaded_dir = str(Path(self._cache_downloaded_dir) / self.repo_id / self._version()) | |
def _build_cache_dir(self): | |
cache_dir = super()._build_cache_dir() | |
version = self._version() | |
return str( | |
Path( | |
strip( | |
"/".join(cache_dir.split(self.dataset_name)[0:-1]), | |
"[^a-zA-Z-]" | |
) | |
) / self.repo_id / version | |
) | |
def _info(self) -> DatasetInfo: | |
"""Construct the DatasetInfo object. See `DatasetInfo` for details. | |
Warning: This function is only called once and the result is cached for all | |
following .info() calls. | |
Returns: | |
info: (DatasetInfo) The dataset information | |
""" | |
return DatasetInfo() | |
def _split_generators(self, dl_manager: DownloadManager): | |
"""Specify feature dictionary generators and dataset splits. | |
This function returns a list of `SplitGenerator`s defining how to generate | |
data and what splits to use. | |
Example: | |
return [ | |
datasets.SplitGenerator( | |
name=datasets.Split.TRAIN, | |
gen_kwargs={'file': 'train_data.zip'}, | |
), | |
datasets.SplitGenerator( | |
name=datasets.Split.TEST, | |
gen_kwargs={'file': 'test_data.zip'}, | |
), | |
] | |
The above code will first call `_generate_examples(file='train_data.zip')` | |
to write the train data, then `_generate_examples(file='test_data.zip')` to | |
write the test data. | |
Datasets are typically split into different subsets to be used at various | |
stages of training and evaluation. | |
Note that for datasets without a `VALIDATION` split, you can use a | |
fraction of the `TRAIN` data for evaluation as you iterate on your model | |
so as not to overfit to the `TEST` data. | |
For downloads and extractions, use the given `download_manager`. | |
Note that the `DownloadManager` caches downloads, so it is fine to have each | |
generator attempt to download the source data. | |
A good practice is to download all data in this function, and then | |
distribute the relevant parts to each split with the `gen_kwargs` argument | |
Args: | |
dl_manager (`DownloadManager`): | |
Download manager to download the data | |
Returns: | |
`list<SplitGenerator>`. | |
""" | |
downloaded_files = {s: dl_manager.download(files) for s, files in self.config.data_files.items()} | |
symlinks = {} | |
for s in self.config.data_files: | |
for _file, _downloaded_file in zip(self.config.data_files[s], downloaded_files[s]): | |
url_path = _file.split(self.repo_id)[-1] | |
url_path = "/".join(url_path.split("/")[1:]) | |
symlink = Path(_downloaded_file).parent / url_path | |
if not symlink.parent.exists(): | |
symlink.parent.mkdir() | |
if not symlink.exists(): | |
symlink.symlink_to(_downloaded_file) | |
symlinks.setdefault(s, []) | |
symlinks[s].append(str(symlink)) | |
return [ | |
datasets.SplitGenerator(name=s, gen_kwargs={"filepaths": files}) | |
for s, files in symlinks.items() | |
] | |
def _generate_examples(self, filepaths, **_kwargs): | |
"""Default function generating examples for each `SplitGenerator`. | |
This function preprocess the examples from the raw data to the preprocessed | |
dataset files. | |
This function is called once for each `SplitGenerator` defined in | |
`_split_generators`. The examples yielded here will be written on | |
disk. | |
Args: | |
**kwargs (additional keyword arguments): | |
Arguments forwarded from the SplitGenerator.gen_kwargs | |
Yields: | |
key: `str` or `int`, a unique deterministic example identification key. | |
* Unique: An error will be raised if two examples are yield with the | |
same key. | |
* Deterministic: When generating the dataset twice, the same example | |
should have the same key. | |
Good keys can be the image id, or line number if examples are extracted | |
from a text file. | |
The key will be hashed and sorted to shuffle examples deterministically, | |
such as generating the dataset multiple times keep examples in the | |
same order. | |
example: `dict<str feature_name, feature_value>`, a feature dictionary | |
ready to be encoded and written to disk. The example will be | |
encoded with `self.info.features.encode_example({...})`. | |
""" | |
id_ = 0 | |
with SplittedFile(filepaths) as sf: | |
tf = tarfile.open(fileobj=sf) | |
while True: | |
tarinfo = tf.next() | |
if tarinfo is None: | |
break | |
f = tf.extractfile(tarinfo) | |
if f is not None: | |
b = f.read() | |
import hashlib | |
print(f"{hashlib.md5(b).hexdigest()} {tarinfo.path}") | |
yield id_, {"filename":tarinfo.path, "bytes":b} | |
id_ += 1 | |
def _version(self): | |
return self.config.version if self.config.version > "0.0.0" else self.base_path.split(self.repo_id)[-1].strip("@") | |
if __name__ == "__main__": | |
MilaDatasetBuilder(repo_id="satyaortiz-gagne/bigearthnet", data_files={"S1":["S1/**.tar.gz*"], "S2":["S2/**.tar.gz*"]}).download_and_prepare() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment