Skip to content

Instantly share code, notes, and snippets.

@twobob
Last active August 8, 2024 04:47
Show Gist options
  • Save twobob/6b22d51db58b7eddc1767a58a6641031 to your computer and use it in GitHub Desktop.
Save twobob/6b22d51db58b7eddc1767a58a6641031 to your computer and use it in GitHub Desktop.
nod to Vvan gemert
# First, in your terminal.
#
# $ python3 -m virtualenv env
# $ source env/bin/activate
# $ pip install torch torchvision transformers sentencepiece protobuf accelerate
# $ pip install git+https://github.com/huggingface/diffusers.git
# $ pip install [email protected]
# $ pip install gradio
import os
import torch
import gradio as gr
from optimum.quanto import freeze, qfloat8, quantize
from diffusers import FlowMatchEulerDiscreteScheduler, AutoencoderKL
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from datetime import datetime
from PIL import Image
dtype = torch.bfloat16
MODEL_OPTIONS = {
"FLUX.1-schnell": {
"repo": "black-forest-labs/FLUX.1-schnell",
"revision": "refs/pr/1",
"default_steps": 4
},
"FLUX.1-dev": {
"repo": "black-forest-labs/FLUX.1-dev",
"revision": "fcb137eff9cee3c5fda4e908a1d3da105323a5bc",
"default_steps": 15
}
}
def load_or_quantize_model(model_class, model_name, quantized_path, bfl_repo, revision, **kwargs):
if os.path.exists(quantized_path):
print(f"Loading quantized {model_name} from {quantized_path}")
return torch.load(quantized_path)
else:
print(f"Quantizing {model_name}")
model = model_class.from_pretrained(bfl_repo, subfolder=model_name, torch_dtype=dtype, revision=revision, **kwargs)
quantize(model, weights=qfloat8)
freeze(model)
torch.save(model, quantized_path)
return model
def load_model(model_option):
bfl_repo = MODEL_OPTIONS[model_option]["repo"]
revision = MODEL_OPTIONS[model_option]["revision"]
# Load or create quantized models
transformer = load_or_quantize_model(FluxTransformer2DModel, "transformer", f"quantized_transformer_{model_option}.pt", bfl_repo, revision)
text_encoder_2 = load_or_quantize_model(T5EncoderModel, "text_encoder_2", f"quantized_text_encoder_2_{model_option}.pt", bfl_repo, revision)
# Load other components
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(bfl_repo, subfolder="scheduler", revision=revision)
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=dtype)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=dtype)
tokenizer_2 = T5TokenizerFast.from_pretrained(bfl_repo, subfolder="tokenizer_2", torch_dtype=dtype, revision=revision)
vae = AutoencoderKL.from_pretrained(bfl_repo, subfolder="vae", torch_dtype=dtype, revision=revision)
pipe = FluxPipeline(
scheduler=scheduler,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=None,
tokenizer_2=tokenizer_2,
vae=vae,
transformer=None,
)
pipe.text_encoder_2 = text_encoder_2
pipe.transformer = transformer
pipe.enable_model_cpu_offload()
return pipe
# Global variable to store the current pipeline
current_pipe = None
def generate(model_option, prompt, steps, guidance, width, height, seed, save_locally):
global current_pipe
if current_pipe is None or current_pipe.model_option != model_option:
current_pipe = load_model(model_option)
current_pipe.model_option = model_option
if seed == -1:
seed = torch.seed()
generator = torch.Generator().manual_seed(int(seed))
image = current_pipe(
prompt=prompt,
width=width,
height=height,
num_inference_steps=int(steps),
generator=generator,
guidance_scale=guidance,
).images[0]
if save_locally:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"flux_generation_{timestamp}.png"
image.save(filename)
print(f"Image saved as {filename}")
return image
def update_steps(model_option):
return MODEL_OPTIONS[model_option]["default_steps"]
with gr.Blocks() as demo:
gr.Markdown("# FLUX Image Generation")
gr.Markdown("Generate images using FLUX models. Choose between FLUX.1-schnell (faster, 4 steps) and FLUX.1-dev (higher quality, 15 steps).")
with gr.Row():
model_dropdown = gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), label="Model", value="FLUX.1-schnell")
steps_input = gr.Number(value=4, precision=0, label="Steps")
prompt_input = gr.Textbox(label="Prompt")
guidance_input = gr.Number(value=3.5, label="Guidance Scale")
width_input = gr.Slider(0, 1920, value=1024, step=2, label="Width")
height_input = gr.Slider(0, 1920, value=1024, step=2, label="Height")
seed_input = gr.Number(value=-1, precision=0, label="Seed")
save_locally_checkbox = gr.Checkbox(label="Save generations locally", value=False)
generate_button = gr.Button("Generate")
output_image = gr.Image(label="Generated Image")
generate_button.click(
generate,
inputs=[model_dropdown, prompt_input, steps_input, guidance_input, width_input, height_input, seed_input, save_locally_checkbox],
outputs=output_image
)
model_dropdown.change(
update_steps,
inputs=[model_dropdown],
outputs=[steps_input]
)
demo.launch(server_name="127.0.0.1")
@twobob
Copy link
Author

twobob commented Aug 7, 2024

pip install [email protected] not 2.1.0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment