Skip to content

Instantly share code, notes, and snippets.

@recoilme
Created December 19, 2024 05:37
Show Gist options
  • Save recoilme/41272b4143b270df068f97afbdba3843 to your computer and use it in GitHub Desktop.
Save recoilme/41272b4143b270df068f97afbdba3843 to your computer and use it in GitHub Desktop.
import webdataset as wds
from PIL import Image
import io
import json
import time
import cv2
import numpy as np
from PIL import Image
output_path = "nij"
num_images = 100000
def downscale_image_by(image, max_size,x=64):
try:
image = np.array(image)
height, width = image.shape[:2]
if width > height:
new_width = max_size
new_height = int(height * (max_size / width))
else:
new_height = max_size
new_width = int(width * (max_size / height))
image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA)
new_width = (new_width // x) * x
new_height = (new_height // x) * x
image = image[:new_height, :new_width]
height, width = image.shape[:2]
if height!=new_height or width!=new_width:
image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA)
image = Image.fromarray(image)
return image
except Exception as e:
print(f"Error downscaling image: {e}")
return None
images_processed = 0
start_time = time.time()
current_ds = 281
while images_processed<num_images:
dataset_url = f"https://huggingface.co/datasets/CaptionEmporium/midjourney-niji-1m-llavanext/resolve/main/wds/train-000{current_ds}.tar"
print("current_ds",current_ds)
current_ds+=1
# Создаем WebDataset
dataset = wds.WebDataset(dataset_url)
for sample in dataset:
if "json" in sample:
json_data = json.loads(sample["json"])
model_source = json_data.get("model_source", "")
if model_source=="nijijourney_v6":
image_data = sample["jpg"] # Данные изображения
image = Image.open(io.BytesIO(image_data)).convert("RGB")
image = downscale_image_by(image,768,64)
id = json_data.get("id", "")
text = json_data.get("caption_llava", "")
image.save(f"{output_path}/{id}.jpg", quality=96)
text_filename = f"{output_path}/{id}.txt"
if images_processed>0 and images_processed % 10000==0:
elapsed_time = time.time() - start_time
estimated_total_time = (elapsed_time / images_processed) * num_images
remaining_time = estimated_total_time - elapsed_time
print(f"File: {id}, Caption: {text}\n")
print(f"Processed {images_processed}/{num_images} files, approximate remaining time: {time.strftime('%H:%M:%S', time.gmtime(remaining_time))}")
with open(text_filename, 'w') as file_txt:
file_txt.write(f"{text}")
#image.show()
images_processed+=1
#if images_processed==1:
# print(images_processed)
# break # Останавливаемся после первой записи
print('done')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment