Skip to content

Instantly share code, notes, and snippets.

@bplunkert
Last active July 8, 2024 11:27
Show Gist options
  • Save bplunkert/ade30e70a20d3fc06a2664e56f9be83d to your computer and use it in GitHub Desktop.
Save bplunkert/ade30e70a20d3fc06a2664e56f9be83d to your computer and use it in GitHub Desktop.
A simple image classifier that uses llava and an ollama endpoint to do local image classification
#!/usr/bin/env python3
import base64
import json
import os
import requests
# Function to find image files in the current directory
def find_image_files(directory):
image_extensions = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff']
image_files = [file for file in os.listdir(directory) if os.path.splitext(file)[1].lower() in image_extensions]
return image_files
# Function to query the Ollama endpoint with an image file
def query_ollama(file_path):
url = "http://inference:11434/api/generate"
model = "llava"
prompt = 'How many cats are in the image? Return your output in the format: ```{ "count" : n }```'
# Open the image file in binary mode and encode it in base64
with open(file_path, 'rb') as file:
encoded_image = base64.b64encode(file.read()).decode('utf-8')
data = {
"model": model,
"prompt": prompt,
"images": [encoded_image],
"stream": False, # Set stream to false to get a single response
"format": "json" # Ensure the response is in JSON format
}
response = requests.post(url, json=data)
if response.status_code == 200:
return response.json()
else:
return f"Error: {response.status_code}"
import time
def main():
current_directory = os.getcwd()
image_files = find_image_files(current_directory)
for file in image_files:
response = {}
while 'response' not in response:
try:
response = query_ollama(file)
print(f"Number of cats in picture {file}: {json.loads(response['response'])['count']}")
except KeyError:
print(f"Error occurred with file {file}, retrying...")
main()
#!/usr/bin/env python3
import base64
import os
import requests
# Function to find image files in the current directory
def find_image_files(directory):
image_extensions = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff']
image_files = [file for file in os.listdir(directory) if os.path.splitext(file)[1].lower() in image_extensions]
return image_files
# Function to query the Ollama endpoint with an image file
def query_ollama(file_path):
url = "http://inference:11434/api/generate"
model = "llava"
prompt = "What is the image?"
# Open the image file in binary mode and encode it in base64
with open(file_path, 'rb') as file:
encoded_image = base64.b64encode(file.read()).decode('utf-8')
data = {
"model": model,
"prompt": prompt,
"images": [encoded_image],
"stream": False # Set stream to false to get a single response
# "format": "json", # Ensure the response is in JSON format
}
response = requests.post(url, json=data)
if response.status_code == 200:
return response.json()
else:
return f"Error: {response.status_code}"
def main():
current_directory = os.getcwd()
image_files = find_image_files(current_directory)
for file in image_files:
response = query_ollama(file)
print(f"Response for {file}: {response['response']}")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment