Skip to content

Instantly share code, notes, and snippets.

@asomoza
Created April 12, 2025 06:52
Show Gist options
  • Save asomoza/61abfd3bf6c8c46db663f87d8f75c969 to your computer and use it in GitHub Desktop.
Save asomoza/61abfd3bf6c8c46db663f87d8f75c969 to your computer and use it in GitHub Desktop.
import gc
import torch
from transformers import (
BitsAndBytesConfig as BitsAndBytesConfig,
)
from transformers import (
CLIPTextModelWithProjection,
CLIPTokenizer,
LlamaForCausalLM,
PreTrainedTokenizerFast,
T5EncoderModel,
T5Tokenizer,
)
from diffusers import AutoencoderKL, HiDreamImagePipeline, HiDreamImageTransformer2DModel, UniPCMultistepScheduler
from diffusers.hooks.group_offloading import apply_group_offloading
from diffusers.image_processor import VaeImageProcessor
repo_id = "HiDream-ai/HiDream-I1-Full"
llama_repo_id = "meta-llama/Llama-3.1-8B-Instruct"
device = torch.device("cuda")
torch_dtype = torch.bfloat16
prompt = "a dog at the left side of an old oak try with a lake in the background, a boat is sailing in the lake with some people over it dancing and having a party, on the right side of the image in the sky, we can see a biplane with a sign attached to it that says 'wow this is awesome'"
def flush(device):
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats(device)
print(f"Current CUDA memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
print(f"Current CUDA memory reserved: {torch.cuda.memory_reserved() / 1e9:.2f} GB")
def encode_prompt(
prompt, pipeline_repo_id, llama_repo_id, do_classifier_free_guidance=False, device=device, dtype=torch_dtype
):
prompt = [prompt] if isinstance(prompt, str) else prompt
tokenizer = CLIPTokenizer.from_pretrained(pipeline_repo_id, subfolder="tokenizer")
text_encoder = CLIPTextModelWithProjection.from_pretrained(
pipeline_repo_id, subfolder="text_encoder", torch_dtype=torch_dtype
).to(device)
prompt_embeds = get_clip_prompt_embeds(prompt, tokenizer, text_encoder)
prompt_embeds_1 = prompt_embeds.clone().detach()
text_encoder.to("cpu")
del prompt_embeds
del tokenizer
del text_encoder
flush(device)
tokenizer = CLIPTokenizer.from_pretrained(pipeline_repo_id, subfolder="tokenizer_2")
text_encoder = CLIPTextModelWithProjection.from_pretrained(
pipeline_repo_id, subfolder="text_encoder_2", torch_dtype=torch_dtype
).to(device)
prompt_embeds = get_clip_prompt_embeds(prompt, tokenizer, text_encoder)
prompt_embeds_2 = prompt_embeds.clone().detach()
text_encoder.to("cpu")
del prompt_embeds
del tokenizer
del text_encoder
flush(device)
pooled_prompt_embeds = torch.cat([prompt_embeds_1, prompt_embeds_2], dim=-1)
tokenizer = T5Tokenizer.from_pretrained(pipeline_repo_id, subfolder="tokenizer_3", torch_dtype=torch_dtype)
text_encoder = T5EncoderModel.from_pretrained(
pipeline_repo_id, subfolder="text_encoder_3", torch_dtype=torch_dtype
).to(device)
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=128,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
attention_mask = text_inputs.attention_mask
prompt_embeds = text_encoder(text_input_ids.to(device), attention_mask=attention_mask.to(device))[0]
t5_prompt_embeds = prompt_embeds.clone().detach()
del prompt_embeds
del text_inputs
del text_encoder
del tokenizer
flush(device)
tokenizer = PreTrainedTokenizerFast.from_pretrained(llama_repo_id)
tokenizer.pad_token = tokenizer.eos_token
text_encoder = LlamaForCausalLM.from_pretrained(
llama_repo_id,
output_hidden_states=True,
output_attentions=True,
torch_dtype=torch_dtype,
).to(device)
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=128,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
attention_mask = text_inputs.attention_mask
outputs = text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask.to(device),
output_hidden_states=True,
output_attentions=True,
)
prompt_embeds = outputs.hidden_states[1:]
prompt_embeds = torch.stack(prompt_embeds, dim=0)
llama3_prompt_embeds = prompt_embeds.clone().detach()
del prompt_embeds
del outputs
del text_inputs
del text_encoder
del tokenizer
flush(device)
prompt_embeds = [t5_prompt_embeds, llama3_prompt_embeds]
embeds = {
"prompt_embeds": prompt_embeds,
"pooled_prompt_embeds": pooled_prompt_embeds,
"negative_prompt_embeds": None,
"negative_pooled_prompt_embeds": None,
}
return embeds
def get_clip_prompt_embeds(prompt, tokenizer, text_encoder):
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=77,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
prompt_embeds = prompt_embeds[0]
return prompt_embeds
def denoise(embeddings, device=device, dtype=torch_dtype):
scheduler = UniPCMultistepScheduler(
flow_shift=3.0,
prediction_type="flow_prediction",
use_flow_sigmas=True,
)
transformer = HiDreamImageTransformer2DModel.from_pretrained(
"HiDream-ai/HiDream-I1-Dev", subfolder="transformer", torch_dtype=torch_dtype
)
transformer.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch_dtype)
apply_group_offloading(
transformer,
onload_device=device,
offload_device=torch.device("cpu"),
offload_type="leaf_level",
use_stream=True,
low_cpu_mem_usage=True,
)
pipe = HiDreamImagePipeline.from_pretrained(
repo_id,
text_encoder=None,
tokenizer=None,
text_encoder_2=None,
tokenizer_2=None,
text_encoder_3=None,
tokenizer_3=None,
text_encoder_4=None,
tokenizer_4=None,
transformer=transformer,
scheduler=scheduler,
vae=None,
torch_dtype=torch_dtype,
)
latents = pipe(
**embeddings,
height=1024,
width=1024,
guidance_scale=0.0,
num_inference_steps=28,
generator=torch.Generator(device).manual_seed(0),
output_type="latent",
return_dict=False,
)
del pipe
flush(device)
return latents
with torch.no_grad():
embeddings = encode_prompt(prompt, repo_id, llama_repo_id, device=device, dtype=torch_dtype)
latents = denoise(embeddings, device=device, dtype=torch_dtype)[0]
vae = AutoencoderKL.from_pretrained(repo_id, subfolder="vae", torch_dtype=torch_dtype).to(device)
latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
with torch.no_grad():
image = vae.decode(latents, return_dict=False)[0]
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor * 2)
image = image_processor.postprocess(image, output_type="pil")[0]
image.save("output.png")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment