Last active
August 8, 2024 04:47
-
-
Save twobob/6b22d51db58b7eddc1767a58a6641031 to your computer and use it in GitHub Desktop.
nod to Vvan gemert
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# First, in your terminal. | |
# | |
# $ python3 -m virtualenv env | |
# $ source env/bin/activate | |
# $ pip install torch torchvision transformers sentencepiece protobuf accelerate | |
# $ pip install git+https://github.com/huggingface/diffusers.git | |
# $ pip install [email protected] | |
# $ pip install gradio | |
import os | |
import torch | |
import gradio as gr | |
from optimum.quanto import freeze, qfloat8, quantize | |
from diffusers import FlowMatchEulerDiscreteScheduler, AutoencoderKL | |
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel | |
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline | |
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast | |
from datetime import datetime | |
from PIL import Image | |
dtype = torch.bfloat16 | |
MODEL_OPTIONS = { | |
"FLUX.1-schnell": { | |
"repo": "black-forest-labs/FLUX.1-schnell", | |
"revision": "refs/pr/1", | |
"default_steps": 4 | |
}, | |
"FLUX.1-dev": { | |
"repo": "black-forest-labs/FLUX.1-dev", | |
"revision": "fcb137eff9cee3c5fda4e908a1d3da105323a5bc", | |
"default_steps": 15 | |
} | |
} | |
def load_or_quantize_model(model_class, model_name, quantized_path, bfl_repo, revision, **kwargs): | |
if os.path.exists(quantized_path): | |
print(f"Loading quantized {model_name} from {quantized_path}") | |
return torch.load(quantized_path) | |
else: | |
print(f"Quantizing {model_name}") | |
model = model_class.from_pretrained(bfl_repo, subfolder=model_name, torch_dtype=dtype, revision=revision, **kwargs) | |
quantize(model, weights=qfloat8) | |
freeze(model) | |
torch.save(model, quantized_path) | |
return model | |
def load_model(model_option): | |
bfl_repo = MODEL_OPTIONS[model_option]["repo"] | |
revision = MODEL_OPTIONS[model_option]["revision"] | |
# Load or create quantized models | |
transformer = load_or_quantize_model(FluxTransformer2DModel, "transformer", f"quantized_transformer_{model_option}.pt", bfl_repo, revision) | |
text_encoder_2 = load_or_quantize_model(T5EncoderModel, "text_encoder_2", f"quantized_text_encoder_2_{model_option}.pt", bfl_repo, revision) | |
# Load other components | |
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(bfl_repo, subfolder="scheduler", revision=revision) | |
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=dtype) | |
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=dtype) | |
tokenizer_2 = T5TokenizerFast.from_pretrained(bfl_repo, subfolder="tokenizer_2", torch_dtype=dtype, revision=revision) | |
vae = AutoencoderKL.from_pretrained(bfl_repo, subfolder="vae", torch_dtype=dtype, revision=revision) | |
pipe = FluxPipeline( | |
scheduler=scheduler, | |
text_encoder=text_encoder, | |
tokenizer=tokenizer, | |
text_encoder_2=None, | |
tokenizer_2=tokenizer_2, | |
vae=vae, | |
transformer=None, | |
) | |
pipe.text_encoder_2 = text_encoder_2 | |
pipe.transformer = transformer | |
pipe.enable_model_cpu_offload() | |
return pipe | |
# Global variable to store the current pipeline | |
current_pipe = None | |
def generate(model_option, prompt, steps, guidance, width, height, seed, save_locally): | |
global current_pipe | |
if current_pipe is None or current_pipe.model_option != model_option: | |
current_pipe = load_model(model_option) | |
current_pipe.model_option = model_option | |
if seed == -1: | |
seed = torch.seed() | |
generator = torch.Generator().manual_seed(int(seed)) | |
image = current_pipe( | |
prompt=prompt, | |
width=width, | |
height=height, | |
num_inference_steps=int(steps), | |
generator=generator, | |
guidance_scale=guidance, | |
).images[0] | |
if save_locally: | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
filename = f"flux_generation_{timestamp}.png" | |
image.save(filename) | |
print(f"Image saved as {filename}") | |
return image | |
def update_steps(model_option): | |
return MODEL_OPTIONS[model_option]["default_steps"] | |
with gr.Blocks() as demo: | |
gr.Markdown("# FLUX Image Generation") | |
gr.Markdown("Generate images using FLUX models. Choose between FLUX.1-schnell (faster, 4 steps) and FLUX.1-dev (higher quality, 15 steps).") | |
with gr.Row(): | |
model_dropdown = gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), label="Model", value="FLUX.1-schnell") | |
steps_input = gr.Number(value=4, precision=0, label="Steps") | |
prompt_input = gr.Textbox(label="Prompt") | |
guidance_input = gr.Number(value=3.5, label="Guidance Scale") | |
width_input = gr.Slider(0, 1920, value=1024, step=2, label="Width") | |
height_input = gr.Slider(0, 1920, value=1024, step=2, label="Height") | |
seed_input = gr.Number(value=-1, precision=0, label="Seed") | |
save_locally_checkbox = gr.Checkbox(label="Save generations locally", value=False) | |
generate_button = gr.Button("Generate") | |
output_image = gr.Image(label="Generated Image") | |
generate_button.click( | |
generate, | |
inputs=[model_dropdown, prompt_input, steps_input, guidance_input, width_input, height_input, seed_input, save_locally_checkbox], | |
outputs=output_image | |
) | |
model_dropdown.change( | |
update_steps, | |
inputs=[model_dropdown], | |
outputs=[steps_input] | |
) | |
demo.launch(server_name="127.0.0.1") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
pip install [email protected] not 2.1.0