Created
May 10, 2023 13:28
-
-
Save mehdidc/199b67dc18d40e10bef5eccab247efa8 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import io | |
import tarfile | |
import random | |
from collections import defaultdict | |
from lxml import etree | |
import uuid | |
from PIL import Image, ImageDraw | |
from glob import glob | |
import time | |
import os | |
import json | |
import webdataset as wds | |
from subprocess import call | |
import shutil | |
from PIL import Image | |
def get_wds(path): | |
# path = "/p/fastdata/datasets/pubmed/raw/00000.tar" | |
# ds = wds.WebDataset(path).compose(preprocess_pmc) | |
# ds = wds.DataPipeline( | |
# wds.SimpleShardList(path), | |
# wds.split_by_worker, | |
# wds.tarfile_to_samples(), | |
# ) | |
# for x in ds: | |
# print(x['json'], x['__key__']) | |
# sys.exit(0) | |
# ds = wds.SimpleShardList(path, splitter=wds.split_by_worker) | |
# loader = wds.WebLoader(ds, num_workers=4, batch_size=2, collate_fn=lambda x:x, ) | |
bs = 8 | |
ds = wds.DataPipeline( | |
wds.SimpleShardList(path), | |
wds.split_by_worker, | |
wds.tarfile_to_samples(), | |
(preprocess_pmc), | |
wds.rename(image="jpg;png",txt="txt"), | |
wds.to_tuple("key", "image", "txt"), | |
# wds.to_tuple("__key__"), | |
# wds.to_tuple("json"), | |
wds.batched(bs), | |
) | |
loader = wds.WebLoader(ds, num_workers=bs, batch_size=None, collate_fn=lambda x:x, persistent_workers=False) | |
# print(path) | |
# loader = wds.WebDataset(path) | |
# nb = 0 | |
# uniq = set() | |
# t0 = time.time() | |
# i = 0 | |
# for t, in loader: | |
# nb += len(t) | |
# if i % 1000 == 0: | |
# dt = time.time() - t0 | |
# print(nb, dt, nb/dt) | |
# i += 1 | |
# sys.exit(0) | |
return loader | |
def preprocess_pmc(src): | |
for sample in src: | |
try: | |
t0 = time.time() | |
#print(sample.keys()) | |
#print(sample['json']) | |
js = json.loads(sample['json']) | |
K = js['url'].replace('/', '_') | |
filename = js['url'] | |
t0 = time.time() | |
desc = io.BytesIO(sample['flac']) | |
tf = tarfile.open(fileobj=desc, mode="r:gz") | |
by_name_and_ext = {} | |
names =set() | |
by_ext = defaultdict(list) | |
members = {} | |
for member in tf.getmembers(): | |
f = member.name | |
members[f] = member | |
name = os.path.basename(os.path.splitext(f)[0]) | |
ext = os.path.basename(os.path.splitext(f)[1]) | |
by_name_and_ext[(name, ext)] = f | |
names.add(name) | |
by_ext[ext].append(f) | |
if not len(by_ext['.nxml']): | |
continue | |
xml_file = by_ext['.nxml'][0] | |
xml_file = tf.extractfile(members[xml_file]) | |
tree = etree.parse(xml_file) | |
fig_tags = tree.xpath('//fig') | |
if len(fig_tags) == 0: | |
continue | |
nb = 0 | |
for tag in fig_tags: | |
captions = tag.findall("caption") | |
captions = ([get_text(c) for c in captions]) | |
imgs = tag.findall("graphic") | |
imgs = [get_href(i) for i in imgs] | |
if len(imgs) != 1 or len(captions) != 1: | |
continue | |
img_name = imgs[0] | |
caption = captions[0] | |
if img_name not in names: | |
continue | |
if (img_name, '.jpg') in by_name_and_ext: | |
filename = by_name_and_ext[(img_name, '.jpg')] | |
elif (img_name, '.png') in by_name_and_ext: | |
filename = by_name_and_ext[(img_name, '.png')] | |
else: | |
continue | |
data = tf.extractfile(members[filename]) | |
data = data.read() | |
# data = io.BytesIO(data) | |
# img = Image.open(data) | |
img = data | |
ext = os.path.splitext(filename)[1][1:] | |
key = K + "-" + img_name | |
yield {"key":key, "__key__": key, ext: img, "txt": caption} | |
nb += 1 | |
except Exception as ex: | |
print(ex) | |
continue | |
class ShuffledIter: | |
def __init__(self, data): | |
self.data = data | |
def __iter__(self): | |
while True: | |
random.shuffle(self.data) | |
yield from self.data | |
def get_text(node): | |
return ''.join(node.itertext()) | |
def get_href(i): | |
if 'xlink:href' in i.attrib: | |
return i.attrib['xlink:href'] | |
elif '{http://www.w3.org/1999/xlink}href' in i.attrib: | |
return i.attrib['{http://www.w3.org/1999/xlink}href'] | |
else: | |
return None | |
def main(): | |
random.seed(0) | |
nb_shards = 2500 | |
sinks = [wds.TarWriter(f"/p/fastdata/datasets/pubmed/figure-captions/{i:05d}.tar") for i in range(nb_shards)] | |
sink_iter = iter(ShuffledIter(sinks)) | |
dataset = get_wds("/p/fastdata/datasets/pubmed/raw/{00000..00520}.tar") | |
t0 = time.time() | |
nb = 0 | |
i = 0 | |
for keys, ims, txts in dataset: | |
for key, im, txt in zip(keys, ims, txts): | |
data = { | |
"__key__":key, | |
"jpg": im, | |
"txt": txt, | |
} | |
sink = next(sink_iter) | |
sink.write(data) | |
nb += len(keys) | |
if i % 1000 == 0: | |
dt = time.time() - t0 | |
print(dt, nb/dt) | |
i += 1 | |
for sink in sinks: | |
sink.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment