Last active
November 6, 2022 12:18
-
-
Save benob/1a2457637f6d02f9643342b37b186674 to your computer and use it in GitHub Desktop.
Simple searchable image gallery using CLIP vectors to represent picture content and text queries
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Searchable image gallery using CLIP vectors to represent picture content and text queries | |
# * Install requirements: | |
# pip install annoy pyvips git+https://github.com/Lednik7/CLIP-ONNX.git git+https://github.com/openai/CLIP.git bottle protobuf==3.20 | |
# * Download models: | |
# wget https://clip-as-service.s3.us-east-2.amazonaws.com/models/onnx/ViT-B-32/{visual,textual}.onnx | |
# * Run server and open http://127.0.0.1:8080: | |
# python gallery.py /path/to/pictures/directory | |
import os | |
import sys | |
import glob | |
import struct | |
import hashlib | |
import sqlite3 | |
import clip | |
import pyvips | |
import numpy as np | |
from PIL import Image | |
from annoy import AnnoyIndex | |
from clip_onnx import clip_onnx | |
from bottle import run, get, static_file, redirect, request | |
directory = sys.argv[1] | |
os.makedirs('thumbnails', exist_ok=True) | |
os.makedirs('features', exist_ok=True) | |
# paths are stored in a simple sqlite db | |
con = sqlite3.connect('db.sqlite') | |
cur = con.cursor() | |
cur.execute('CREATE TABLE IF NOT EXISTS pictures (path TEXT, date DATE, width INT, height INT, thumbnail TEXT, features TEXT)') | |
con.commit() | |
# load clip models | |
_, preprocess = clip.load("ViT-B/32", device="cpu", jit=False) | |
onnx_model = clip_onnx(None) | |
onnx_model.load_onnx(visual_path="visual.onnx", textual_path="textual.onnx", logit_scale=100.0) # model.logit_scale.exp() | |
onnx_model.start_sessions(providers=["CPUExecutionProvider"]) | |
# generate a vector of features with clip | |
def extract_features(path): | |
image_input = Image.open(path) | |
image = preprocess(image_input).unsqueeze(0).cpu() # [1, 3, 224, 224] | |
image_onnx = image.detach().cpu().numpy().astype(np.float32) | |
image_features = onnx_model.encode_image(image_onnx) | |
return image_features[0] | |
# process a single picture | |
def process(path): | |
# pass already processed files | |
if (1,) == cur.execute('SELECT 1 FROM pictures WHERE path = ?', (path,)).fetchone(): | |
return | |
# compute unique identifier from path | |
name = hashlib.sha1(path.encode('utf8')).hexdigest() | |
# generate and save thumbnail | |
image = pyvips.Image.new_from_file(path, access='sequential') | |
image = image.autorot() | |
date = image.get("exif-ifd0-DateTime") | |
thumb = pyvips.Image.thumbnail_image(image, 224, crop='attention') | |
thumbnail_path = os.path.join('thumbnails', name[:4], name[4:] + '.jpg') | |
os.makedirs(os.path.dirname(thumbnail_path), exist_ok=True) | |
thumb.write_to_file(thumbnail_path) | |
# generate and save clip features | |
features = extract_features(thumbnail_path) | |
features_path = os.path.join('features', name[:4], name[4:] + '.bin') | |
os.makedirs(os.path.dirname(features_path), exist_ok=True) | |
with open(features_path, 'wb') as fp: | |
vectorFormat = struct.Struct('f' * 512) | |
fp.write(vectorFormat.pack(*features)) | |
# save to db | |
cur.execute('INSERT INTO pictures VALUES (?, ?, ?, ?, ?, ?)', (path, date, image.width, image.height, thumbnail_path, features_path)) | |
con.commit() | |
# process files recursively in directory | |
def crawl(directory): | |
for path in glob.iglob(directory + '/**/*.jpg', recursive=True): | |
try: | |
print(path) | |
process(path) | |
except Exception as e: | |
print(e) | |
# create an annoy index for quick vector search | |
def create_index(): | |
index = AnnoyIndex(512, 'angular') | |
for id, path in cur.execute('SELECT rowid, features FROM pictures'): | |
vectorFormat = struct.Struct('f' * 512) | |
with open(path, 'rb') as fp: | |
features = vectorFormat.unpack(fp.read(2048)) | |
index.add_item(id, features) | |
index.build(256) | |
return index | |
html = ''' | |
<!doctype html> | |
<html lang="en-US"> | |
<head> | |
<meta charset="utf-8"> | |
<meta name="viewport" content="width=device-width,initial-scale=1"> | |
</head> | |
<body> | |
%s | |
</body> | |
</html>''' | |
# / redirects to /0 | |
@get('/') | |
def callback(): | |
redirect('/0') | |
# return 100 pictures with similar features to picture with given id | |
@get('/<num:int>') | |
def callback(num): | |
result = '<form action="/search"><input type="text" name="q"><input type="submit"></form>' | |
for i in index.get_nns_by_item(int(num), 100): | |
path, thumbnail = cur.execute('SELECT path, thumbnail FROM pictures WHERE rowid = ?', (i,)).fetchone() | |
result += '<a href="/%d"><img width="244" height="244" title="%s" src="%s"></a>' % (i, path, thumbnail) | |
return html % result | |
# convert text to vector with clip and return 100 similar pictures | |
@get('/search') | |
def callback(): | |
query = request.params.get('q') | |
text = clip.tokenize([query]).cpu() | |
text_onnx = text.detach().cpu().numpy().astype(np.int64) | |
text_features = onnx_model.encode_text(text_onnx)[0] | |
result = '<form action="/search"><input type="text" name="q" value="%s"><input type="submit"></form>' % query | |
for i in index.get_nns_by_vector(text_features, 100): | |
path, thumbnail = cur.execute('SELECT path, thumbnail FROM pictures WHERE rowid = ?', (i,)).fetchone() | |
result += '<a href="/%d"><img width="244" height="244" title="%s" src="%s"></a>' % (i, path, thumbnail) | |
return html % result | |
# allow access to thumbnails | |
@get("/thumbnails/<filepath:path>") | |
def callback(filepath): | |
return static_file(filepath, root="thumbnails") | |
# generate thumbnails, features and index | |
crawl(directory) | |
index = create_index() | |
# run server at http://localhost:8080/ | |
run(host='127.0.0.1', port=8080) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment