Created
August 9, 2024 20:22
-
-
Save itsPreto/038abe634dd10e9850768cd0b13ed169 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import requests | |
import json | |
from typing import List, Dict, Union, Optional | |
import subprocess | |
import argparse | |
from rich.console import Console | |
from rich.table import Table | |
from rich.panel import Panel | |
from rich.text import Text | |
from rich import box | |
from rich.syntax import Syntax | |
from rich.tree import Tree | |
import os | |
class OllamaModelManager: | |
def __init__(self, base_url: str = "http://localhost:11434"): | |
self.base_url = base_url | |
self.console = Console() | |
self.custom_tags = self.load_custom_tags() | |
def _request(self, method: str, endpoint: str, data: Optional[Dict] = None) -> Union[Dict, List[Dict]]: | |
url = f"{self.base_url}{endpoint}" | |
try: | |
response = requests.request(method, url, json=data) | |
response.raise_for_status() | |
return response.json() | |
except requests.exceptions.RequestException as e: | |
self.console.print(f"[bold red]Error making request to {url}: {str(e)}[/bold red]") | |
return None | |
def create_model(self, name: str, modelfile: str, stream: bool = True) -> Union[Dict, List[Dict]]: | |
data = {"name": name, "modelfile": modelfile, "stream": stream} | |
result = self._request("POST", "/api/create", data) | |
if result: | |
self.console.print(f"[green]Model {name} created successfully.[/green]") | |
return result | |
def check_blob_exists(self, digest: str) -> bool: | |
url = f"{self.base_url}/api/blobs/{digest}" | |
try: | |
response = requests.head(url) | |
return response.status_code == 200 | |
except requests.exceptions.RequestException: | |
return False | |
def create_blob(self, digest: str, file_path: str) -> Optional[str]: | |
url = f"{self.base_url}/api/blobs/{digest}" | |
try: | |
with open(file_path, 'rb') as file: | |
response = requests.post(url, data=file) | |
response.raise_for_status() | |
return response.text | |
except (IOError, requests.exceptions.RequestException) as e: | |
self.console.print(f"[bold red]Error creating blob: {str(e)}[/bold red]") | |
return None | |
def list_local_models(self) -> Optional[Dict]: | |
models = self._request("GET", "/api/tags") | |
if models: | |
for model in models.get('models', []): | |
model['tags'] = self.add_tags(model) | |
return models | |
def show_model_info(self, name: str, verbose: bool = False) -> Optional[Dict]: | |
data = {"name": name, "verbose": verbose} | |
info = self._request("POST", "/api/show", data) | |
if info: | |
size = info['details'].get('parameter_size', '') | |
if not size: | |
size = info['details'].get('size', 'Unknown') | |
modified = info.get('modified_at', 'Unknown') | |
info['tags'] = self.add_tags({'name': name, 'size': size, 'modified': modified}) | |
return info | |
def copy_model(self, source: str, destination: str) -> bool: | |
data = {"source": source, "destination": destination} | |
result = self._request("POST", "/api/copy", data) | |
if result: | |
self.console.print(f"[green]Model {source} copied to {destination} successfully.[/green]") | |
if source in self.custom_tags: | |
self.custom_tags[destination] = self.custom_tags[source].copy() | |
self.save_custom_tags() | |
return result is not None | |
def delete_model(self, name: str) -> bool: | |
data = {"name": name} | |
result = self._request("DELETE", "/api/delete", data) | |
if result: | |
self.console.print(f"[green]Model {name} deleted successfully.[/green]") | |
if name in self.custom_tags: | |
del self.custom_tags[name] | |
self.save_custom_tags() | |
return result is not None | |
def pull_model(self, name: str, insecure: bool = False, stream: bool = True) -> Union[Dict, List[Dict], None]: | |
data = {"name": name, "insecure": insecure, "stream": stream} | |
result = self._request("POST", "/api/pull", data) | |
if result: | |
self.console.print(f"[green]Model {name} pulled successfully.[/green]") | |
return result | |
def push_model(self, name: str, insecure: bool = False, stream: bool = True) -> Union[Dict, List[Dict], None]: | |
data = {"name": name, "insecure": insecure, "stream": stream} | |
result = self._request("POST", "/api/push", data) | |
if result: | |
self.console.print(f"[green]Model {name} pushed successfully.[/green]") | |
return result | |
def generate_embeddings(self, model: str, input: Union[str, List[str]], truncate: bool = True, options: Optional[Dict] = None, keep_alive: str = "5m") -> Optional[Dict]: | |
data = { | |
"model": model, | |
"input": input, | |
"truncate": truncate, | |
"options": options or {}, | |
"keep_alive": keep_alive | |
} | |
return self._request("POST", "/api/embed", data) | |
def list_running_models(self) -> Optional[Dict]: | |
models = self._request("GET", "/api/ps") | |
if models: | |
for model in models.get('models', []): | |
model['tags'] = self.add_tags(model) | |
return models | |
def get_ollama_list(self) -> List[str]: | |
try: | |
result = subprocess.run(['ollama', 'list'], capture_output=True, text=True, check=True) | |
return result.stdout.strip().split('\n')[1:] # Skip header | |
except subprocess.CalledProcessError as e: | |
self.console.print(f"[bold red]Error running 'ollama list': {str(e)}[/bold red]") | |
return [] | |
def parse_line(self, line: str) -> Dict[str, str]: | |
parts = line.split('\t') | |
return { | |
'name': parts[0].strip(), | |
'id': parts[1].strip(), | |
'size': parts[2].strip(), | |
'modified': ' '.join(parts[3:]).strip() | |
} | |
def add_tags(self, model: Dict[str, str]) -> List[str]: | |
tags = [] | |
# Size-based tags | |
size_str = model['size'].split()[0] | |
if size_str.endswith('M'): | |
size = float(size_str[:-1]) / 1024 # Convert MB to GB | |
elif size_str.endswith('B'): | |
size = float(size_str[:-1]) | |
else: | |
size = float(size_str) | |
if size < 1: | |
tags.append('tiny') | |
elif size < 3: | |
tags.append('small') | |
elif size < 7: | |
tags.append('medium') | |
else: | |
tags.append('large') | |
# Recency tags | |
modified = model['modified'] | |
if 'hours' in modified: | |
tags.append('recent') | |
elif 'days' in modified and int(modified.split()[0]) <= 7: | |
tags.append('week-old') | |
else: | |
tags.append('older') | |
# Model type tags | |
name = model['name'].lower() | |
if 'embed' in name: | |
tags.append('embedding') | |
if any(prefix in name for prefix in ['llama', 'mistral', 'phi']): | |
tags.append('llm') | |
if 'code' in name: | |
tags.append('code') | |
# Add custom tags | |
if model['name'] in self.custom_tags: | |
tags.extend(self.custom_tags[model['name']]) | |
return tags | |
def filter_models(self, models: List[Dict[str, str]], tags: Optional[List[str]] = None) -> List[Dict[str, str]]: | |
if not tags: | |
return models | |
return [model for model in models if any(tag in model['tags'] for tag in tags)] | |
def list_models_with_tags(self, tags: Optional[List[str]] = None) -> List[Dict[str, str]]: | |
models = [self.parse_line(line) for line in self.get_ollama_list()] | |
for model in models: | |
model['tags'] = self.add_tags(model) | |
return self.filter_models(models, tags) | |
def get_embedding_model(self) -> Optional[str]: | |
embedding_models = self.list_models_with_tags(['embedding']) | |
if embedding_models: | |
return embedding_models[0]['name'] # Return the first available embedding model | |
return None | |
def get_llm_model(self) -> Optional[str]: | |
llm_models = self.list_models_with_tags(['llm']) | |
if llm_models: | |
return llm_models[0]['name'] # Return the first available LLM model | |
return None | |
def display_all_models(self, models: List[Dict[str, str]]): | |
table = Table(title="All Models", box=box.DOUBLE_EDGE) | |
table.add_column("Name", style="cyan", no_wrap=True) | |
table.add_column("Size", style="magenta") | |
table.add_column("Modified", style="green") | |
table.add_column("Tags", style="yellow") | |
table.add_column("Custom Tags", style="red") | |
for model in models: | |
custom_tags = ", ".join(self.custom_tags.get(model['name'], [])) | |
table.add_row( | |
model['name'], | |
model['size'], | |
model['modified'], | |
", ".join(model['tags']), | |
custom_tags | |
) | |
self.console.print(Panel(table, expand=False, border_style="bold white")) | |
def display_filtered_models(self, models: List[Dict[str, str]], filter_tags: List[str]): | |
table = Table(title=f"Models Tagged: {', '.join(filter_tags)}", box=box.SIMPLE_HEAVY) | |
table.add_column("Name", style="cyan", no_wrap=True) | |
table.add_column("Size", style="magenta") | |
table.add_column("Tags", style="yellow") | |
table.add_column("Custom Tags", style="red") | |
for model in models: | |
custom_tags = ", ".join(self.custom_tags.get(model['name'], [])) | |
table.add_row( | |
model['name'], | |
model['size'], | |
", ".join(model['tags']), | |
custom_tags | |
) | |
self.console.print(Panel(table, expand=False, border_style="bold green")) | |
def display_model_info(self, model_name: str, info: Dict): | |
tree = Tree(f"[bold cyan]{model_name}[/bold cyan]") | |
details = tree.add("Details") | |
for key, value in info['details'].items(): | |
details.add(f"[yellow]{key}[/yellow]: {value}") | |
modelfile = tree.add("Modelfile") | |
modelfile.add(Syntax(info['modelfile'], "dockerfile", theme="monokai", line_numbers=True)) | |
parameters = tree.add("Parameters") | |
for line in info['parameters'].split('\n'): | |
parameters.add(line) | |
tags = tree.add("Tags") | |
tags.add(", ".join(info['tags'])) | |
self.console.print(Panel(tree, title="Model Information", expand=False, border_style="bold magenta")) | |
def display_embeddings(self, model_name: str, embeddings: List[float]): | |
table = Table(title=f"Embeddings from {model_name}", box=box.SIMPLE) | |
table.add_column("Index", style="cyan", justify="right") | |
table.add_column("Value", style="magenta") | |
for i, value in enumerate(embeddings[:10]): # Display first 10 values | |
table.add_row(str(i), f"{value:.6f}") | |
self.console.print(Panel(table, expand=False, border_style="bold yellow")) | |
def add_custom_tag(self, model_name: str, tag: str): | |
if model_name not in self.custom_tags: | |
self.custom_tags[model_name] = [] | |
if tag not in self.custom_tags[model_name]: | |
self.custom_tags[model_name].append(tag) | |
self.save_custom_tags() | |
return True | |
return False | |
def remove_custom_tag(self, model_name: str, tag: str): | |
if model_name in self.custom_tags and tag in self.custom_tags[model_name]: | |
self.custom_tags[model_name].remove(tag) | |
if not self.custom_tags[model_name]: | |
del self.custom_tags[model_name] | |
self.save_custom_tags() | |
return True | |
return False | |
def load_custom_tags(self): | |
try: | |
with open('custom_tags.json', 'r') as f: | |
return json.load(f) | |
except FileNotFoundError: | |
return {} | |
def save_custom_tags(self): | |
with open('custom_tags.json', 'w') as f: | |
json.dump(self.custom_tags, f) | |
def main(): | |
parser = argparse.ArgumentParser(description="Ollama Model Manager") | |
parser.add_argument("--list", action="store_true", help="List all models") | |
parser.add_argument("--add-tag", nargs=2, metavar=('MODEL', 'TAG'), help="Add a custom tag to a model") | |
parser.add_argument("--remove-tag", nargs=2, metavar=('MODEL', 'TAG'), help="Remove a custom tag from a model") | |
parser.add_argument("--filter", nargs='+', help="Filter models by tags") | |
parser.add_argument("--info", metavar='MODEL', help="Show detailed information for a specific model") | |
parser.add_argument("--embeddings", nargs=2, metavar=('MODEL', 'TEXT'), help="Generate embeddings for the given text using the specified model") | |
parser.add_argument("--create", nargs=2, metavar=('NAME', 'MODELFILE'), help="Create a new model") | |
parser.add_argument("--delete", metavar='MODEL', help="Delete a model") | |
parser.add_argument("--copy", nargs=2, metavar=('SOURCE', 'DESTINATION'), help="Copy a model") | |
parser.add_argument("--pull", metavar='MODEL', help="Pull a model") | |
parser.add_argument("--push", metavar='MODEL', help="Push a model") | |
args = parser.parse_args() | |
api = OllamaModelManager() | |
if args.list or not any(vars(args).values()): | |
all_models = api.list_models_with_tags() | |
api.display_all_models(all_models) | |
if args.add_tag: | |
model, tag = args.add_tag | |
if api.add_custom_tag(model, tag): | |
api.console.print(f"[green]Added tag '{tag}' to {model}[/green]") | |
else: | |
api.console.print(f"[red]Failed to add tag '{tag}' to {model}[/red]") | |
if args.remove_tag: | |
model, tag = args.remove_tag | |
if api.remove_custom_tag(model, tag): | |
api.console.print(f"[green]Removed tag '{tag}' from {model}[/green]") | |
else: | |
api.console.print(f"[red]Failed to remove tag '{tag}' from {model}[/red]") | |
if args.filter: | |
filtered_models = api.list_models_with_tags(args.filter) | |
api.display_filtered_models(filtered_models, args.filter) | |
if args.info: | |
model_info = api.show_model_info(args.info) | |
if model_info: | |
api.display_model_info(args.info, model_info) | |
else: | |
api.console.print(f"[red]Unable to fetch information for model: {args.info}[/red]") | |
if args.embeddings: | |
model, text = args.embeddings | |
embeddings = api.generate_embeddings(model, text) | |
if embeddings: | |
api.display_embeddings(model, embeddings['embeddings'][0]) | |
else: | |
api.console.print(f"[red]Unable to generate embeddings using model: {model}[/red]") | |
if args.create: | |
name, modelfile = args.create | |
with open(modelfile, 'r') as f: | |
modelfile_content = f.read() | |
api.create_model(name, modelfile_content) | |
if args.delete: | |
api.delete_model(args.delete) | |
if args.copy: | |
source, destination = args.copy | |
api.copy_model(source, destination) | |
if args.pull: | |
api.pull_model(args.pull) | |
if args.push: | |
api.push_model(args.push) | |
if __name__ == "__main__": | |
main() |
@NEWbie0709 can you update these 3 functions and tell me if you see something like the output below?
def parse_line(self, line: str) -> Dict[str, str]:
parts = line.split()
if len(parts) < 4:
return {} # Return an empty dict if the line doesn't have enough parts
return {
'name': parts[0],
'id': parts[1],
'size': parts[2],
'modified': ' '.join(parts[3:])
}
def add_tags(self, model: Dict[str, str]) -> List[str]:
tags = []
# Size-based tags
size_str = model['size'].replace(' ', '') # Remove any spaces
size = 0
unit = 'GB'
if size_str.endswith('MB'):
size = float(size_str[:-2])
size /= 1024 # Convert MB to GB
elif size_str.endswith('GB'):
size = float(size_str[:-2])
else:
try:
size = float(size_str)
except ValueError:
tags.append('unknown_size')
if size > 0:
if size < 1:
tags.append('tiny')
elif size < 3:
tags.append('small')
elif size < 7:
tags.append('medium')
else:
tags.append('large')
# Recency tags
modified = model['modified']
if 'hours' in modified:
tags.append('recent')
elif 'days' in modified:
try:
days = int(modified.split()[0])
if days <= 7:
tags.append('week-old')
else:
tags.append('older')
except ValueError:
tags.append('unknown_age')
else:
tags.append('older')
# Model type tags
name = model['name'].lower()
if 'embed' in name:
tags.append('embedding')
if any(prefix in name for prefix in ['llama', 'mistral', 'phi']):
tags.append('llm')
if 'code' in name:
tags.append('code')
# Add custom tags
if model['name'] in self.custom_tags:
tags.extend(self.custom_tags[model['name']])
return tags
def list_models_with_tags(self, tags: Optional[List[str]] = None) -> List[Dict[str, str]]:
models = []
for line in self.get_ollama_list():
model = self.parse_line(line)
if model: # Only add non-empty dictionaries
model['tags'] = self.add_tags(model)
models.append(model)
return self.filter_models(models, tags)
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
what do you get when you run
ollama list
? becasue as long as you have at least 1 model downloaded it should work since--list
is just runningollama list
under the hood.