Created
July 15, 2024 14:17
-
-
Save craiga/ac09b21908f7ddbab0bb9e899c8b07b2 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
"""Get images from PhotoPrism and label them.""" | |
import json | |
import logging | |
import tempfile | |
import time | |
from pprint import pformat | |
import boto3 | |
import click | |
import click_log | |
import httpx | |
import ollama as ollamalib | |
logger = logging.getLogger(__name__) | |
click_log.basic_config(logger) | |
class SessionIDAuth(httpx.Auth): | |
"""Session ID authentication.""" | |
def __init__(self, session_id): | |
"""Session ID authentication.""" | |
self.session_id = session_id | |
def auth_flow(self, request): | |
"""Add session ID header to request.""" | |
request.headers["X-Session-ID"] = self.session_id | |
yield request | |
def image_search(session, api_url, query, order): | |
"""Perform a paginated image search.""" | |
offset = 0 | |
page_size = 100 | |
while True: | |
response = session.get( | |
api_url + "/photos", | |
params={"count": page_size, "offset": offset, "q": query, "order": order}, | |
) | |
response.raise_for_status() | |
photos = response.json() | |
if not photos: | |
break | |
yield from photos | |
offset += page_size | |
@click.group() | |
@click.option("-u", "--username") | |
@click.option("-p", "--password") | |
@click.option("-u", "--api-url") | |
@click.option( | |
"-s", "--sleep-between-labels", "--sleep", default=0, show_default=True, type=int | |
) | |
@click.option("--sleep-between-images", default=0, show_default=True, type=int) | |
@click.option("-q", "--query", default="original:*", show_default=True) | |
@click.option("-o", "--order", default="added", show_default=True) | |
@click_log.simple_verbosity_option(logger) | |
@click.pass_context | |
def cli( | |
context, | |
username, | |
password, | |
api_url, | |
sleep_between_labels, | |
sleep_between_images, | |
query, | |
order, | |
): | |
"""Get images from PhotoPrism and label them.""" | |
logger.debug("Establish an httpx session.") | |
response = httpx.post( | |
api_url + "/session", | |
json={"username": username, "password": password}, | |
timeout=None, | |
) | |
response.raise_for_status() | |
response_data = response.json() | |
session = httpx.Client(auth=SessionIDAuth(response_data["id"]), timeout=None) | |
download_token = response_data["config"]["downloadToken"] | |
context.ensure_object(dict) | |
context.obj["api_url"] = api_url | |
context.obj["download_token"] = download_token | |
context.obj["order"] = order | |
context.obj["query"] = query | |
context.obj["session"] = session | |
context.obj["sleep_between_images"] = sleep_between_images | |
context.obj["sleep_between_labels"] = sleep_between_labels | |
@cli.command() | |
@click.pass_context | |
def rekognition(context): | |
"""Label images using Amazon Rekognition.""" | |
api_url = context.obj["api_url"] | |
download_token = context.obj["download_token"] | |
order = context.obj["order"] | |
query = context.obj["query"] | |
session = context.obj["session"] | |
sleep_between_images = context.obj["sleep_between_images"] | |
sleep_between_labels = context.obj["sleep_between_labels"] | |
rekognition = boto3.client("rekognition") | |
for search_result in image_search(session, api_url, query, order): | |
logger.info(f"Processing uid:{search_result["UID"]}…") | |
info_response = session.get(api_url + "/photos/" + search_result["UID"]) | |
info_response.raise_for_status() | |
info = info_response.json() | |
if "Rekognition" in [l["Label"]["Name"] for l in info["Labels"]]: | |
logger.info("Skipping as it already has Rekognition label.") | |
continue | |
try: | |
download_response = session.get( | |
api_url + "/dl/" + search_result["Hash"], params={"t": download_token} | |
) | |
download_response.raise_for_status() | |
except Exception as exc: | |
logger.warning( | |
"Error while calling Rekognition, moving on to next image.", | |
exc_info=exc, | |
) | |
continue | |
try: | |
rekognition_response = rekognition.detect_labels( | |
Image={"Bytes": download_response.content} | |
) | |
except Exception as exc: | |
logger.warning( | |
"Error while calling Rekognition, moving on to next image.", | |
exc_info=exc, | |
) | |
continue | |
labels = [ | |
(l["Name"], l["Confidence"]) for l in rekognition_response["Labels"] | |
] + [("Rekognition", 100)] | |
for label, confidence in labels: | |
uncertainty = int(100 - confidence) | |
logger.info(f"Adding label {label} with uncertainty of {uncertainty}.") | |
label_response = session.post( | |
api_url + "/photos/" + search_result["UID"] + "/label", | |
json={"Name": label, "Priority": 0, "Uncertainty": uncertainty}, | |
) | |
label_response.raise_for_status() | |
time.sleep(sleep_between_labels) | |
time.sleep(sleep_between_images) | |
@cli.command() | |
@click.pass_context | |
@click.option("--url", default="http://localhost:11434", show_default=True, type=str) | |
@click.option("--model", default="llava", show_default=True, type=str) | |
@click.option( | |
"--prompt", | |
default=( | |
"Your job is to generate tags for images, along with a confidence score for" | |
" each tag. Generate as many tags as possible. Tags are generally one or two" | |
" words long. Confidence should be a score from 0 to 100 of how confident you" | |
" are in that tag. Return results as a JSON object with the tags as the keys" | |
" and the confidence as the values." | |
), | |
show_default=True, | |
type=str, | |
) | |
def ollama(context, url, model, prompt): | |
"""Label images using Ollama.""" | |
api_url = context.obj["api_url"] | |
download_token = context.obj["download_token"] | |
order = context.obj["order"] | |
query = context.obj["query"] | |
session = context.obj["session"] | |
sleep_between_images = context.obj["sleep_between_images"] | |
sleep_between_labels = context.obj["sleep_between_labels"] | |
for search_result in image_search(session, api_url, query, order): | |
logger.info(f"Processing uid:{search_result["UID"]}…") | |
try: | |
download_response = session.get( | |
api_url + "/dl/" + search_result["Hash"], params={"t": download_token} | |
) | |
download_response.raise_for_status() | |
except Exception as exc: | |
logger.warning( | |
"Error while calling Rekognition, moving on to next image.", | |
exc_info=exc, | |
) | |
continue | |
with tempfile.NamedTemporaryFile() as tmp_file: | |
tmp_file.write(download_response.content) | |
ollama_client = ollamalib.Client(host=url) | |
ollama_response = ollama_client.generate( | |
model=model, format="json", images=[tmp_file.name], prompt=prompt | |
) | |
logger.debug( | |
"Got response from Ollama:\n%s", pformat(ollama_response["response"]) | |
) | |
try: | |
labels = json.loads(ollama_response["response"]) | |
except json.decoder.JSONDecodeError as exc: | |
exc.add_note( | |
f'String trying to be decoded: "{ollama_response["response"]}"' | |
) | |
logger.warning( | |
"Error decoding JSON response from Ollama, moving on to next image.", | |
exc_info=exc, | |
) | |
continue | |
for label, confidence in labels.items(): | |
uncertainty = int(100 - confidence) | |
logger.info(f"Adding label {label} with uncertainty of {uncertainty}.") | |
label_response = session.post( | |
api_url + "/photos/" + search_result["UID"] + "/label", | |
json={"Name": label, "Priority": 0, "Uncertainty": uncertainty}, | |
) | |
label_response.raise_for_status() | |
time.sleep(sleep_between_labels) | |
time.sleep(sleep_between_images) | |
if __name__ == "__main__": | |
cli() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
boto3 | |
click | |
click-log | |
pip-tools | |
httpx | |
ollama |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# | |
# This file is autogenerated by pip-compile with Python 3.12 | |
# by the following command: | |
# | |
# pip-compile --allow-unsafe --strip-extras | |
# | |
anyio==4.4.0 | |
# via httpx | |
boto3==1.34.115 | |
# via -r requirements.in | |
botocore==1.34.115 | |
# via | |
# boto3 | |
# s3transfer | |
build==1.2.1 | |
# via pip-tools | |
certifi==2024.2.2 | |
# via | |
# httpcore | |
# httpx | |
click==8.1.7 | |
# via | |
# -r requirements.in | |
# click-log | |
# pip-tools | |
click-log==0.4.0 | |
# via -r requirements.in | |
h11==0.14.0 | |
# via httpcore | |
httpcore==1.0.5 | |
# via httpx | |
httpx==0.27.0 | |
# via | |
# -r requirements.in | |
# ollama | |
idna==3.7 | |
# via | |
# anyio | |
# httpx | |
jmespath==1.0.1 | |
# via | |
# boto3 | |
# botocore | |
ollama==0.2.1 | |
# via -r requirements.in | |
packaging==24.0 | |
# via build | |
pip-tools==7.4.1 | |
# via -r requirements.in | |
pyproject-hooks==1.1.0 | |
# via | |
# build | |
# pip-tools | |
python-dateutil==2.9.0.post0 | |
# via botocore | |
s3transfer==0.10.1 | |
# via boto3 | |
six==1.16.0 | |
# via python-dateutil | |
sniffio==1.3.1 | |
# via | |
# anyio | |
# httpx | |
urllib3==2.2.1 | |
# via botocore | |
wheel==0.43.0 | |
# via pip-tools | |
# The following packages are considered to be unsafe in a requirements file: | |
pip==24.0 | |
# via pip-tools | |
setuptools==70.0.0 | |
# via pip-tools |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment