Skip to content

Instantly share code, notes, and snippets.

@mdbecker
Created August 7, 2024 16:26
Show Gist options
  • Save mdbecker/eeb1222c82de2ee3281a8ca3b0503c7b to your computer and use it in GitHub Desktop.
Save mdbecker/eeb1222c82de2ee3281a8ca3b0503c7b to your computer and use it in GitHub Desktop.
FLUX.1 [schnell] on Mac MPS
import hashlib
import json
import random
from datetime import datetime
import diffusers
import piexif
import torch
from diffusers import FluxPipeline
from PIL import Image
# Modify the rope function to handle MPS device for Mac compatibility
_flux_rope = diffusers.models.transformers.transformer_flux.rope
def new_flux_rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
assert dim % 2 == 0, "The dimension must be even."
if pos.device.type == "mps":
return _flux_rope(pos.to("cpu"), dim, theta).to(device=pos.device)
else:
return _flux_rope(pos, dim, theta)
diffusers.models.transformers.transformer_flux.rope = new_flux_rope
def add_metadata_to_image(image, metadata):
"""
Add metadata to the image for compatibility with platforms like Civitai.
This function encodes prompt, generation parameters, and model information.
"""
user_comment = f"{metadata['prompt']}\n"
user_comment += f"Steps: {metadata['steps']}, "
user_comment += f"Sampler: {metadata['sampler']}, "
user_comment += f"CFG scale: {metadata['guidance']}, "
user_comment += f"Seed: {metadata['seed']}, "
user_comment += f"Size: {metadata['width']}x{metadata['height']}, "
user_comment += f"Created Date: {datetime.utcnow().isoformat()}Z, "
civitai_resources = [{
"type": "checkpoint",
"modelVersionId": metadata.get('model_version_id', 699279),
"modelName": metadata.get('model_name', 'FLUX'),
"modelVersionName": metadata.get('model_version_name', 'Schnell')
}]
user_comment += f"Civitai resources: {json.dumps(civitai_resources)},"
exif_dict = {
"Exif": {
piexif.ExifIFD.UserComment: user_comment.encode('utf-16be'),
},
}
return piexif.dump(exif_dict)
# Load the Flux Schnell model
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
revision='refs/pr/1',
torch_dtype=torch.bfloat16
).to("mps")
# Set up the prompt and parameters
prompt = "an obese rusty robot. The robot is chewing on a long microchip that has 'vram' printed on it"
seed = random.randint(0, 2**32 - 1)
generator = torch.Generator(device="mps").manual_seed(seed)
# Define generation parameters
params = {
"prompt": prompt,
"height": 1024,
"width": 1024,
"num_inference_steps": 4, # Schnell version works well with just 4 steps
"guidance_scale": 0.0,
"generator": generator,
"max_sequence_length": 256,
}
# Generate the image
out = pipe(**params).images[0]
# Prepare metadata for image
metadata = {
"prompt": prompt,
"steps": params["num_inference_steps"],
"sampler": "Euler",
"guidance": params["guidance_scale"],
"seed": seed,
"width": params["width"],
"height": params["height"],
"model_name": "FLUX",
"model_version_name": "Schnell",
"model_version_id": 699279,
}
# Add metadata to the image
exif_bytes = add_metadata_to_image(out, metadata)
# Save the generated image with a unique filename and metadata
filename = f"{hashlib.md5(prompt.encode()).hexdigest()[:10]}_{seed}.jpg"
out.save(filename, "JPEG", quality=95, exif=exif_bytes)
print(f"Generated image saved as {filename}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment