Created
April 12, 2025 06:52
-
-
Save asomoza/61abfd3bf6c8c46db663f87d8f75c969 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gc | |
import torch | |
from transformers import ( | |
BitsAndBytesConfig as BitsAndBytesConfig, | |
) | |
from transformers import ( | |
CLIPTextModelWithProjection, | |
CLIPTokenizer, | |
LlamaForCausalLM, | |
PreTrainedTokenizerFast, | |
T5EncoderModel, | |
T5Tokenizer, | |
) | |
from diffusers import AutoencoderKL, HiDreamImagePipeline, HiDreamImageTransformer2DModel, UniPCMultistepScheduler | |
from diffusers.hooks.group_offloading import apply_group_offloading | |
from diffusers.image_processor import VaeImageProcessor | |
repo_id = "HiDream-ai/HiDream-I1-Full" | |
llama_repo_id = "meta-llama/Llama-3.1-8B-Instruct" | |
device = torch.device("cuda") | |
torch_dtype = torch.bfloat16 | |
prompt = "a dog at the left side of an old oak try with a lake in the background, a boat is sailing in the lake with some people over it dancing and having a party, on the right side of the image in the sky, we can see a biplane with a sign attached to it that says 'wow this is awesome'" | |
def flush(device): | |
gc.collect() | |
torch.cuda.empty_cache() | |
torch.cuda.reset_peak_memory_stats(device) | |
print(f"Current CUDA memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB") | |
print(f"Current CUDA memory reserved: {torch.cuda.memory_reserved() / 1e9:.2f} GB") | |
def encode_prompt( | |
prompt, pipeline_repo_id, llama_repo_id, do_classifier_free_guidance=False, device=device, dtype=torch_dtype | |
): | |
prompt = [prompt] if isinstance(prompt, str) else prompt | |
tokenizer = CLIPTokenizer.from_pretrained(pipeline_repo_id, subfolder="tokenizer") | |
text_encoder = CLIPTextModelWithProjection.from_pretrained( | |
pipeline_repo_id, subfolder="text_encoder", torch_dtype=torch_dtype | |
).to(device) | |
prompt_embeds = get_clip_prompt_embeds(prompt, tokenizer, text_encoder) | |
prompt_embeds_1 = prompt_embeds.clone().detach() | |
text_encoder.to("cpu") | |
del prompt_embeds | |
del tokenizer | |
del text_encoder | |
flush(device) | |
tokenizer = CLIPTokenizer.from_pretrained(pipeline_repo_id, subfolder="tokenizer_2") | |
text_encoder = CLIPTextModelWithProjection.from_pretrained( | |
pipeline_repo_id, subfolder="text_encoder_2", torch_dtype=torch_dtype | |
).to(device) | |
prompt_embeds = get_clip_prompt_embeds(prompt, tokenizer, text_encoder) | |
prompt_embeds_2 = prompt_embeds.clone().detach() | |
text_encoder.to("cpu") | |
del prompt_embeds | |
del tokenizer | |
del text_encoder | |
flush(device) | |
pooled_prompt_embeds = torch.cat([prompt_embeds_1, prompt_embeds_2], dim=-1) | |
tokenizer = T5Tokenizer.from_pretrained(pipeline_repo_id, subfolder="tokenizer_3", torch_dtype=torch_dtype) | |
text_encoder = T5EncoderModel.from_pretrained( | |
pipeline_repo_id, subfolder="text_encoder_3", torch_dtype=torch_dtype | |
).to(device) | |
text_inputs = tokenizer( | |
prompt, | |
padding="max_length", | |
max_length=128, | |
truncation=True, | |
add_special_tokens=True, | |
return_tensors="pt", | |
) | |
text_input_ids = text_inputs.input_ids | |
attention_mask = text_inputs.attention_mask | |
prompt_embeds = text_encoder(text_input_ids.to(device), attention_mask=attention_mask.to(device))[0] | |
t5_prompt_embeds = prompt_embeds.clone().detach() | |
del prompt_embeds | |
del text_inputs | |
del text_encoder | |
del tokenizer | |
flush(device) | |
tokenizer = PreTrainedTokenizerFast.from_pretrained(llama_repo_id) | |
tokenizer.pad_token = tokenizer.eos_token | |
text_encoder = LlamaForCausalLM.from_pretrained( | |
llama_repo_id, | |
output_hidden_states=True, | |
output_attentions=True, | |
torch_dtype=torch_dtype, | |
).to(device) | |
text_inputs = tokenizer( | |
prompt, | |
padding="max_length", | |
max_length=128, | |
truncation=True, | |
add_special_tokens=True, | |
return_tensors="pt", | |
) | |
text_input_ids = text_inputs.input_ids | |
attention_mask = text_inputs.attention_mask | |
outputs = text_encoder( | |
text_input_ids.to(device), | |
attention_mask=attention_mask.to(device), | |
output_hidden_states=True, | |
output_attentions=True, | |
) | |
prompt_embeds = outputs.hidden_states[1:] | |
prompt_embeds = torch.stack(prompt_embeds, dim=0) | |
llama3_prompt_embeds = prompt_embeds.clone().detach() | |
del prompt_embeds | |
del outputs | |
del text_inputs | |
del text_encoder | |
del tokenizer | |
flush(device) | |
prompt_embeds = [t5_prompt_embeds, llama3_prompt_embeds] | |
embeds = { | |
"prompt_embeds": prompt_embeds, | |
"pooled_prompt_embeds": pooled_prompt_embeds, | |
"negative_prompt_embeds": None, | |
"negative_pooled_prompt_embeds": None, | |
} | |
return embeds | |
def get_clip_prompt_embeds(prompt, tokenizer, text_encoder): | |
text_inputs = tokenizer( | |
prompt, | |
padding="max_length", | |
max_length=77, | |
truncation=True, | |
return_tensors="pt", | |
) | |
text_input_ids = text_inputs.input_ids | |
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) | |
prompt_embeds = prompt_embeds[0] | |
return prompt_embeds | |
def denoise(embeddings, device=device, dtype=torch_dtype): | |
scheduler = UniPCMultistepScheduler( | |
flow_shift=3.0, | |
prediction_type="flow_prediction", | |
use_flow_sigmas=True, | |
) | |
transformer = HiDreamImageTransformer2DModel.from_pretrained( | |
"HiDream-ai/HiDream-I1-Dev", subfolder="transformer", torch_dtype=torch_dtype | |
) | |
transformer.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch_dtype) | |
apply_group_offloading( | |
transformer, | |
onload_device=device, | |
offload_device=torch.device("cpu"), | |
offload_type="leaf_level", | |
use_stream=True, | |
low_cpu_mem_usage=True, | |
) | |
pipe = HiDreamImagePipeline.from_pretrained( | |
repo_id, | |
text_encoder=None, | |
tokenizer=None, | |
text_encoder_2=None, | |
tokenizer_2=None, | |
text_encoder_3=None, | |
tokenizer_3=None, | |
text_encoder_4=None, | |
tokenizer_4=None, | |
transformer=transformer, | |
scheduler=scheduler, | |
vae=None, | |
torch_dtype=torch_dtype, | |
) | |
latents = pipe( | |
**embeddings, | |
height=1024, | |
width=1024, | |
guidance_scale=0.0, | |
num_inference_steps=28, | |
generator=torch.Generator(device).manual_seed(0), | |
output_type="latent", | |
return_dict=False, | |
) | |
del pipe | |
flush(device) | |
return latents | |
with torch.no_grad(): | |
embeddings = encode_prompt(prompt, repo_id, llama_repo_id, device=device, dtype=torch_dtype) | |
latents = denoise(embeddings, device=device, dtype=torch_dtype)[0] | |
vae = AutoencoderKL.from_pretrained(repo_id, subfolder="vae", torch_dtype=torch_dtype).to(device) | |
latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor | |
with torch.no_grad(): | |
image = vae.decode(latents, return_dict=False)[0] | |
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1) | |
image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor * 2) | |
image = image_processor.postprocess(image, output_type="pil")[0] | |
image.save("output.png") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment