Last active
January 31, 2025 20:28
-
-
Save lucataco/4d6b33904d7482847519aeaaca6628e4 to your computer and use it in GitHub Desktop.
Flux Schnell locally on MPS
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# conda create -n flux python=3.11 | |
# conda activate flux | |
# pip install torch==2.3.1 | |
# pip install diffusers==0.30.0 transformers==4.43.3 | |
# pip install sentencepiece==0.2.0 accelerate==0.33.0 protobuf==5.27.3 | |
import torch | |
from diffusers import FluxPipeline | |
import diffusers | |
_flux_rope = diffusers.models.transformers.transformer_flux.rope | |
def new_flux_rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: | |
assert dim % 2 == 0, "The dimension must be even." | |
if pos.device.type == "mps": | |
return _flux_rope(pos.to("cpu"), dim, theta).to(device=pos.device) | |
else: | |
return _flux_rope(pos, dim, theta) | |
diffusers.models.transformers.transformer_flux.rope = new_flux_rope | |
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", revision='refs/pr/1', torch_dtype=torch.bfloat16).to("mps") | |
prompt = "A cat holding a sign that says hello world" | |
out = pipe( | |
prompt=prompt, | |
guidance_scale=0., | |
height=1024, | |
width=1024, | |
num_inference_steps=4, | |
max_sequence_length=256, | |
).images[0] | |
out.save("flux_image.png") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment