Last active
August 21, 2024 19:36
-
-
Save razvanab/25f011ef7484905676ab281b224a82bc to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# An application that functions as a local SD model browser for the Civitai models that you downloaded. | |
# If you think you can improve this to be better, please do. | |
# Created with various LLM models. | |
import streamlit as st | |
import dateutil.parser | |
import pandas as pd | |
from datetime import datetime | |
import os | |
import json | |
import requests | |
from urllib.parse import quote | |
from tkinter import Tk, filedialog | |
from time import time | |
# Set page config | |
st.set_page_config(page_title="Stable Diffusion Model Browser", layout="wide") | |
# Initialize session state | |
if 'models' not in st.session_state: | |
st.session_state.models = pd.DataFrame(columns=["id", "name", "type", "version", "isOfficial", "downloads", "rating", "createdAt", "path", "thumbnail"]) | |
if 'model_path' not in st.session_state: | |
st.session_state.model_path = "" | |
# Function to open a directory selection dialog | |
def select_directory(): | |
root = Tk() | |
root.withdraw() # Hide the main tkinter window | |
directory = filedialog.askdirectory() # Open the dialog to select a directory | |
root.destroy() | |
return directory | |
# Civitai API functions | |
def search_model_on_civitai(model_name): | |
encoded_name = quote(model_name) | |
url = f"https://civitai.com/api/v1/models?limit=1&query={encoded_name}" | |
try: | |
response = requests.get(url) | |
response.raise_for_status() | |
data = response.json() | |
if data['items']: | |
return data['items'][0] | |
except requests.RequestException as e: | |
st.error(f"Error fetching data from Civitai: {e}") | |
return None | |
def get_model_details(model_id): | |
url = f"https://civitai.com/api/v1/models/{model_id}" | |
try: | |
response = requests.get(url) | |
response.raise_for_status() | |
return response.json() | |
except requests.RequestException as e: | |
st.error(f"Error fetching model details: {e}") | |
return None | |
def scan_models(path): | |
models = [] | |
start_time = time() | |
# Scan models directory | |
with st.spinner("Scanning models..."): | |
for root, dirs, files in os.walk(path): | |
for file in files: | |
if file.endswith(('.ckpt', '.safetensors', '.pt')): | |
file_path = os.path.join(root, file) | |
# Ensure file exists | |
if not os.path.exists(file_path): | |
st.warning(f"File not found: {file_path}") | |
continue | |
model_name = os.path.splitext(file)[0] | |
st.write(f"Scanning: {file_path}") # Status update | |
civitai_model = search_model_on_civitai(model_name) | |
if civitai_model: | |
model_details = get_model_details(civitai_model['id']) | |
if model_details: | |
created_at_raw = model_details.get('created_at', None) | |
if created_at_raw: | |
try: | |
created_at = dateutil.parser.parse(created_at_raw).strftime("%Y-%m-%d") | |
except dateutil.parser.ParserError: | |
created_at = datetime.fromtimestamp(os.path.getctime(file_path)).strftime("%Y-%m-%d") | |
else: | |
created_at = datetime.fromtimestamp(os.path.getctime(file_path)).strftime("%Y-%m-%d") | |
thumbnail = model_details.get('modelVersions', [{}])[0].get('images', [{}])[0].get('url', "https://picsum.photos/300/300") | |
else: | |
created_at = datetime.fromtimestamp(os.path.getctime(file_path)).strftime("%Y-%m-%d") | |
thumbnail = "https://picsum.photos/300/300?random=" + str(len(models) + 1) | |
models.append({ | |
"id": civitai_model['id'], | |
"name": civitai_model['name'], | |
"type": civitai_model['type'], | |
"version": model_details.get('modelVersions', [{}])[0].get('name', "Unknown"), | |
"isOfficial": False, | |
"downloads": civitai_model['stats']['downloadCount'], | |
"rating": civitai_model['stats'].get('rating', 0), | |
"createdAt": created_at, | |
"path": file_path, | |
"thumbnail": thumbnail | |
}) | |
else: | |
# Handle models not found on Civitai | |
created_at = datetime.fromtimestamp(os.path.getctime(file_path)).strftime("%Y-%m-%d") | |
thumbnail = "https://picsum.photos/300/300?random=" + str(len(models) + 1) | |
models.append({ | |
"id": len(models) + 1, | |
"name": model_name, | |
"type": "Unknown", | |
"version": "Unknown", | |
"isOfficial": False, | |
"downloads": 0, | |
"rating": 0, | |
"createdAt": created_at, | |
"path": file_path, | |
"thumbnail": thumbnail | |
}) | |
elapsed_time = time() - start_time | |
st.write(f"Scanning completed in {elapsed_time:.2f} seconds") | |
return pd.DataFrame(models) | |
# Sidebar for options | |
st.sidebar.title("Options") | |
# Add a button to select model directory | |
if st.sidebar.button('Select Model Directory'): | |
selected_directory = select_directory() | |
if selected_directory: | |
st.session_state.model_path = selected_directory | |
if os.path.exists(selected_directory): | |
st.sidebar.write("Scanning models and fetching details from Civitai...") | |
with st.spinner("Scanning models..."): | |
st.session_state.models = scan_models(selected_directory) | |
st.sidebar.success("Models scanned successfully!") | |
else: | |
st.sidebar.error("Invalid path") | |
# Add an input field for manual path entry | |
model_path = st.sidebar.text_input("Or enter model path manually", st.session_state.model_path) | |
if model_path and model_path != st.session_state.model_path: | |
st.session_state.model_path = model_path | |
if os.path.exists(model_path): | |
st.sidebar.write("Scanning models and fetching details from Civitai...") | |
with st.spinner("Scanning models..."): | |
st.session_state.models = scan_models(model_path) | |
st.sidebar.success("Models scanned successfully!") | |
else: | |
st.sidebar.error("Invalid path") | |
# Refresh button | |
if st.sidebar.button("Refresh Models"): | |
if os.path.exists(st.session_state.model_path): | |
st.sidebar.write("Refreshing models and fetching details from Civitai...") | |
with st.spinner("Refreshing models..."): | |
st.session_state.models = scan_models(st.session_state.model_path) | |
st.sidebar.success("Models refreshed successfully!") | |
else: | |
st.sidebar.error("Invalid path") | |
# Main content | |
st.title("Stable Diffusion Model Browser") | |
# Search bar | |
search_term = st.text_input("Search models", "") | |
# Filters | |
col1, col2, col3, col4 = st.columns(4) | |
with col1: | |
type_filter = st.selectbox("Type", ["All"] + list(st.session_state.models['type'].unique())) | |
with col2: | |
version_filter = st.selectbox("Version", ["All"] + list(st.session_state.models['version'].unique())) | |
with col3: | |
source_filter = st.multiselect("Source", ["Official", "Community"]) | |
with col4: | |
sort_by = st.selectbox("Sort by", ["Name", "Type", "Version", "Downloads", "Rating", "Created Date"]) | |
# Apply filters | |
filtered_models = st.session_state.models | |
if search_term: | |
filtered_models = filtered_models[filtered_models['name'].str.contains(search_term, case=False)] | |
if type_filter != "All": | |
filtered_models = filtered_models[filtered_models['type'] == type_filter] | |
if version_filter != "All": | |
filtered_models = filtered_models[filtered_models['version'] == version_filter] | |
if "Official" in source_filter and "Community" not in source_filter: | |
filtered_models = filtered_models[filtered_models['isOfficial'] == True] | |
elif "Community" in source_filter and "Official" not in source_filter: | |
filtered_models = filtered_models[filtered_models['isOfficial'] == False] | |
# Apply sorting | |
if sort_by == "Name": | |
filtered_models = filtered_models.sort_values('name') | |
elif sort_by == "Type": | |
filtered_models = filtered_models.sort_values('type') | |
elif sort_by == "Version": | |
filtered_models = filtered_models.sort_values('version') | |
elif sort_by == "Downloads": | |
filtered_models = filtered_models.sort_values('downloads', ascending=False) | |
elif sort_by == "Rating": | |
filtered_models = filtered_models.sort_values('rating', ascending=False) | |
elif sort_by == "Created Date": | |
filtered_models = filtered_models.sort_values('createdAt', ascending=False) | |
# Display models in a grid | |
num_cols = min(len(filtered_models), 3) # Ensure num_cols is a positive integer | |
cols = st.columns(num_cols if num_cols > 0 else 1) # Create columns, default to 1 if num_cols is 0 | |
# Display each model | |
for index, model in filtered_models.iterrows(): | |
with cols[index % len(cols)]: | |
st.image(model['thumbnail'], use_column_width=True) | |
st.subheader(model['name']) | |
st.write(f"Type: {model['type']}") | |
st.write(f"Version: {model['version']}") | |
st.write(f"Downloads: {model['downloads']:,}") | |
st.write(f"Rating: {model['rating']:.2f}") | |
st.write(f"Created: {model['createdAt']}") | |
st.write(f"Path: {model['path']}") | |
# Save models data to JSON file | |
if st.button("Save Models Data"): | |
file_path = st.text_input("Enter path to save data", "models_data.json") | |
if file_path: | |
try: | |
with open(file_path, "w") as f: | |
json.dump(st.session_state.models.to_dict(orient="records"), f) | |
st.success("Models data saved successfully!") | |
except Exception as e: | |
st.error(f"Error saving models data: {e}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment