Created
February 13, 2025 21:47
-
-
Save cnndabbler/662ef9f9c6212b305c46ab259dadb369 to your computer and use it in GitHub Desktop.
modal with FLUX
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# --- | |
# BENTOAI: | |
# input as prompts_file: image_prompts.txt | |
# output-directory: "./flux/frame_XXXX.png" | |
# --- | |
# BENTOAI: This guide is based on Modal excellent examples and tutorials: https://modal.com/docs/example | |
# In this guide, we'll run Flux to generate images from prompts using [Hugging Face's Diffusers](https://github.com/huggingface/diffusers). | |
# We'll use the [FLUX.1 model](https://huggingface.co/black-forest-labs/FLUX.1-dev) for this example, | |
# ## Setting up the image and dependencies | |
import time | |
from io import BytesIO | |
from pathlib import Path | |
import os | |
import modal | |
# We'll make use of the full [CUDA toolkit](https://modal.com/docs/guide/cuda) | |
# in this example, so we'll build our container image off of the `nvidia/cuda` base. | |
cuda_version = "12.4.0" # should be no greater than host CUDA version | |
flavor = "devel" # includes full CUDA toolkit | |
operating_sys = "ubuntu22.04" | |
tag = f"{cuda_version}-{flavor}-{operating_sys}" | |
cuda_dev_image = modal.Image.from_registry( | |
f"nvidia/cuda:{tag}", add_python="3.11" | |
).entrypoint([]) | |
# Now we install most of our dependencies with `apt` and `pip`. | |
# For Hugging Face's [Diffusers](https://github.com/huggingface/diffusers) library | |
# we install from GitHub source and so pin to a specific commit. | |
# PyTorch added [faster attention kernels for Hopper GPUs in version 2.5 | |
# diffusers_commit_sha = "81cf3b2f155f1de322079af28f625349ee21ec6b" | |
flux_image = ( | |
cuda_dev_image.apt_install( | |
"git", | |
"libglib2.0-0", | |
"libsm6", | |
"libxrender1", | |
"libxext6", | |
"ffmpeg", | |
"libgl1", | |
) | |
# Modified to use uv | |
.pip_install("uv") | |
.run_commands("uv pip install --system --compile-bytecode torch invisible_watermark transformers huggingface_hub accelerate safetensors sentencepiece torch diffusers numpy" | |
) | |
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1", "HF_HUB_CACHE": "/cache", "TRANSFORMERS_OFFLINE": "0"}) | |
) | |
# Later, we'll also use `torch.compile` to increase the speed further. | |
# Torch compilation needs to be re-executed when each new container starts, | |
# So we turn on some extra caching to reduce compile times for later containers. | |
# flux_image = flux_image.env( | |
# { | |
# "TORCHINDUCTOR_CACHE_DIR": "/root/.inductor-cache", | |
# "TORCHINDUCTOR_FX_GRAPH_CACHE": "1", | |
# } | |
# ) | |
# Finally, we construct our Modal [App](https://modal.com/docs/reference/modal.App), | |
# set its default image to the one we just constructed, | |
# and import `FluxPipeline` for downloading and running Flux.1. | |
app = modal.App("bentoai-flux", image=flux_image) | |
with flux_image.imports(): | |
import torch | |
from diffusers import FluxPipeline | |
# ## Defining a parameterized `Model` inference class | |
# Next, we map the model's setup and inference code onto Modal. | |
# 1. We the model setun in the method decorated with `@modal.enter()`. This includes loading the | |
# weights and moving them to the GPU, along with an optional `torch.compile` step (see details below). | |
# The `@modal.enter()` decorator ensures that this method runs only once, when a new container starts, | |
# instead of in the path of every call. | |
# 2. We run the actual inference in methods decorated with `@modal.method()`. | |
MINUTES = 60 # seconds | |
VARIANT = "dev" # "schnell"or "dev", but note [dev] requires you to accept terms and conditions on HF | |
# BENTOAI: we added SECRETS to accept terms and conditions on HF | |
NUM_INFERENCE_STEPS = 50 # use ~50 for [dev], smaller for [schnell] | |
@app.cls( | |
gpu="L40S", # 48GB VRAM | |
# gpu="L4", # 24GB VRAM ... need to investigate GGUF | |
# BENTOAI: Seems to be a good compromise cost vs. performance | |
container_idle_timeout=5 * MINUTES, | |
timeout=10 * MINUTES, # leave plenty of time for compilation | |
# BENTOAI: we added SECRETS to accept terms and conditions on HF | |
secrets=[modal.Secret.from_name("huggingface-secret")], # Add HF token from Modal secrets | |
volumes={ # add Volumes to store serializable compilation artifacts | |
"/cache": modal.Volume.from_name( | |
"hf-hub-cache", create_if_missing=True | |
), | |
"/root/.nv": modal.Volume.from_name("nv-cache", create_if_missing=True), | |
"/root/.triton": modal.Volume.from_name( | |
"triton-cache", create_if_missing=True | |
), | |
"/root/.inductor-cache": modal.Volume.from_name( | |
"inductor-cache", create_if_missing=True | |
), | |
}, | |
) | |
class Model: | |
compile: int = ( # see section on torch.compile below for details | |
modal.parameter(default=0) | |
) | |
@modal.enter() | |
def enter(self): | |
pipe = FluxPipeline.from_pretrained( | |
f"black-forest-labs/FLUX.1-{VARIANT}", | |
torch_dtype=torch.bfloat16, | |
# use_auth_token=os.environ["HF_TOKEN"], # Use HF token from Modal secrets | |
).to("cuda") # move model to GPU | |
# BENTOAI: helps avoid some OOM VRAM issues | |
pipe.enable_model_cpu_offload() # Save VRAM | |
self.pipe = pipe | |
# BENTOAI: added for compilation that we will try later | |
# self.pipe = optimize(pipe, compile=bool(self.compile)) | |
@modal.method() | |
def inference(self, prompts: list[str]) -> list[bytes]: | |
print(f"π¨ generating {len(prompts)} images...") | |
outputs = self.pipe( | |
prompts, | |
output_type="pil", | |
num_inference_steps=NUM_INFERENCE_STEPS, | |
# BENTOAI: 9:16 images | |
height=800, | |
width=480, | |
guidance_scale=3.5, | |
max_sequence_length=512, | |
generator=torch.Generator("cpu").manual_seed(0) | |
).images | |
image_bytes = [] | |
for i, img in enumerate(outputs): | |
byte_stream = BytesIO() | |
img.save(byte_stream, format="PNG") | |
image_bytes.append(byte_stream.getvalue()) | |
return image_bytes | |
# ## Calling our inference function | |
# To generate an image we just need to call the `Model`'s `generate` method | |
# with `.remote` appended to it. | |
# You can call `.generate.remote` from any Python environment that has access to your Modal credentials. | |
# The local environment will get back the image as bytes. | |
# Here, we wrap the call in a Modal [`local_entrypoint`](https://modal.com/docs/reference/modal.App#local_entrypoint) | |
# so that it can be run with `modal run`: | |
# ```bash | |
# modal run flux.py | |
# ``` | |
# By default, we call `generate` twice to demonstrate how much faster | |
# the inference is after cold start. In our tests, clients received images in about 1.2 seconds. | |
# We save the output bytes to a temporary file. | |
@app.local_entrypoint() | |
def main( | |
prompts_file: str = "image_prompts.txt", | |
compile: bool = False, # BENTOAI: added for compilation that we will try later | |
): | |
# Read prompts from file | |
prompts = [] | |
with open(prompts_file, 'r') as f: | |
for line in f: | |
line = line.strip() | |
if line and not line.startswith('**'): | |
# Clean up the prompt text | |
prompts.append(line) | |
if not prompts: | |
print("No valid prompts found in the file") | |
return | |
print(f"π¨ Found {len(prompts)} prompts") | |
model = Model(compile=int(compile)) | |
print("π¨ generating batch of images:") | |
t0 = time.time() | |
image_bytes_list = model.inference.remote(prompts) | |
print(f"π¨ batch inference latency: {time.time() - t0:.2f} seconds") | |
output_dir = Path(".") / "flux" | |
output_dir.mkdir(exist_ok=True, parents=True) | |
print("π¨ saving outputs to flux/frame_XXXX.png") | |
for i, image_bytes in enumerate(image_bytes_list): | |
frame_num = str(i + 1).zfill(4) | |
output_path = output_dir / f"frame_{frame_num}.png" | |
output_path.write_bytes(image_bytes) | |
print(f"π¨ Generated frame_{frame_num}.png for prompt: {prompts[i][:100]}...") | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment