Skip to content

Instantly share code, notes, and snippets.

@rom1504
Created July 8, 2022 16:01
Show Gist options
  • Save rom1504/0fbc00fced63b15f9cb610c70ab401d7 to your computer and use it in GitHub Desktop.
Save rom1504/0fbc00fced63b15f9cb610c70ab401d7 to your computer and use it in GitHub Desktop.
Using webdataset and on disk tags
class WDSDataset(data.IterableDataset):
def __init__(self, min_size, transform=None, target_transform=None):
self.min_size = min_size
self.transform = transform if transform is not None else nn.Identity()
self.target_transform = target_transform if target_transform is not None else nn.Identity()
self.kv = OnDiskKV(file='/home/ubuntu/laion5B-watermark-safety-ordered', key_format='q', value_format='ee')
self.kv_aesthetic = OnDiskKV(file='/home/ubuntu/laion5B-aesthetic-tags-kv', key_format='q', value_format='e')
self.pwatermark_threshold = 0.8
self.punsafe_threshold = 0.5
self.aesthetic_threshold = 5.
self.total_samples = 0
self.samples = 0
location = 'pipe:aws s3 cp --quiet s3://s-datasets/laion5b/laion2B-data/{000000..231349}.tar -::pipe:aws s3 cp --quiet s3://s-datasets/laion5b/laion1B-nolang-data/{000000..127231}.tar -::pipe:aws s3 cp --quiet s3://s-datasets/laion5b/laion2B-multi-data/{000000..226687}.tar -'
self.inner_dataset = wds.DataPipeline(
wds.ResampledShards(location),
wds.tarfile_to_samples(handler=wds.warn_and_continue),
wds.shuffle(1000, handler=wds.warn_and_continue),
wds.decode('pilrgb', handler=wds.warn_and_continue),
wds.map(self._add_tags, handler=wds.ignore_and_continue),
wds.select(self._filter_predicate),
wds.map_dict(jpg=self.transform, txt=self.target_transform, punsafe=self._punsafe_to_class, handler=wds.warn_and_continue),
wds.to_tuple('jpg', 'txt', 'punsafe', handler=wds.warn_and_continue),
)
@staticmethod
def _compute_hash(url, text):
if url is None:
url = ''
if text is None:
text = ''
total = (url + text).encode('utf-8')
return mmh3.hash64(total)[0]
def _add_tags(self, x):
hsh = self._compute_hash(x['json']['url'], x['txt'])
pwatermark, punsafe = self.kv[hsh]
aesthetic = self.kv_aesthetic[hsh][0]
return {**x, 'pwatermark': pwatermark, 'punsafe': punsafe, 'aesthetic': aesthetic}
def _punsafe_to_class(self, punsafe):
return torch.tensor(punsafe >= self.punsafe_threshold).long()
def _filter_predicate(self, x):
try:
return x['pwatermark'] < self.pwatermark_threshold and x['aesthetic'] >= self.aesthetic_threshold and x['json']['original_width'] >= self.min_size and x['json']['original_height'] >= self.min_size
except:
return False
def __iter__(self):
return iter(self.inner_dataset) 
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment