Skip to content

Instantly share code, notes, and snippets.

@razvanab
Last active August 21, 2024 19:36
Show Gist options
  • Save razvanab/25f011ef7484905676ab281b224a82bc to your computer and use it in GitHub Desktop.
Save razvanab/25f011ef7484905676ab281b224a82bc to your computer and use it in GitHub Desktop.
# An application that functions as a local SD model browser for the Civitai models that you downloaded.
# If you think you can improve this to be better, please do.
# Created with various LLM models.
import streamlit as st
import dateutil.parser
import pandas as pd
from datetime import datetime
import os
import json
import requests
from urllib.parse import quote
from tkinter import Tk, filedialog
from time import time
# Set page config
st.set_page_config(page_title="Stable Diffusion Model Browser", layout="wide")
# Initialize session state
if 'models' not in st.session_state:
st.session_state.models = pd.DataFrame(columns=["id", "name", "type", "version", "isOfficial", "downloads", "rating", "createdAt", "path", "thumbnail"])
if 'model_path' not in st.session_state:
st.session_state.model_path = ""
# Function to open a directory selection dialog
def select_directory():
root = Tk()
root.withdraw() # Hide the main tkinter window
directory = filedialog.askdirectory() # Open the dialog to select a directory
root.destroy()
return directory
# Civitai API functions
def search_model_on_civitai(model_name):
encoded_name = quote(model_name)
url = f"https://civitai.com/api/v1/models?limit=1&query={encoded_name}"
try:
response = requests.get(url)
response.raise_for_status()
data = response.json()
if data['items']:
return data['items'][0]
except requests.RequestException as e:
st.error(f"Error fetching data from Civitai: {e}")
return None
def get_model_details(model_id):
url = f"https://civitai.com/api/v1/models/{model_id}"
try:
response = requests.get(url)
response.raise_for_status()
return response.json()
except requests.RequestException as e:
st.error(f"Error fetching model details: {e}")
return None
def scan_models(path):
models = []
start_time = time()
# Scan models directory
with st.spinner("Scanning models..."):
for root, dirs, files in os.walk(path):
for file in files:
if file.endswith(('.ckpt', '.safetensors', '.pt')):
file_path = os.path.join(root, file)
# Ensure file exists
if not os.path.exists(file_path):
st.warning(f"File not found: {file_path}")
continue
model_name = os.path.splitext(file)[0]
st.write(f"Scanning: {file_path}") # Status update
civitai_model = search_model_on_civitai(model_name)
if civitai_model:
model_details = get_model_details(civitai_model['id'])
if model_details:
created_at_raw = model_details.get('created_at', None)
if created_at_raw:
try:
created_at = dateutil.parser.parse(created_at_raw).strftime("%Y-%m-%d")
except dateutil.parser.ParserError:
created_at = datetime.fromtimestamp(os.path.getctime(file_path)).strftime("%Y-%m-%d")
else:
created_at = datetime.fromtimestamp(os.path.getctime(file_path)).strftime("%Y-%m-%d")
thumbnail = model_details.get('modelVersions', [{}])[0].get('images', [{}])[0].get('url', "https://picsum.photos/300/300")
else:
created_at = datetime.fromtimestamp(os.path.getctime(file_path)).strftime("%Y-%m-%d")
thumbnail = "https://picsum.photos/300/300?random=" + str(len(models) + 1)
models.append({
"id": civitai_model['id'],
"name": civitai_model['name'],
"type": civitai_model['type'],
"version": model_details.get('modelVersions', [{}])[0].get('name', "Unknown"),
"isOfficial": False,
"downloads": civitai_model['stats']['downloadCount'],
"rating": civitai_model['stats'].get('rating', 0),
"createdAt": created_at,
"path": file_path,
"thumbnail": thumbnail
})
else:
# Handle models not found on Civitai
created_at = datetime.fromtimestamp(os.path.getctime(file_path)).strftime("%Y-%m-%d")
thumbnail = "https://picsum.photos/300/300?random=" + str(len(models) + 1)
models.append({
"id": len(models) + 1,
"name": model_name,
"type": "Unknown",
"version": "Unknown",
"isOfficial": False,
"downloads": 0,
"rating": 0,
"createdAt": created_at,
"path": file_path,
"thumbnail": thumbnail
})
elapsed_time = time() - start_time
st.write(f"Scanning completed in {elapsed_time:.2f} seconds")
return pd.DataFrame(models)
# Sidebar for options
st.sidebar.title("Options")
# Add a button to select model directory
if st.sidebar.button('Select Model Directory'):
selected_directory = select_directory()
if selected_directory:
st.session_state.model_path = selected_directory
if os.path.exists(selected_directory):
st.sidebar.write("Scanning models and fetching details from Civitai...")
with st.spinner("Scanning models..."):
st.session_state.models = scan_models(selected_directory)
st.sidebar.success("Models scanned successfully!")
else:
st.sidebar.error("Invalid path")
# Add an input field for manual path entry
model_path = st.sidebar.text_input("Or enter model path manually", st.session_state.model_path)
if model_path and model_path != st.session_state.model_path:
st.session_state.model_path = model_path
if os.path.exists(model_path):
st.sidebar.write("Scanning models and fetching details from Civitai...")
with st.spinner("Scanning models..."):
st.session_state.models = scan_models(model_path)
st.sidebar.success("Models scanned successfully!")
else:
st.sidebar.error("Invalid path")
# Refresh button
if st.sidebar.button("Refresh Models"):
if os.path.exists(st.session_state.model_path):
st.sidebar.write("Refreshing models and fetching details from Civitai...")
with st.spinner("Refreshing models..."):
st.session_state.models = scan_models(st.session_state.model_path)
st.sidebar.success("Models refreshed successfully!")
else:
st.sidebar.error("Invalid path")
# Main content
st.title("Stable Diffusion Model Browser")
# Search bar
search_term = st.text_input("Search models", "")
# Filters
col1, col2, col3, col4 = st.columns(4)
with col1:
type_filter = st.selectbox("Type", ["All"] + list(st.session_state.models['type'].unique()))
with col2:
version_filter = st.selectbox("Version", ["All"] + list(st.session_state.models['version'].unique()))
with col3:
source_filter = st.multiselect("Source", ["Official", "Community"])
with col4:
sort_by = st.selectbox("Sort by", ["Name", "Type", "Version", "Downloads", "Rating", "Created Date"])
# Apply filters
filtered_models = st.session_state.models
if search_term:
filtered_models = filtered_models[filtered_models['name'].str.contains(search_term, case=False)]
if type_filter != "All":
filtered_models = filtered_models[filtered_models['type'] == type_filter]
if version_filter != "All":
filtered_models = filtered_models[filtered_models['version'] == version_filter]
if "Official" in source_filter and "Community" not in source_filter:
filtered_models = filtered_models[filtered_models['isOfficial'] == True]
elif "Community" in source_filter and "Official" not in source_filter:
filtered_models = filtered_models[filtered_models['isOfficial'] == False]
# Apply sorting
if sort_by == "Name":
filtered_models = filtered_models.sort_values('name')
elif sort_by == "Type":
filtered_models = filtered_models.sort_values('type')
elif sort_by == "Version":
filtered_models = filtered_models.sort_values('version')
elif sort_by == "Downloads":
filtered_models = filtered_models.sort_values('downloads', ascending=False)
elif sort_by == "Rating":
filtered_models = filtered_models.sort_values('rating', ascending=False)
elif sort_by == "Created Date":
filtered_models = filtered_models.sort_values('createdAt', ascending=False)
# Display models in a grid
num_cols = min(len(filtered_models), 3) # Ensure num_cols is a positive integer
cols = st.columns(num_cols if num_cols > 0 else 1) # Create columns, default to 1 if num_cols is 0
# Display each model
for index, model in filtered_models.iterrows():
with cols[index % len(cols)]:
st.image(model['thumbnail'], use_column_width=True)
st.subheader(model['name'])
st.write(f"Type: {model['type']}")
st.write(f"Version: {model['version']}")
st.write(f"Downloads: {model['downloads']:,}")
st.write(f"Rating: {model['rating']:.2f}")
st.write(f"Created: {model['createdAt']}")
st.write(f"Path: {model['path']}")
# Save models data to JSON file
if st.button("Save Models Data"):
file_path = st.text_input("Enter path to save data", "models_data.json")
if file_path:
try:
with open(file_path, "w") as f:
json.dump(st.session_state.models.to_dict(orient="records"), f)
st.success("Models data saved successfully!")
except Exception as e:
st.error(f"Error saving models data: {e}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment