Skip to content

Instantly share code, notes, and snippets.

@lucataco
Last active January 31, 2025 20:28
Show Gist options
  • Save lucataco/4d6b33904d7482847519aeaaca6628e4 to your computer and use it in GitHub Desktop.
Save lucataco/4d6b33904d7482847519aeaaca6628e4 to your computer and use it in GitHub Desktop.
Flux Schnell locally on MPS
# conda create -n flux python=3.11
# conda activate flux
# pip install torch==2.3.1
# pip install diffusers==0.30.0 transformers==4.43.3
# pip install sentencepiece==0.2.0 accelerate==0.33.0 protobuf==5.27.3
import torch
from diffusers import FluxPipeline
import diffusers
_flux_rope = diffusers.models.transformers.transformer_flux.rope
def new_flux_rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
assert dim % 2 == 0, "The dimension must be even."
if pos.device.type == "mps":
return _flux_rope(pos.to("cpu"), dim, theta).to(device=pos.device)
else:
return _flux_rope(pos, dim, theta)
diffusers.models.transformers.transformer_flux.rope = new_flux_rope
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", revision='refs/pr/1', torch_dtype=torch.bfloat16).to("mps")
prompt = "A cat holding a sign that says hello world"
out = pipe(
prompt=prompt,
guidance_scale=0.,
height=1024,
width=1024,
num_inference_steps=4,
max_sequence_length=256,
).images[0]
out.save("flux_image.png")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment