Created
July 8, 2022 16:01
-
-
Save rom1504/0fbc00fced63b15f9cb610c70ab401d7 to your computer and use it in GitHub Desktop.
Using webdataset and on disk tags
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class WDSDataset(data.IterableDataset): | |
def __init__(self, min_size, transform=None, target_transform=None): | |
self.min_size = min_size | |
self.transform = transform if transform is not None else nn.Identity() | |
self.target_transform = target_transform if target_transform is not None else nn.Identity() | |
self.kv = OnDiskKV(file='/home/ubuntu/laion5B-watermark-safety-ordered', key_format='q', value_format='ee') | |
self.kv_aesthetic = OnDiskKV(file='/home/ubuntu/laion5B-aesthetic-tags-kv', key_format='q', value_format='e') | |
self.pwatermark_threshold = 0.8 | |
self.punsafe_threshold = 0.5 | |
self.aesthetic_threshold = 5. | |
self.total_samples = 0 | |
self.samples = 0 | |
location = 'pipe:aws s3 cp --quiet s3://s-datasets/laion5b/laion2B-data/{000000..231349}.tar -::pipe:aws s3 cp --quiet s3://s-datasets/laion5b/laion1B-nolang-data/{000000..127231}.tar -::pipe:aws s3 cp --quiet s3://s-datasets/laion5b/laion2B-multi-data/{000000..226687}.tar -' | |
self.inner_dataset = wds.DataPipeline( | |
wds.ResampledShards(location), | |
wds.tarfile_to_samples(handler=wds.warn_and_continue), | |
wds.shuffle(1000, handler=wds.warn_and_continue), | |
wds.decode('pilrgb', handler=wds.warn_and_continue), | |
wds.map(self._add_tags, handler=wds.ignore_and_continue), | |
wds.select(self._filter_predicate), | |
wds.map_dict(jpg=self.transform, txt=self.target_transform, punsafe=self._punsafe_to_class, handler=wds.warn_and_continue), | |
wds.to_tuple('jpg', 'txt', 'punsafe', handler=wds.warn_and_continue), | |
) | |
@staticmethod | |
def _compute_hash(url, text): | |
if url is None: | |
url = '' | |
if text is None: | |
text = '' | |
total = (url + text).encode('utf-8') | |
return mmh3.hash64(total)[0] | |
def _add_tags(self, x): | |
hsh = self._compute_hash(x['json']['url'], x['txt']) | |
pwatermark, punsafe = self.kv[hsh] | |
aesthetic = self.kv_aesthetic[hsh][0] | |
return {**x, 'pwatermark': pwatermark, 'punsafe': punsafe, 'aesthetic': aesthetic} | |
def _punsafe_to_class(self, punsafe): | |
return torch.tensor(punsafe >= self.punsafe_threshold).long() | |
def _filter_predicate(self, x): | |
try: | |
return x['pwatermark'] < self.pwatermark_threshold and x['aesthetic'] >= self.aesthetic_threshold and x['json']['original_width'] >= self.min_size and x['json']['original_height'] >= self.min_size | |
except: | |
return False | |
def __iter__(self): | |
return iter(self.inner_dataset) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment