Created
December 27, 2022 00:55
-
-
Save ryanwebster90/5a645aceb519e60499089b24813e5ec5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import torch | |
import fire | |
import glob | |
def abs_ind_to_feat_file(abs_ind, cum_sz, feat_files): | |
inds = np.argwhere(abs_ind - cum_sz >= 0) | |
last_ind = inds[-1].item() | |
ind_offset = cum_sz[last_ind] | |
local_ind = abs_ind - ind_offset | |
return feat_files[last_ind],last_ind,local_ind | |
def pil_to_torch(img, none_to_black = True): | |
if img is not None: | |
img = np.array(img) | |
# handle BW images | |
if len(img.shape) == 2: | |
img = torch.from_numpy(img).to(torch.float).unsqueeze(0).repeat(3,1,1).unsqueeze(0) | |
else: | |
img = torch.from_numpy(img).to(torch.float).permute((2,0,1)).unsqueeze(0) | |
else: | |
img = torch.zeros(1,3,256,256).to(torch.float) | |
return img | |
def feat_key_to_file_loc(feat_path): | |
if '24_26' in feat_path: | |
abs_path = '/home/ryan/dw3/hdd5tb/laion_24_26/' | |
elif '16_20' in feat_path: | |
abs_path = '/home/ryan/hdd1/laion-data-16-20/' | |
elif '27_31' in feat_path: | |
abs_path = '/home/ryan/dw3/hdd/laion_data_27_31/' | |
elif '1_4_1' in feat_path: | |
abs_path = '/home/ryan/dw0/laion400m-data-old/' | |
elif '21_23' in feat_path: | |
abs_path = '/home/ryan/dw3/hdd5tb/laion_21_23_data/' | |
elif '9_12' in feat_path: | |
abs_path = '/home/ryan/dw0/hdd1/laion400m-p9_12/' | |
elif '1_4_0' in feat_path: | |
abs_path = '/home/ryan/dw0/hdd1/laion_data_old_first/' | |
else: | |
print('ERROR, no file path found') | |
return abs_path | |
def vis_nns_from_nn_inds(nn_file, chunk_size=4, num_chunks=2, query_image_folder=None, query_wds_folder = None, query_captions = None): | |
from PIL import Image | |
import glob | |
import os | |
import torchvision | |
import math | |
import pandas as pd | |
import wds_utils | |
import glob | |
import pickle as pkl | |
# out_name = out_name.split('/')[-1][:-3] | |
# we need to go from abs ind > tarfile location | |
feat_files = sorted(glob.glob('../hdd14tb/vitb32_overlap_feats/img_emb/*.npy')) | |
cum_sz = [0] | |
for feat in feat_files: | |
cum_sz += [cum_sz[-1] + np.load(feat,mmap_mode='r').shape[0]] | |
cum_sz = np.array(cum_sz).astype('int') | |
pq_files = sorted(glob.glob('../hdd14tb/vitb32_overlap_feats/metadata/*.npy')) | |
feat_files_src = feat_files | |
# nn_inds = np.load(nn_file).astype('int32') | |
nn_inds = nn_file | |
if query_image_folder is not None: | |
query_img_files = sorted(glob.glob(f'{query_image_folder}/*.jpg')) | |
# should we also do the same for the captions? | |
all_imgs = [] | |
row_size = nn_inds.shape[1] | |
# os.makedirs(f'vis/{out_name}/',exist_ok=True) | |
text_strs = '' | |
metadata_strs = '' | |
import wds_utils_v1 | |
offset = 0 | |
# convert all the abs_inds | |
# for first | |
for i in range(1): | |
# replace this with a regular Image read if the query is an image folder | |
text_strs = ' \n' | |
# all_imgs += [pil_to_torch(img)] | |
all_imgs = [] | |
for k in range(12): | |
abs_ind = nn_inds[0,k] | |
print("abs ind = ",abs_ind) | |
# go from abs into to feat file | |
feat_file,feat_file_ind, local_ind = abs_ind_to_feat_file(abs_ind, cum_sz, feat_files_src) | |
# nn_key,ind = abs_ind_to_feat_key_and_local_ind(abs_ind,ind_map) | |
file_path = feat_key_to_file_loc(feat_file) | |
# feat_file,ind = abs_ind_to_folder(ind_map,abs_ind) | |
# nn_key,metadata_loc,data_loc = get_feat_file_data_and_metadata(feat_file,feat_to_file_loc) | |
print(f"ff={feat_file}, ffi={feat_file_ind}, file_path={file_path}, local_ind={local_ind}") | |
# we also need a dict of d[feat_file] = metadata[feat_file] | |
# for now, just reload the metadata every time | |
nn_files = np.load(pq_files[feat_file_ind]) | |
nn_file = nn_files[local_ind] | |
# nn_file = list(pd.read_parquet(pq_files[feat_file_ind])["image_path"])[local_ind] | |
# nn_file, last_ind,local_ind = abs_ind_to_pq_feat_file(abs_ind, cum_sz_md, md_dict) | |
metadata_strs += f"{i},{k+1} nn file {nn_file}, feat file {feat_file} \n" | |
tar_size = 10000 | |
if 'laion_4_8' in file_path: | |
tar_size = 1000 | |
print('file path and nn file',file_path, nn_file) | |
img,caption = wds_utils.retrieve_image(file_path, nn_file, tar_size=tar_size, verb=True) | |
all_imgs += [pil_to_torch(img)] | |
text_strs += f"{i},{k+1} {caption} \n" | |
# save end of chunk | |
if True: | |
chunk_ind = int(i/(chunk_size)) | |
# print(f'saving chunk to vis/{out_name}/{out_name}_{chunk_ind:03d}.jpg') | |
torchvision.utils.save_image(torch.cat(all_imgs,dim=0),f'text_query_demo.jpg',nrow=4,normalize=True) | |
all_imgs = [] | |
text_file = open(f'text_query_demo.txt','w') | |
text_file.write(text_strs) | |
text_file.close() | |
# text_file = open(f'vis/{out_name}/{out_name}_meta_{chunk_ind:03d}.txt','w') | |
# text_file.write(metadata_strs) | |
# text_file.close() | |
text_strs = '' | |
metadata_strs = '' | |
import faiss | |
import torch | |
from PIL import Image | |
import open_clip | |
import torch | |
import torch.nn.functional as F | |
from tqdm import tqdm | |
from open_clip import tokenize | |
import torchvision | |
import numpy as np | |
import os | |
with torch.no_grad(): | |
index = faiss.read_index('vitb32_overlap_index/image.index') | |
s = input('type your text query...') | |
# model, _, preprocess = open_clip.create_model_and_transforms('ViT-H-14',pretrained = 'laion2b_s32b_b79k') | |
import clip | |
device = torch.device("cuda") | |
model, preprocess = clip.load("ViT-B/32", device=device, jit=True) | |
model.cuda() | |
texts = tokenize([s]).cuda() # tokenize | |
text_embeddings = model.encode_text(texts) | |
text_embedding = F.normalize(text_embeddings, dim=-1).mean(dim=0) | |
text_embedding /= text_embedding.norm() | |
while True: | |
print("embeddings finished!") | |
print(text_embedding.size()) | |
d,nns = index.search(text_embedding.reshape(1,-1).cpu().numpy().astype('float32'),16) | |
vis_nns_from_nn_inds(nns) | |
print('done!') | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment