Skip to content

Instantly share code, notes, and snippets.

@justinchuby
Created January 1, 2025 01:08
Show Gist options
  • Save justinchuby/9ef952501429f631efe517c7e29586aa to your computer and use it in GitHub Desktop.
Save justinchuby/9ef952501429f631efe517c7e29586aa to your computer and use it in GitHub Desktop.
import os
import torch
import transformers
device = "cpu"
config = {
"_name_or_path": "/fsx/loubna/checkpoints/cosmo2_1T/500000",
"architectures": ["LlamaForCausalLM"],
"attention_bias": False,
"attention_dropout": 0.0,
"bos_token_id": 0,
"eos_token_id": 0,
"hidden_act": "silu",
"hidden_size": 2048,
"initializer_range": 0.02,
"intermediate_size": 8192,
"max_position_embeddings": 2048,
"model_type": "llama",
"num_attention_heads": 32,
"num_hidden_layers": 24,
"num_key_value_heads": 32,
"pretraining_tp": 1,
"rms_norm_eps": 1e-05,
"rope_scaling": None,
"rope_theta": 10000.0,
"tie_word_embeddings": True,
"torch_dtype": "float32",
"transformers_version": "4.39.3",
"use_cache": True,
"vocab_size": 49152,
}
config.update(
{
"_from_model_config": True,
"bos_token_id": 0,
"eos_token_id": 0,
"transformers_version": "4.39.3",
"num_hidden_layers": 1,
}
)
conf = transformers.LlamaConfig(**config)
model = transformers.LlamaForCausalLM(conf)
model.eval()
batch_size = 2
sequence_length = 30
vocab_size = config["vocab_size"]
max_position_embeddings = config["max_position_embeddings"]
num_heads = model.config.num_attention_heads
past_sequence_length = 16 # Example past sequence length
head_dim = model.config.hidden_size // num_heads
num_layers = model.config.num_hidden_layers
total_sequence_length = sequence_length + past_sequence_length
dim = (batch_size, sequence_length)
input_ids = torch.randint(0, vocab_size, (batch_size, sequence_length)).to(torch.int64)
attention_mask = torch.ones((batch_size, total_sequence_length), device=device)
position_ids = torch.randint(0, max_position_embeddings, (batch_size, sequence_length)).to(torch.int64)
# past key values: (batch_size, num_heads, past_sequence_length, head_dim) for each
# Totally, there are 2 * num_layers past key values
# Generate random past key values
past_key_values = []
for _ in range(num_layers):
past_key = torch.rand(batch_size, num_heads, past_sequence_length, head_dim, device=device)
past_value = torch.rand(batch_size, num_heads, past_sequence_length, head_dim, device=device)
past_key_values.append((past_key, past_value))
# Combine all inputs
inputs = (input_ids, attention_mask, position_ids, tuple(past_key_values))
batch_dim = torch.export.Dim("batch_size")
sequence_dim = torch.export.Dim("sequence_length", max=128)
total_sequence_dim = sequence_dim + 16 #torch.export.Dim("total_sequence_length")
program = torch.onnx.export(model, inputs, dynamo=True, fallback=False, report=True,
dynamic_shapes={
'input_ids': {0: batch_dim, 1: sequence_dim},
'attention_mask': {0: batch_dim, 1: total_sequence_dim},
'position_ids': {0: batch_dim, 1: sequence_dim},
'past_key_values': (({}, {}),),
},
dump_exported_program=True,
)
root = os.path.dirname(os.path.abspath(__file__))
program.save(os.path.join(root, "model.onnx"))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment