Created
August 7, 2024 16:26
-
-
Save mdbecker/eeb1222c82de2ee3281a8ca3b0503c7b to your computer and use it in GitHub Desktop.
FLUX.1 [schnell] on Mac MPS
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import hashlib | |
import json | |
import random | |
from datetime import datetime | |
import diffusers | |
import piexif | |
import torch | |
from diffusers import FluxPipeline | |
from PIL import Image | |
# Modify the rope function to handle MPS device for Mac compatibility | |
_flux_rope = diffusers.models.transformers.transformer_flux.rope | |
def new_flux_rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: | |
assert dim % 2 == 0, "The dimension must be even." | |
if pos.device.type == "mps": | |
return _flux_rope(pos.to("cpu"), dim, theta).to(device=pos.device) | |
else: | |
return _flux_rope(pos, dim, theta) | |
diffusers.models.transformers.transformer_flux.rope = new_flux_rope | |
def add_metadata_to_image(image, metadata): | |
""" | |
Add metadata to the image for compatibility with platforms like Civitai. | |
This function encodes prompt, generation parameters, and model information. | |
""" | |
user_comment = f"{metadata['prompt']}\n" | |
user_comment += f"Steps: {metadata['steps']}, " | |
user_comment += f"Sampler: {metadata['sampler']}, " | |
user_comment += f"CFG scale: {metadata['guidance']}, " | |
user_comment += f"Seed: {metadata['seed']}, " | |
user_comment += f"Size: {metadata['width']}x{metadata['height']}, " | |
user_comment += f"Created Date: {datetime.utcnow().isoformat()}Z, " | |
civitai_resources = [{ | |
"type": "checkpoint", | |
"modelVersionId": metadata.get('model_version_id', 699279), | |
"modelName": metadata.get('model_name', 'FLUX'), | |
"modelVersionName": metadata.get('model_version_name', 'Schnell') | |
}] | |
user_comment += f"Civitai resources: {json.dumps(civitai_resources)}," | |
exif_dict = { | |
"Exif": { | |
piexif.ExifIFD.UserComment: user_comment.encode('utf-16be'), | |
}, | |
} | |
return piexif.dump(exif_dict) | |
# Load the Flux Schnell model | |
pipe = FluxPipeline.from_pretrained( | |
"black-forest-labs/FLUX.1-schnell", | |
revision='refs/pr/1', | |
torch_dtype=torch.bfloat16 | |
).to("mps") | |
# Set up the prompt and parameters | |
prompt = "an obese rusty robot. The robot is chewing on a long microchip that has 'vram' printed on it" | |
seed = random.randint(0, 2**32 - 1) | |
generator = torch.Generator(device="mps").manual_seed(seed) | |
# Define generation parameters | |
params = { | |
"prompt": prompt, | |
"height": 1024, | |
"width": 1024, | |
"num_inference_steps": 4, # Schnell version works well with just 4 steps | |
"guidance_scale": 0.0, | |
"generator": generator, | |
"max_sequence_length": 256, | |
} | |
# Generate the image | |
out = pipe(**params).images[0] | |
# Prepare metadata for image | |
metadata = { | |
"prompt": prompt, | |
"steps": params["num_inference_steps"], | |
"sampler": "Euler", | |
"guidance": params["guidance_scale"], | |
"seed": seed, | |
"width": params["width"], | |
"height": params["height"], | |
"model_name": "FLUX", | |
"model_version_name": "Schnell", | |
"model_version_id": 699279, | |
} | |
# Add metadata to the image | |
exif_bytes = add_metadata_to_image(out, metadata) | |
# Save the generated image with a unique filename and metadata | |
filename = f"{hashlib.md5(prompt.encode()).hexdigest()[:10]}_{seed}.jpg" | |
out.save(filename, "JPEG", quality=95, exif=exif_bytes) | |
print(f"Generated image saved as {filename}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment