Skip to content

Instantly share code, notes, and snippets.

@cnndabbler
Created February 13, 2025 21:47
Show Gist options
  • Save cnndabbler/662ef9f9c6212b305c46ab259dadb369 to your computer and use it in GitHub Desktop.
Save cnndabbler/662ef9f9c6212b305c46ab259dadb369 to your computer and use it in GitHub Desktop.
modal with FLUX
# ---
# BENTOAI:
# input as prompts_file: image_prompts.txt
# output-directory: "./flux/frame_XXXX.png"
# ---
# BENTOAI: This guide is based on Modal excellent examples and tutorials: https://modal.com/docs/example
# In this guide, we'll run Flux to generate images from prompts using [Hugging Face's Diffusers](https://github.com/huggingface/diffusers).
# We'll use the [FLUX.1 model](https://huggingface.co/black-forest-labs/FLUX.1-dev) for this example,
# ## Setting up the image and dependencies
import time
from io import BytesIO
from pathlib import Path
import os
import modal
# We'll make use of the full [CUDA toolkit](https://modal.com/docs/guide/cuda)
# in this example, so we'll build our container image off of the `nvidia/cuda` base.
cuda_version = "12.4.0" # should be no greater than host CUDA version
flavor = "devel" # includes full CUDA toolkit
operating_sys = "ubuntu22.04"
tag = f"{cuda_version}-{flavor}-{operating_sys}"
cuda_dev_image = modal.Image.from_registry(
f"nvidia/cuda:{tag}", add_python="3.11"
).entrypoint([])
# Now we install most of our dependencies with `apt` and `pip`.
# For Hugging Face's [Diffusers](https://github.com/huggingface/diffusers) library
# we install from GitHub source and so pin to a specific commit.
# PyTorch added [faster attention kernels for Hopper GPUs in version 2.5
# diffusers_commit_sha = "81cf3b2f155f1de322079af28f625349ee21ec6b"
flux_image = (
cuda_dev_image.apt_install(
"git",
"libglib2.0-0",
"libsm6",
"libxrender1",
"libxext6",
"ffmpeg",
"libgl1",
)
# Modified to use uv
.pip_install("uv")
.run_commands("uv pip install --system --compile-bytecode torch invisible_watermark transformers huggingface_hub accelerate safetensors sentencepiece torch diffusers numpy"
)
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1", "HF_HUB_CACHE": "/cache", "TRANSFORMERS_OFFLINE": "0"})
)
# Later, we'll also use `torch.compile` to increase the speed further.
# Torch compilation needs to be re-executed when each new container starts,
# So we turn on some extra caching to reduce compile times for later containers.
# flux_image = flux_image.env(
# {
# "TORCHINDUCTOR_CACHE_DIR": "/root/.inductor-cache",
# "TORCHINDUCTOR_FX_GRAPH_CACHE": "1",
# }
# )
# Finally, we construct our Modal [App](https://modal.com/docs/reference/modal.App),
# set its default image to the one we just constructed,
# and import `FluxPipeline` for downloading and running Flux.1.
app = modal.App("bentoai-flux", image=flux_image)
with flux_image.imports():
import torch
from diffusers import FluxPipeline
# ## Defining a parameterized `Model` inference class
# Next, we map the model's setup and inference code onto Modal.
# 1. We the model setun in the method decorated with `@modal.enter()`. This includes loading the
# weights and moving them to the GPU, along with an optional `torch.compile` step (see details below).
# The `@modal.enter()` decorator ensures that this method runs only once, when a new container starts,
# instead of in the path of every call.
# 2. We run the actual inference in methods decorated with `@modal.method()`.
MINUTES = 60 # seconds
VARIANT = "dev" # "schnell"or "dev", but note [dev] requires you to accept terms and conditions on HF
# BENTOAI: we added SECRETS to accept terms and conditions on HF
NUM_INFERENCE_STEPS = 50 # use ~50 for [dev], smaller for [schnell]
@app.cls(
gpu="L40S", # 48GB VRAM
# gpu="L4", # 24GB VRAM ... need to investigate GGUF
# BENTOAI: Seems to be a good compromise cost vs. performance
container_idle_timeout=5 * MINUTES,
timeout=10 * MINUTES, # leave plenty of time for compilation
# BENTOAI: we added SECRETS to accept terms and conditions on HF
secrets=[modal.Secret.from_name("huggingface-secret")], # Add HF token from Modal secrets
volumes={ # add Volumes to store serializable compilation artifacts
"/cache": modal.Volume.from_name(
"hf-hub-cache", create_if_missing=True
),
"/root/.nv": modal.Volume.from_name("nv-cache", create_if_missing=True),
"/root/.triton": modal.Volume.from_name(
"triton-cache", create_if_missing=True
),
"/root/.inductor-cache": modal.Volume.from_name(
"inductor-cache", create_if_missing=True
),
},
)
class Model:
compile: int = ( # see section on torch.compile below for details
modal.parameter(default=0)
)
@modal.enter()
def enter(self):
pipe = FluxPipeline.from_pretrained(
f"black-forest-labs/FLUX.1-{VARIANT}",
torch_dtype=torch.bfloat16,
# use_auth_token=os.environ["HF_TOKEN"], # Use HF token from Modal secrets
).to("cuda") # move model to GPU
# BENTOAI: helps avoid some OOM VRAM issues
pipe.enable_model_cpu_offload() # Save VRAM
self.pipe = pipe
# BENTOAI: added for compilation that we will try later
# self.pipe = optimize(pipe, compile=bool(self.compile))
@modal.method()
def inference(self, prompts: list[str]) -> list[bytes]:
print(f"🎨 generating {len(prompts)} images...")
outputs = self.pipe(
prompts,
output_type="pil",
num_inference_steps=NUM_INFERENCE_STEPS,
# BENTOAI: 9:16 images
height=800,
width=480,
guidance_scale=3.5,
max_sequence_length=512,
generator=torch.Generator("cpu").manual_seed(0)
).images
image_bytes = []
for i, img in enumerate(outputs):
byte_stream = BytesIO()
img.save(byte_stream, format="PNG")
image_bytes.append(byte_stream.getvalue())
return image_bytes
# ## Calling our inference function
# To generate an image we just need to call the `Model`'s `generate` method
# with `.remote` appended to it.
# You can call `.generate.remote` from any Python environment that has access to your Modal credentials.
# The local environment will get back the image as bytes.
# Here, we wrap the call in a Modal [`local_entrypoint`](https://modal.com/docs/reference/modal.App#local_entrypoint)
# so that it can be run with `modal run`:
# ```bash
# modal run flux.py
# ```
# By default, we call `generate` twice to demonstrate how much faster
# the inference is after cold start. In our tests, clients received images in about 1.2 seconds.
# We save the output bytes to a temporary file.
@app.local_entrypoint()
def main(
prompts_file: str = "image_prompts.txt",
compile: bool = False, # BENTOAI: added for compilation that we will try later
):
# Read prompts from file
prompts = []
with open(prompts_file, 'r') as f:
for line in f:
line = line.strip()
if line and not line.startswith('**'):
# Clean up the prompt text
prompts.append(line)
if not prompts:
print("No valid prompts found in the file")
return
print(f"🎨 Found {len(prompts)} prompts")
model = Model(compile=int(compile))
print("🎨 generating batch of images:")
t0 = time.time()
image_bytes_list = model.inference.remote(prompts)
print(f"🎨 batch inference latency: {time.time() - t0:.2f} seconds")
output_dir = Path(".") / "flux"
output_dir.mkdir(exist_ok=True, parents=True)
print("🎨 saving outputs to flux/frame_XXXX.png")
for i, image_bytes in enumerate(image_bytes_list):
frame_num = str(i + 1).zfill(4)
output_path = output_dir / f"frame_{frame_num}.png"
output_path.write_bytes(image_bytes)
print(f"🎨 Generated frame_{frame_num}.png for prompt: {prompts[i][:100]}...")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment