Created
February 26, 2025 20:22
-
-
Save deseven/48468051dfffd4a955dafd910d7cd0ea to your computer and use it in GitHub Desktop.
Flux:Schnell pipe for OpenWebUI with images saving
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
title: FLUX Schnell Manifold Function for Black Forest Lab Image Generation Models | |
author: Balaxxe, credit to mobilestack and bgeneto | |
author_url: https://github.com/jaim12005/open-webui-flux-1.1-pro-ultra | |
funding_url: https://github.com/open-webui | |
version: 1.5 | |
license: MIT | |
requirements: pydantic>=2.0.0, aiohttp>=3.8.1 | |
environment_variables: | |
- REPLICATE_API_TOKEN (required) | |
- FLUX_GO_FAST (optional, default: true) | |
- FLUX_DISABLE_SAFETY (optional, default: false) | |
- FLUX_SEED (optional) | |
- FLUX_ASPECT_RATIO (optional, default: "1:1") | |
- FLUX_OUTPUT_FORMAT (optional, default: "webp") | |
- FLUX_OUTPUT_QUALITY (optional, default: 80) | |
- FLUX_NUM_OUTPUTS (optional, default: 1) | |
supported providers: replicate.com | |
NOTE: Due to the asynchronous nature of the Replicate API, each image generation will make 2-3 (rare occasion 4) API requests: | |
1. Initial request to start generation | |
2. Follow-up request(s) to check completion status | |
This is normal behavior and required by the API design. You will typically see only 2 requests after the first generation. | |
NOTE: Model and/or API provider request? Shoot me a link on the Open WebUI discord - same username :) | |
""" | |
from typing import Dict, AsyncIterator, Optional, Literal, cast, Union, List | |
from pydantic import BaseModel, Field | |
import os | |
import base64 | |
import aiohttp | |
import asyncio | |
from datetime import datetime | |
import json | |
import uuid | |
import time | |
from open_webui.config import UPLOAD_DIR | |
from open_webui.models.files import FileForm, FileModel, FileModelResponse, Files | |
from open_webui.storage.provider import Storage | |
from io import BytesIO | |
AspectRatioType = Literal[ | |
"1:1", "16:9", "21:9", "3:2", "2:3", "4:5", "5:4", "3:4", "4:3", "9:16", "9:21" | |
] | |
OutputFormatType = Literal["webp", "jpg", "png"] | |
MegapixelsType = Literal["1", "0.25"] | |
class Pipe: | |
class Valves(BaseModel): | |
REPLICATE_API_TOKEN: str = Field( | |
default="", description="Your Replicate API token" | |
) | |
FLUX_GO_FAST: bool = Field(default=True, description="Enable fast mode") | |
FLUX_DISABLE_SAFETY: bool = Field( | |
default=False, description="Disable the built-in safety checker (API only)" | |
) | |
FLUX_SEED: Optional[int] = Field( | |
default=None, description="Random seed for reproducible generations" | |
) | |
FLUX_ASPECT_RATIO: AspectRatioType = Field( | |
default="1:1", description="Output image aspect ratio" | |
) | |
FLUX_OUTPUT_FORMAT: OutputFormatType = Field( | |
default="webp", description="Output image format" | |
) | |
FLUX_OUTPUT_QUALITY: int = Field( | |
default=80, description="Output image quality (1-100)" | |
) | |
FLUX_NUM_OUTPUTS: int = Field( | |
default=1, description="Number of images to generate" | |
) | |
def __init__(self): | |
self.type = "pipe" | |
self.id = "flux_schnell" | |
self.name = "Flux Schnell" | |
self.MODEL_URL = "https://api.replicate.com/v1/models/black-forest-labs/flux-schnell/predictions" | |
self.valves = self.Valves( | |
**{ | |
k: v | |
for k, v in { | |
"REPLICATE_API_TOKEN": os.getenv("REPLICATE_API_TOKEN", ""), | |
"FLUX_SEED": ( | |
int(os.getenv("FLUX_SEED")) if os.getenv("FLUX_SEED") else None | |
), | |
"FLUX_ASPECT_RATIO": os.getenv("FLUX_ASPECT_RATIO", "1:1"), | |
"FLUX_OUTPUT_FORMAT": os.getenv("FLUX_OUTPUT_FORMAT", "webp"), | |
"FLUX_GO_FAST": bool(os.getenv("FLUX_GO_FAST", True)), | |
"FLUX_NUM_OUTPUTS": int(os.getenv("FLUX_NUM_OUTPUTS", "1")), | |
"FLUX_OUTPUT_QUALITY": int(os.getenv("FLUX_OUTPUT_QUALITY", "80")), | |
}.items() | |
if v is not None | |
} | |
) | |
def _get_status(self, message: str) -> str: | |
"""Format a status message with timestamp.""" | |
timestamp = datetime.now().strftime("%H:%M:%S") | |
return f"[{timestamp}] {message}" | |
def _create_file( | |
self, image_data: bytes, content_type: str, prompt: str, __user__: dict | |
) -> Optional[str]: | |
"""Save image using OpenWebUI's file handling system""" | |
try: | |
file_id = str(uuid.uuid4()) | |
file_ext = content_type.split("/")[-1].split(";")[0].strip() | |
original_name = f"flux_image.{file_ext}" | |
stored_filename = f"{file_id}_{original_name}" | |
# Use OpenWebUI's storage system | |
file_obj = BytesIO(image_data) | |
_, file_path = Storage.upload_file(file_obj, stored_filename) | |
# Create proper metadata structure | |
meta = { | |
"name": original_name, | |
"content_type": content_type, | |
"size": len(image_data), | |
"data": {"prompt": prompt}, # Add any additional metadata here | |
} | |
# Create FileForm according to OpenWebUI's API requirements | |
file_form = FileForm( | |
id=file_id, filename=original_name, meta=meta, path=file_path | |
) | |
# Insert into database | |
file = Files.insert_new_file(__user__["id"], file_form) | |
return file.id | |
except Exception as e: | |
print(f"Error saving file: {e}") | |
return None | |
def _get_file_url(self, file_id: str) -> str: | |
return f"/api/v1/files/{file_id}/content" | |
async def _process_image( | |
self, url_or_data: str, prompt: str, params: Dict, stream: bool = True | |
) -> Union[str, List[str]]: | |
"""Process image data and return it in SSE format.""" | |
if url_or_data.startswith("http"): | |
async with aiohttp.ClientSession() as session: | |
async with session.get(url_or_data, timeout=30) as response: | |
response.raise_for_status() | |
image_data = base64.b64encode(await response.read()).decode("utf-8") | |
content_type = response.headers.get( | |
"Content-Type", f"image/{self.valves.FLUX_OUTPUT_FORMAT}" | |
) | |
image_url = f"data:{content_type};base64,{image_data}" | |
else: | |
image_url = url_or_data | |
if not stream: | |
return f'<img src="{image_url}" alt="Generated Image" />' | |
responses = [] | |
responses.append( | |
self._create_sse_chunk( | |
f'<img src="{image_url}" alt="Generated Image" style="max-width: 100%; height: auto;" />' | |
) | |
) | |
responses.append(self._create_sse_chunk({}, finish_reason="stop")) | |
responses.append("data: [DONE]\n\n") | |
return responses | |
def _create_sse_chunk( | |
self, | |
content: Union[str, Dict], | |
content_type: str = "text/html", | |
finish_reason: Optional[str] = None, | |
) -> str: | |
"""Create a Server-Sent Events chunk.""" | |
chunk_data = { | |
"id": f"chatcmpl-{uuid.uuid4()}", | |
"object": "chat.completion.chunk", | |
"created": int(time.time()), | |
"model": "flux-schnell", | |
"choices": [ | |
{ | |
"delta": ( | |
{} | |
if finish_reason | |
else { | |
"role": "assistant", | |
"content": content, | |
"content_type": content_type, | |
} | |
), | |
"index": 0, | |
"finish_reason": finish_reason, | |
} | |
], | |
} | |
return f"data: {json.dumps(chunk_data)}\n\n" | |
async def _wait_for_completion( | |
self, prediction_url: str, __event_emitter__=None | |
) -> Dict: | |
headers = { | |
"Authorization": f"Token {self.valves.REPLICATE_API_TOKEN}", | |
"Accept": "application/json", | |
"Prefer": "wait=30", | |
} | |
async with aiohttp.ClientSession() as session: | |
await asyncio.sleep(0.5) | |
async with session.get( | |
prediction_url, headers=headers, timeout=35 | |
) as response: | |
response.raise_for_status() | |
result = await response.json() | |
if result.get("status") in ["succeeded", "failed", "canceled"]: | |
return result | |
await asyncio.sleep(0.3) | |
async with session.get( | |
prediction_url, headers=headers, timeout=35 | |
) as response: | |
response.raise_for_status() | |
result = await response.json() | |
if result.get("status") in ["succeeded", "failed", "canceled"]: | |
return result | |
await asyncio.sleep(0.3) | |
async with session.get( | |
prediction_url, headers=headers, timeout=35 | |
) as response: | |
response.raise_for_status() | |
final_result = await response.json() | |
if final_result.get("status") in ["succeeded", "failed", "canceled"]: | |
return final_result | |
raise Exception( | |
f"Generation incomplete after {final_result.get('status')} status" | |
) | |
async def pipe( | |
self, body: Dict, __user__: Dict, __event_emitter__=None | |
) -> AsyncIterator[str]: | |
if not self.valves.REPLICATE_API_TOKEN: | |
yield "Error: REPLICATE_API_TOKEN is required" | |
return | |
try: | |
prompt = (body.get("messages", [{}])[-1].get("content", "") or "").strip() | |
if not prompt: | |
yield "Error: No prompt provided" | |
return | |
input_params = { | |
"prompt": prompt, | |
"go_fast": self.valves.FLUX_GO_FAST, | |
"num_outputs": self.valves.FLUX_NUM_OUTPUTS, | |
"aspect_ratio": self.valves.FLUX_ASPECT_RATIO, | |
"output_format": self.valves.FLUX_OUTPUT_FORMAT, | |
"output_quality": self.valves.FLUX_OUTPUT_QUALITY, | |
"disable_safety_checker": self.valves.FLUX_DISABLE_SAFETY, | |
} | |
if self.valves.FLUX_SEED is not None: | |
input_params["seed"] = self.valves.FLUX_SEED | |
if __event_emitter__: | |
await __event_emitter__( | |
{ | |
"type": "status", | |
"data": { | |
"description": "Starting Flux Schnell generation...", | |
"done": False, | |
}, | |
} | |
) | |
async with aiohttp.ClientSession() as session: | |
async with session.post( | |
self.MODEL_URL, | |
headers={ | |
"Authorization": f"Token {self.valves.REPLICATE_API_TOKEN}", | |
"Content-Type": "application/json", | |
"Prefer": "wait=30", | |
}, | |
json={"input": input_params}, | |
timeout=35, | |
) as response: | |
response.raise_for_status() | |
prediction = await response.json() | |
result = await self._wait_for_completion( | |
prediction["urls"]["get"], __event_emitter__ | |
) | |
if result.get("status") != "succeeded": | |
raise Exception( | |
f"Generation failed: {result.get('error', 'Unknown error')}" | |
) | |
# Handle outputs and save files | |
saved_urls = [] | |
outputs = result.get("output", []) | |
if isinstance(outputs, str): | |
outputs = [outputs] | |
for output_url in outputs: | |
async with session.get(output_url) as response: | |
if response.status != 200: | |
continue | |
image_data = await response.read() | |
content_type = response.headers.get( | |
"Content-Type", f"image/{self.valves.FLUX_OUTPUT_FORMAT}" | |
) | |
# Save file using OpenWebUI's system | |
file_id = self._create_file( | |
image_data, content_type, prompt, __user__ | |
) | |
if file_id: | |
file_url = self._get_file_url(file_id) | |
saved_urls.append(file_url) | |
if not saved_urls: | |
raise Exception("No images were saved") | |
# Send markdown with permanent URLs | |
if __event_emitter__: | |
content = "\n".join( | |
f"" for url in saved_urls | |
) | |
await __event_emitter__( | |
{ | |
"type": "message", | |
"data": { | |
"content": content, | |
"content_type": "text/markdown", | |
}, | |
} | |
) | |
await __event_emitter__( | |
{ | |
"type": "status", | |
"data": { | |
"description": "Image generated successfully!", | |
"done": True, | |
}, | |
} | |
) | |
yield "" | |
except Exception as e: | |
error_msg = f"Error: {str(e)}" | |
if __event_emitter__: | |
await __event_emitter__( | |
{"type": "status", "data": {"description": error_msg, "done": True}} | |
) | |
yield error_msg |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment