Skip to content

Instantly share code, notes, and snippets.

@itsPreto
Created August 9, 2024 20:22
Show Gist options
  • Save itsPreto/038abe634dd10e9850768cd0b13ed169 to your computer and use it in GitHub Desktop.
Save itsPreto/038abe634dd10e9850768cd0b13ed169 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
import requests
import json
from typing import List, Dict, Union, Optional
import subprocess
import argparse
from rich.console import Console
from rich.table import Table
from rich.panel import Panel
from rich.text import Text
from rich import box
from rich.syntax import Syntax
from rich.tree import Tree
import os
class OllamaModelManager:
def __init__(self, base_url: str = "http://localhost:11434"):
self.base_url = base_url
self.console = Console()
self.custom_tags = self.load_custom_tags()
def _request(self, method: str, endpoint: str, data: Optional[Dict] = None) -> Union[Dict, List[Dict]]:
url = f"{self.base_url}{endpoint}"
try:
response = requests.request(method, url, json=data)
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
self.console.print(f"[bold red]Error making request to {url}: {str(e)}[/bold red]")
return None
def create_model(self, name: str, modelfile: str, stream: bool = True) -> Union[Dict, List[Dict]]:
data = {"name": name, "modelfile": modelfile, "stream": stream}
result = self._request("POST", "/api/create", data)
if result:
self.console.print(f"[green]Model {name} created successfully.[/green]")
return result
def check_blob_exists(self, digest: str) -> bool:
url = f"{self.base_url}/api/blobs/{digest}"
try:
response = requests.head(url)
return response.status_code == 200
except requests.exceptions.RequestException:
return False
def create_blob(self, digest: str, file_path: str) -> Optional[str]:
url = f"{self.base_url}/api/blobs/{digest}"
try:
with open(file_path, 'rb') as file:
response = requests.post(url, data=file)
response.raise_for_status()
return response.text
except (IOError, requests.exceptions.RequestException) as e:
self.console.print(f"[bold red]Error creating blob: {str(e)}[/bold red]")
return None
def list_local_models(self) -> Optional[Dict]:
models = self._request("GET", "/api/tags")
if models:
for model in models.get('models', []):
model['tags'] = self.add_tags(model)
return models
def show_model_info(self, name: str, verbose: bool = False) -> Optional[Dict]:
data = {"name": name, "verbose": verbose}
info = self._request("POST", "/api/show", data)
if info:
size = info['details'].get('parameter_size', '')
if not size:
size = info['details'].get('size', 'Unknown')
modified = info.get('modified_at', 'Unknown')
info['tags'] = self.add_tags({'name': name, 'size': size, 'modified': modified})
return info
def copy_model(self, source: str, destination: str) -> bool:
data = {"source": source, "destination": destination}
result = self._request("POST", "/api/copy", data)
if result:
self.console.print(f"[green]Model {source} copied to {destination} successfully.[/green]")
if source in self.custom_tags:
self.custom_tags[destination] = self.custom_tags[source].copy()
self.save_custom_tags()
return result is not None
def delete_model(self, name: str) -> bool:
data = {"name": name}
result = self._request("DELETE", "/api/delete", data)
if result:
self.console.print(f"[green]Model {name} deleted successfully.[/green]")
if name in self.custom_tags:
del self.custom_tags[name]
self.save_custom_tags()
return result is not None
def pull_model(self, name: str, insecure: bool = False, stream: bool = True) -> Union[Dict, List[Dict], None]:
data = {"name": name, "insecure": insecure, "stream": stream}
result = self._request("POST", "/api/pull", data)
if result:
self.console.print(f"[green]Model {name} pulled successfully.[/green]")
return result
def push_model(self, name: str, insecure: bool = False, stream: bool = True) -> Union[Dict, List[Dict], None]:
data = {"name": name, "insecure": insecure, "stream": stream}
result = self._request("POST", "/api/push", data)
if result:
self.console.print(f"[green]Model {name} pushed successfully.[/green]")
return result
def generate_embeddings(self, model: str, input: Union[str, List[str]], truncate: bool = True, options: Optional[Dict] = None, keep_alive: str = "5m") -> Optional[Dict]:
data = {
"model": model,
"input": input,
"truncate": truncate,
"options": options or {},
"keep_alive": keep_alive
}
return self._request("POST", "/api/embed", data)
def list_running_models(self) -> Optional[Dict]:
models = self._request("GET", "/api/ps")
if models:
for model in models.get('models', []):
model['tags'] = self.add_tags(model)
return models
def get_ollama_list(self) -> List[str]:
try:
result = subprocess.run(['ollama', 'list'], capture_output=True, text=True, check=True)
return result.stdout.strip().split('\n')[1:] # Skip header
except subprocess.CalledProcessError as e:
self.console.print(f"[bold red]Error running 'ollama list': {str(e)}[/bold red]")
return []
def parse_line(self, line: str) -> Dict[str, str]:
parts = line.split('\t')
return {
'name': parts[0].strip(),
'id': parts[1].strip(),
'size': parts[2].strip(),
'modified': ' '.join(parts[3:]).strip()
}
def add_tags(self, model: Dict[str, str]) -> List[str]:
tags = []
# Size-based tags
size_str = model['size'].split()[0]
if size_str.endswith('M'):
size = float(size_str[:-1]) / 1024 # Convert MB to GB
elif size_str.endswith('B'):
size = float(size_str[:-1])
else:
size = float(size_str)
if size < 1:
tags.append('tiny')
elif size < 3:
tags.append('small')
elif size < 7:
tags.append('medium')
else:
tags.append('large')
# Recency tags
modified = model['modified']
if 'hours' in modified:
tags.append('recent')
elif 'days' in modified and int(modified.split()[0]) <= 7:
tags.append('week-old')
else:
tags.append('older')
# Model type tags
name = model['name'].lower()
if 'embed' in name:
tags.append('embedding')
if any(prefix in name for prefix in ['llama', 'mistral', 'phi']):
tags.append('llm')
if 'code' in name:
tags.append('code')
# Add custom tags
if model['name'] in self.custom_tags:
tags.extend(self.custom_tags[model['name']])
return tags
def filter_models(self, models: List[Dict[str, str]], tags: Optional[List[str]] = None) -> List[Dict[str, str]]:
if not tags:
return models
return [model for model in models if any(tag in model['tags'] for tag in tags)]
def list_models_with_tags(self, tags: Optional[List[str]] = None) -> List[Dict[str, str]]:
models = [self.parse_line(line) for line in self.get_ollama_list()]
for model in models:
model['tags'] = self.add_tags(model)
return self.filter_models(models, tags)
def get_embedding_model(self) -> Optional[str]:
embedding_models = self.list_models_with_tags(['embedding'])
if embedding_models:
return embedding_models[0]['name'] # Return the first available embedding model
return None
def get_llm_model(self) -> Optional[str]:
llm_models = self.list_models_with_tags(['llm'])
if llm_models:
return llm_models[0]['name'] # Return the first available LLM model
return None
def display_all_models(self, models: List[Dict[str, str]]):
table = Table(title="All Models", box=box.DOUBLE_EDGE)
table.add_column("Name", style="cyan", no_wrap=True)
table.add_column("Size", style="magenta")
table.add_column("Modified", style="green")
table.add_column("Tags", style="yellow")
table.add_column("Custom Tags", style="red")
for model in models:
custom_tags = ", ".join(self.custom_tags.get(model['name'], []))
table.add_row(
model['name'],
model['size'],
model['modified'],
", ".join(model['tags']),
custom_tags
)
self.console.print(Panel(table, expand=False, border_style="bold white"))
def display_filtered_models(self, models: List[Dict[str, str]], filter_tags: List[str]):
table = Table(title=f"Models Tagged: {', '.join(filter_tags)}", box=box.SIMPLE_HEAVY)
table.add_column("Name", style="cyan", no_wrap=True)
table.add_column("Size", style="magenta")
table.add_column("Tags", style="yellow")
table.add_column("Custom Tags", style="red")
for model in models:
custom_tags = ", ".join(self.custom_tags.get(model['name'], []))
table.add_row(
model['name'],
model['size'],
", ".join(model['tags']),
custom_tags
)
self.console.print(Panel(table, expand=False, border_style="bold green"))
def display_model_info(self, model_name: str, info: Dict):
tree = Tree(f"[bold cyan]{model_name}[/bold cyan]")
details = tree.add("Details")
for key, value in info['details'].items():
details.add(f"[yellow]{key}[/yellow]: {value}")
modelfile = tree.add("Modelfile")
modelfile.add(Syntax(info['modelfile'], "dockerfile", theme="monokai", line_numbers=True))
parameters = tree.add("Parameters")
for line in info['parameters'].split('\n'):
parameters.add(line)
tags = tree.add("Tags")
tags.add(", ".join(info['tags']))
self.console.print(Panel(tree, title="Model Information", expand=False, border_style="bold magenta"))
def display_embeddings(self, model_name: str, embeddings: List[float]):
table = Table(title=f"Embeddings from {model_name}", box=box.SIMPLE)
table.add_column("Index", style="cyan", justify="right")
table.add_column("Value", style="magenta")
for i, value in enumerate(embeddings[:10]): # Display first 10 values
table.add_row(str(i), f"{value:.6f}")
self.console.print(Panel(table, expand=False, border_style="bold yellow"))
def add_custom_tag(self, model_name: str, tag: str):
if model_name not in self.custom_tags:
self.custom_tags[model_name] = []
if tag not in self.custom_tags[model_name]:
self.custom_tags[model_name].append(tag)
self.save_custom_tags()
return True
return False
def remove_custom_tag(self, model_name: str, tag: str):
if model_name in self.custom_tags and tag in self.custom_tags[model_name]:
self.custom_tags[model_name].remove(tag)
if not self.custom_tags[model_name]:
del self.custom_tags[model_name]
self.save_custom_tags()
return True
return False
def load_custom_tags(self):
try:
with open('custom_tags.json', 'r') as f:
return json.load(f)
except FileNotFoundError:
return {}
def save_custom_tags(self):
with open('custom_tags.json', 'w') as f:
json.dump(self.custom_tags, f)
def main():
parser = argparse.ArgumentParser(description="Ollama Model Manager")
parser.add_argument("--list", action="store_true", help="List all models")
parser.add_argument("--add-tag", nargs=2, metavar=('MODEL', 'TAG'), help="Add a custom tag to a model")
parser.add_argument("--remove-tag", nargs=2, metavar=('MODEL', 'TAG'), help="Remove a custom tag from a model")
parser.add_argument("--filter", nargs='+', help="Filter models by tags")
parser.add_argument("--info", metavar='MODEL', help="Show detailed information for a specific model")
parser.add_argument("--embeddings", nargs=2, metavar=('MODEL', 'TEXT'), help="Generate embeddings for the given text using the specified model")
parser.add_argument("--create", nargs=2, metavar=('NAME', 'MODELFILE'), help="Create a new model")
parser.add_argument("--delete", metavar='MODEL', help="Delete a model")
parser.add_argument("--copy", nargs=2, metavar=('SOURCE', 'DESTINATION'), help="Copy a model")
parser.add_argument("--pull", metavar='MODEL', help="Pull a model")
parser.add_argument("--push", metavar='MODEL', help="Push a model")
args = parser.parse_args()
api = OllamaModelManager()
if args.list or not any(vars(args).values()):
all_models = api.list_models_with_tags()
api.display_all_models(all_models)
if args.add_tag:
model, tag = args.add_tag
if api.add_custom_tag(model, tag):
api.console.print(f"[green]Added tag '{tag}' to {model}[/green]")
else:
api.console.print(f"[red]Failed to add tag '{tag}' to {model}[/red]")
if args.remove_tag:
model, tag = args.remove_tag
if api.remove_custom_tag(model, tag):
api.console.print(f"[green]Removed tag '{tag}' from {model}[/green]")
else:
api.console.print(f"[red]Failed to remove tag '{tag}' from {model}[/red]")
if args.filter:
filtered_models = api.list_models_with_tags(args.filter)
api.display_filtered_models(filtered_models, args.filter)
if args.info:
model_info = api.show_model_info(args.info)
if model_info:
api.display_model_info(args.info, model_info)
else:
api.console.print(f"[red]Unable to fetch information for model: {args.info}[/red]")
if args.embeddings:
model, text = args.embeddings
embeddings = api.generate_embeddings(model, text)
if embeddings:
api.display_embeddings(model, embeddings['embeddings'][0])
else:
api.console.print(f"[red]Unable to generate embeddings using model: {model}[/red]")
if args.create:
name, modelfile = args.create
with open(modelfile, 'r') as f:
modelfile_content = f.read()
api.create_model(name, modelfile_content)
if args.delete:
api.delete_model(args.delete)
if args.copy:
source, destination = args.copy
api.copy_model(source, destination)
if args.pull:
api.pull_model(args.pull)
if args.push:
api.push_model(args.push)
if __name__ == "__main__":
main()
@itsPreto
Copy link
Author

what do you get when you run ollama list? becasue as long as you have at least 1 model downloaded it should work since --list is just running ollama list under the hood.

@NEWbie0709
Copy link

this is what i get from ollama list
image

@itsPreto
Copy link
Author

@NEWbie0709 can you update these 3 functions and tell me if you see something like the output below?

Screenshot 2024-08-16 at 1 53 11 PM
def parse_line(self, line: str) -> Dict[str, str]:
    parts = line.split()
    if len(parts) < 4:
        return {}  # Return an empty dict if the line doesn't have enough parts
    return {
        'name': parts[0],
        'id': parts[1],
        'size': parts[2],
        'modified': ' '.join(parts[3:])
    }
def add_tags(self, model: Dict[str, str]) -> List[str]:
    tags = []

    # Size-based tags
    size_str = model['size'].replace(' ', '')  # Remove any spaces
    size = 0
    unit = 'GB'
    if size_str.endswith('MB'):
        size = float(size_str[:-2])
        size /= 1024  # Convert MB to GB
    elif size_str.endswith('GB'):
        size = float(size_str[:-2])
    else:
        try:
            size = float(size_str)
        except ValueError:
            tags.append('unknown_size')

    if size > 0:
        if size < 1:
            tags.append('tiny')
        elif size < 3:
            tags.append('small')
        elif size < 7:
            tags.append('medium')
        else:
            tags.append('large')

    # Recency tags
    modified = model['modified']
    if 'hours' in modified:
        tags.append('recent')
    elif 'days' in modified:
        try:
            days = int(modified.split()[0])
            if days <= 7:
                tags.append('week-old')
            else:
                tags.append('older')
        except ValueError:
            tags.append('unknown_age')
    else:
        tags.append('older')

    # Model type tags
    name = model['name'].lower()
    if 'embed' in name:
        tags.append('embedding')
    if any(prefix in name for prefix in ['llama', 'mistral', 'phi']):
        tags.append('llm')
    if 'code' in name:
        tags.append('code')

    # Add custom tags
    if model['name'] in self.custom_tags:
        tags.extend(self.custom_tags[model['name']])

    return tags
def list_models_with_tags(self, tags: Optional[List[str]] = None) -> List[Dict[str, str]]:
    models = []
    for line in self.get_ollama_list():
        model = self.parse_line(line)
        if model:  # Only add non-empty dictionaries
            model['tags'] = self.add_tags(model)
            models.append(model)
    return self.filter_models(models, tags)

@NEWbie0709
Copy link

Sorry for the late reply. I changed the code, and it works perfectly. Thanks for the help!
image
image
Thankyou!! :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment