Created
April 23, 2025 15:55
-
-
Save cat-state/3d5e02785127e0f13371787a1a1c40a6 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.nn.attention.flex_attention import flex_attention, create_block_mask, BlockMask | |
| import dataclasses | |
| # torch.use_deterministic_algorithms(True) | |
| torch._dynamo.config.cache_size_limit = 128 | |
| @dataclasses.dataclass | |
| class Config: | |
| bos_token_id: int = 151643 | |
| eos_token_id: int = 151643 | |
| hidden_act: str = "silu" | |
| hidden_size: int = 1536 | |
| initializer_range: float = 0.02 | |
| intermediate_size: int = 8960 | |
| max_position_embeddings: int = 131072 | |
| max_window_layers: int = 21 | |
| model_type: str = "qwen2" | |
| num_attention_heads: int = 12 | |
| num_hidden_layers: int = 28 | |
| num_key_value_heads: int = 2 | |
| rms_norm_eps: float = 1e-6 | |
| rope_theta: int = 10000 | |
| sliding_window: int = 4096 | |
| tie_word_embeddings: bool = False | |
| dtype: torch.dtype = torch.bfloat16 | |
| use_cache: bool = True | |
| use_mrope: bool = False | |
| use_sliding_window: bool = False | |
| vocab_size: int = 151936 | |
| # @torch.compile(fullgraph=True) | |
| def rmsnorm(x, w, eps=1e-6): | |
| ori_dtype = x.dtype | |
| x = x.to(torch.float32) | |
| variance = (x * x).mean(dim=-1, keepdim=True) | |
| x = x * torch.rsqrt(variance + eps) | |
| return w.expand_as(x) * x.to(ori_dtype) | |
| # @torch.compile(fullgraph=True) | |
| def rotary(x, pos_idx, rope_theta): | |
| # x: [batch, seqlen, nheads, headdim] or [batch, seqlen, headdim] | |
| # pos_idx: [seqlen] or [batch, seqlen] | |
| ori_dtype = x.dtype | |
| x = x.to(torch.float32) | |
| shape = x.shape | |
| x = x.view(-1, shape[-2], shape[-1]) if len(shape) == 4 else x | |
| dim = x.shape[-1] | |
| half_dim = dim // 2 | |
| # More efficient freq computation - only compute for half_dim//2 since we repeat | |
| inv_freq = (1.0 / (rope_theta ** (torch.arange(0, half_dim, 2, device=x.device, dtype=x.dtype) / half_dim))) | |
| # Compute freqs directly for pairs of dims | |
| freqs = pos_idx.view(-1, 1) * inv_freq | |
| # Compute sin and cos once for pairs | |
| cos = torch.cos(freqs) | |
| sin = torch.sin(freqs) | |
| # Efficient interleaving without repeat_interleave | |
| cos = torch.stack([cos, cos], dim=-1).flatten(-2) # [seqlen, dim//2] | |
| sin = torch.stack([sin, sin], dim=-1).flatten(-2) # [seqlen, dim//2] | |
| # Apply rotary embedding | |
| x1, x2 = x[..., :half_dim], x[..., half_dim:] | |
| # Single concatenation operation | |
| out = torch.cat([ | |
| x1 * cos.unsqueeze(0) - x2 * sin.unsqueeze(0), | |
| x2 * cos.unsqueeze(0) + x1 * sin.unsqueeze(0) | |
| ], dim=-1) | |
| return out.view(shape).to(ori_dtype) | |
| PRONT = print | |
| class FFFormerLayer(nn.Module): | |
| up_gate: nn.Linear | |
| down: nn.Linear | |
| o: nn.Linear | |
| ilnw: nn.Parameter | |
| alnw: nn.Parameter | |
| def __init__(self, cfg: Config): | |
| super().__init__() | |
| self.cfg = cfg | |
| self.up_gate = nn.Linear(cfg.hidden_size, cfg.intermediate_size * 2, bias=False, dtype=cfg.dtype) | |
| self.down = nn.Linear(cfg.intermediate_size, cfg.hidden_size, bias=False, dtype=cfg.dtype) | |
| # Fix QKV dimensions | |
| self.head_dim = cfg.hidden_size // cfg.num_attention_heads | |
| self.num_heads = cfg.num_attention_heads | |
| self.num_kv_heads = cfg.num_key_value_heads | |
| # Combined QKV with merged biases | |
| total_qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_kv_heads) | |
| self.qkv = nn.Linear(cfg.hidden_size, total_qkv_dim, bias=True, dtype=cfg.dtype) | |
| # Fix split points | |
| self.q_size = self.head_dim * self.num_heads | |
| self.kv_size = self.head_dim * self.num_kv_heads | |
| self.k_start = self.q_size | |
| self.v_start = self.k_start + self.kv_size | |
| self.o = nn.Linear(cfg.hidden_size, cfg.hidden_size, bias=False, dtype=cfg.dtype) | |
| self.ilnw = nn.Parameter(torch.ones([cfg.hidden_size], dtype=torch.float32)) | |
| self.alnw = nn.Parameter(torch.ones([cfg.hidden_size], dtype=torch.float32)) | |
| def forward(self, x, kcache, vcache, write_idxs, position_idxs, attn_mask=None, causal=True): | |
| N, T, C = x.shape | |
| # Apply layer norm and QKV projection | |
| x_ln = rmsnorm(x, self.ilnw, self.cfg.rms_norm_eps) | |
| # PRONT(f"my input{x_ln}") | |
| # PRONT(f"my qkv {self.qkv.weight}") | |
| qkv = self.qkv(x_ln.to(self.cfg.dtype)) # Bias is automatically added by nn.Linear | |
| # Split QKV | |
| q = qkv[..., :self.q_size] | |
| k = qkv[..., self.k_start:self.v_start] | |
| v = qkv[..., self.v_start:] | |
| # Reshape and transpose | |
| # q = q.reshape(N, T * (self.num_heads // self.num_kv_heads), self.num_kv_heads, -1).transpose(1, 2).contiguous() # [N, num_heads, T, head_dim] | |
| q = q.reshape(N, T, self.num_heads, -1).transpose(1, 2).contiguous() # [N, num_heads, T, head_dim] | |
| k = k.reshape(N, T, self.num_kv_heads, -1).transpose(1, 2).contiguous() # [N, num_kv_heads, T, head_dim] | |
| v = v.reshape(N, T, self.num_kv_heads, -1).transpose(1, 2).contiguous() # [N, num_kv_heads, T, head_dim] | |
| # PRONT(f"my q{q[0, 0, 0, :10]}") | |
| # q = rotary(q, position_idxs.repeat_interleave((self.num_heads // self.num_kv_heads), 0), self.cfg.rope_theta) | |
| # print(position_idxs) | |
| q = rotary(q, position_idxs, self.cfg.rope_theta) | |
| k = rotary(k, position_idxs, self.cfg.rope_theta) | |
| # PRONT(f"my qrot{q}") | |
| if kcache is None: | |
| kcache, vcache = k, v | |
| else: | |
| # print(kcache.shape, k.shape, write_idxs.shape, write_idxs) | |
| # kcache[torch.arange(2), torch.arange(2, device=kcache.device), write_idxs | |
| batch_idx = torch.arange(N, device=kcache.device)[:, None, None] # (N, 1, 1) | |
| head_idx = torch.arange(self.num_kv_heads, device=kcache.device)[None, :, None] # (1, H, 1) | |
| write_idxs = write_idxs[:, None, :] # (N, 1, S) | |
| kcache[batch_idx, head_idx, write_idxs, :] = k | |
| vcache[batch_idx, head_idx, write_idxs, :] = v | |
| # kcache.scatter_(2, write_idxs, k) | |
| # vcache.scatter_(2, write_idxs, v) | |
| # attn = flex_attention(q, kcache, vcache, block_mask=attn_mask) | |
| if T > 1 or causal: | |
| attn = F.scaled_dot_product_attention(q, kcache, vcache, is_causal=causal, enable_gqa=True) | |
| attn = attn.transpose(1, 2).reshape(N, T, C) | |
| else: | |
| q = q.reshape(N, self.num_kv_heads, self.num_heads // self.num_kv_heads, -1) | |
| attn = F.scaled_dot_product_attention(q, kcache, vcache, is_causal=False) | |
| attn = attn.reshape(N, self.num_heads, T, -1).transpose(1, 2).reshape(N, T, C) | |
| # attn = attn.transpose(1, 2).reshape(N, T, C) | |
| # PRONT(f"my attn{attn}") | |
| x = x + self.o(attn) | |
| up_gate = self.up_gate(rmsnorm(x, self.alnw, self.cfg.rms_norm_eps).to(self.cfg.dtype)) | |
| mid = F.silu(up_gate[..., :self.cfg.intermediate_size]) * up_gate[..., self.cfg.intermediate_size:] | |
| x = x + self.down(mid) | |
| # PRONT(f"my output{x}") | |
| # globals()["PRONT"] = lambda x: x | |
| return x | |
| class FFFormer(nn.Module): | |
| embed: nn.Embedding | |
| unembed: nn.Linear | |
| oln: nn.Parameter | |
| layers: nn.ModuleList # [FFFormerLayer] | |
| def __init__(self, cfg: Config): | |
| super().__init__() | |
| self.cfg = cfg | |
| self.embed = nn.Embedding(cfg.vocab_size, cfg.hidden_size, dtype=cfg.dtype) | |
| self.unembed = nn.Linear(cfg.hidden_size, cfg.vocab_size, bias=False, dtype=cfg.dtype) | |
| self.oln = nn.Parameter(torch.ones([cfg.hidden_size], dtype=torch.float32)) | |
| self.layers = nn.ModuleList([FFFormerLayer(cfg) for _ in range((cfg.num_hidden_layers))]) | |
| def make_mask(self, input_ids): | |
| N, T = input_ids.shape | |
| qpkv = self.cfg.num_attention_heads // self.cfg.num_key_value_heads | |
| def causalish(b, h, q_idx, kv_idx): | |
| return (q_idx // qpkv) >= kv_idx | |
| mask = create_block_mask(causalish, N, None, T * qpkv, T) | |
| return mask | |
| q_heads_per_kv = self.cfg.num_attention_heads // self.cfg.num_key_value_heads | |
| attn_mask = torch.ones(N, T, T, dtype=torch.bool, device="cuda").tril(diagonal=0).repeat_interleave(q_heads_per_kv, 1) | |
| return attn_mask | |
| @torch.compile(dynamic=False) | |
| def forward(self, input_ids, kcache, vcache, write_idxs, position_idxs, attn_mask, causal=True): | |
| x = self.embed(input_ids).to(self.cfg.dtype) | |
| for i, l in enumerate(self.layers): | |
| kcache_i, vcache_i = (None, None) if kcache is None else (kcache[i], vcache[i]) | |
| x = l(x, kcache_i, vcache_i, write_idxs, position_idxs, attn_mask, causal=causal) | |
| x = rmsnorm(x, self.oln, self.cfg.rms_norm_eps).to(self.cfg.dtype) | |
| return self.unembed(x) | |
| @staticmethod | |
| def from_llamaish_hf_model(hf_cfg, state_dict): | |
| state_dict = hf_model.state_dict() | |
| hf_cfg = hf_model.config | |
| cfg = Config( | |
| vocab_size=hf_cfg.vocab_size, | |
| hidden_size=hf_cfg.hidden_size, | |
| num_hidden_layers=hf_cfg.num_hidden_layers, | |
| num_attention_heads=hf_cfg.num_attention_heads, | |
| num_key_value_heads=hf_cfg.num_key_value_heads, | |
| rms_norm_eps=hf_cfg.rms_norm_eps, | |
| dtype=torch.bfloat16, | |
| intermediate_size=hf_cfg.intermediate_size, | |
| max_position_embeddings=hf_cfg.max_position_embeddings, | |
| max_window_layers=hf_cfg.max_window_layers, | |
| sliding_window=hf_cfg.sliding_window, | |
| tie_word_embeddings=hf_cfg.tie_word_embeddings, | |
| use_cache=hf_cfg.use_cache, | |
| use_mrope=hf_cfg.use_mrope, | |
| use_sliding_window=hf_cfg.use_sliding_window, | |
| initializer_range=hf_cfg.initializer_range, | |
| hidden_act=hf_cfg.hidden_act, | |
| bos_token_id=hf_cfg.bos_token_id, | |
| eos_token_id=hf_cfg.eos_token_id, | |
| rope_theta=hf_cfg.rope_theta | |
| ) | |
| with torch.device('meta'): | |
| model = FFFormer(cfg) | |
| new_state = {} | |
| # Track which source keys we've handled | |
| handled_keys = set() | |
| # Direct mappings with regex patterns | |
| key_maps = { | |
| r'model\.embed_tokens\.weight': ('embed.weight', cfg.dtype), | |
| r'model\.norm\.weight': ('oln', cfg.dtype), | |
| r'lm_head\.weight': ('unembed.weight', cfg.dtype), | |
| r'model\.layers\.(\d+)\.input_layernorm\.weight': lambda m: (f'layers.{m.group(1)}.ilnw', cfg.dtype), | |
| r'model\.layers\.(\d+)\.post_attention_layernorm\.weight': lambda m: (f'layers.{m.group(1)}.alnw', cfg.dtype), | |
| r'model\.layers\.(\d+)\.self_attn\.o_proj\.weight': lambda m: (f'layers.{m.group(1)}.o.weight', cfg.dtype), | |
| r'model\.layers\.(\d+)\.mlp\.down_proj\.weight': lambda m: (f'layers.{m.group(1)}.down.weight', cfg.dtype), | |
| } | |
| # First pass - direct mappings and collecting weights for merging | |
| qkv_weights = {} | |
| qkv_biases = {} | |
| upgate_weights = {} | |
| import re | |
| for k, v in state_dict.items(): | |
| # Try direct mappings first | |
| mapped = False | |
| for pattern, target in key_maps.items(): | |
| match = re.match(pattern, k) | |
| if match: | |
| if callable(target): | |
| new_key, dtype = target(match) | |
| else: | |
| new_key, dtype = target | |
| new_state[new_key] = v.to(dtype).detach().clone() | |
| handled_keys.add(k) | |
| mapped = True | |
| break | |
| if mapped: | |
| continue | |
| # Collect QKV weights and biases | |
| qkv_match = re.match(r'model\.layers\.(\d+)\.self_attn\.(q|k|v)_proj\.(weight|bias)', k) | |
| if qkv_match: | |
| layer, qkv_type, param_type = qkv_match.groups() | |
| if param_type == 'weight': | |
| qkv_weights.setdefault(layer, {})[qkv_type] = v | |
| else: | |
| qkv_biases.setdefault(layer, {})[qkv_type] = v | |
| handled_keys.add(k) | |
| continue | |
| # Collect up/gate weights | |
| upgate_match = re.match(r'model\.layers\.(\d+)\.mlp\.(up|gate)_proj\.weight', k) | |
| if upgate_match: | |
| layer, proj_type = upgate_match.groups() | |
| upgate_weights.setdefault(layer, {})[proj_type] = v | |
| handled_keys.add(k) | |
| continue | |
| # Check for unhandled source keys | |
| unhandled_keys = set(state_dict.keys()) - handled_keys | |
| if unhandled_keys: | |
| raise ValueError(f"Found unhandled keys in source state dict: {unhandled_keys}") | |
| # Second pass - merge collected weights | |
| for layer in range(cfg.num_hidden_layers): | |
| layer = str(layer) | |
| # Merge QKV weights and set biases | |
| if layer in qkv_weights and layer in qkv_biases: | |
| qkv_w_dict = qkv_weights[layer] | |
| qkv_b_dict = qkv_biases[layer] | |
| if not all(k in qkv_w_dict for k in ['q', 'k', 'v']) or not all(k in qkv_b_dict for k in ['q', 'k', 'v']): | |
| raise ValueError(f"Missing QKV weights/biases for layer {layer}") | |
| # Merge weights and biases | |
| q_w, k_w, v_w = qkv_w_dict['q'], qkv_w_dict['k'], qkv_w_dict['v'] | |
| q_b, k_b, v_b = qkv_b_dict['q'], qkv_b_dict['k'], qkv_b_dict['v'] | |
| # Merge weights | |
| qkv_merged = torch.cat([q_w, k_w, v_w], dim=0) | |
| new_state[f'layers.{layer}.qkv.weight'] = qkv_merged | |
| # Merge biases | |
| qkv_bias_merged = torch.cat([q_b, k_b, v_b]) | |
| new_state[f'layers.{layer}.qkv.bias'] = qkv_bias_merged | |
| else: | |
| raise ValueError(f"No QKV weights/biases found for layer {layer}") | |
| # Merge up/gate | |
| if layer in upgate_weights: | |
| upgate_dict = upgate_weights[layer] | |
| if not all(k in upgate_dict for k in ['gate', 'up']): | |
| raise ValueError(f"Missing up/gate weights for layer {layer}") | |
| gate, up = upgate_dict['gate'], upgate_dict['up'] | |
| upgate_merged = torch.cat([gate, up], dim=0) | |
| new_state[f'layers.{layer}.up_gate.weight'] = upgate_merged | |
| else: | |
| raise ValueError(f"No up/gate weights found for layer {layer}") | |
| # Verify all required keys are present in new state dict | |
| expected_keys = set(model.state_dict().keys()) | |
| missing_keys = expected_keys - set(new_state.keys()) | |
| if missing_keys: | |
| raise ValueError(f"Missing required keys in converted state dict: {missing_keys}") | |
| # Load the state dict | |
| new_state = {k: v.to(cfg.dtype) for k, v in new_state.items()} | |
| model.load_state_dict(new_state, strict=True, assign=True) | |
| for k, p in model.named_parameters(): | |
| assert p.device != 'meta', k | |
| return model | |
| import tqdm | |
| @torch.inference_mode | |
| def generate(model: FFFormer, input_ids: torch.Tensor, new_tokens: int): | |
| prefill_length = input_ids.shape[1] | |
| kcache = torch.full((model.cfg.num_hidden_layers, input_ids.shape[0], model.cfg.num_key_value_heads, input_ids.shape[1] + new_tokens, 128), fill_value=float("-inf"), dtype=model.cfg.dtype).cuda() | |
| vcache = torch.full((model.cfg.num_hidden_layers, input_ids.shape[0], model.cfg.num_key_value_heads, input_ids.shape[1] + new_tokens, 128), fill_value=float("-inf"), dtype=model.cfg.dtype).cuda() | |
| prefill_write_idxs = torch.arange(input_ids.shape[1])[None, :].cuda() | |
| prefill_position_idxs = torch.arange(input_ids.shape[1]).cuda() | |
| mask = model.make_mask(input_ids) | |
| prefill_logits = model(input_ids, kcache[:, :, :, :prefill_length], vcache[:, :, :, :prefill_length], prefill_write_idxs, prefill_position_idxs, attn_mask=None, causal=True) | |
| last_output_token = torch.argmax(prefill_logits[:, -1:], dim=-1) | |
| all_output_tokens = [last_output_token] | |
| write_idxs = prefill_write_idxs[:, -1:] + 1 | |
| position_idxs = prefill_position_idxs[-1:] + 1 | |
| for i in range(new_tokens): | |
| chunk_length = 256 * ((prefill_length + i + 256 - 1) // 256) | |
| print(chunk_length) | |
| output_logits = model(last_output_token, kcache[:, :, :, :chunk_length], vcache[:, :, :, :chunk_length], write_idxs, position_idxs, attn_mask=None, causal=False) | |
| #cur_tokens = torch.cat([input_ids] + all_output_tokens, dim=1) | |
| #output_logits = model(cur_tokens, None, None, None, position_idxs, attn_mask=mask, causal=False) | |
| output_logits = F.softmax(output_logits[:, prefill_length+i-1:prefill_length+i], dim=-1) | |
| # next_output_token = torch.multinomial(output_logits, 1) | |
| next_output_token = torch.argmax(output_logits, dim=-1) | |
| write_idxs = write_idxs.add_(1) | |
| position_idxs = position_idxs.add_(1) | |
| last_output_token = next_output_token | |
| all_output_tokens.append(next_output_token) | |
| return torch.cat(all_output_tokens, dim=1) | |
| if True or __name__ == "__main__": | |
| import time | |
| import torch.cuda | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| def benchmark_model(model: FFFormer, inputs, num_runs=50, warmup=10): | |
| # Move everything to GPU and clear cache | |
| input_ids = inputs["input_ids"].cuda() | |
| torch.cuda.synchronize() | |
| torch.cuda.empty_cache() | |
| kcache = torch.zeros(model.cfg.num_hidden_layers, model.cfg.num_key_value_heads, input_ids.shape[1], 128, dtype=model.cfg.dtype).cuda() | |
| vcache = torch.zeros(model.cfg.num_hidden_layers, model.cfg.num_key_value_heads, input_ids.shape[1], 128, dtype=model.cfg.dtype).cuda() | |
| write_idxs = torch.arange(input_ids.shape[1])[None, None, :, None].cuda() | |
| position_idxs = torch.arange(input_ids.shape[1]).cuda() | |
| mask = model.make_mask(input_ids) | |
| # Warmup runs | |
| for _ in range(warmup): | |
| with torch.no_grad(): | |
| _ = model(input_ids, kcache, vcache, write_idxs, position_idxs, mask) | |
| # Measure GPU memory | |
| torch.cuda.synchronize() | |
| memory_allocated = torch.cuda.memory_allocated() / 1024**2 # MB | |
| memory_reserved = torch.cuda.memory_reserved() / 1024**2 # MB | |
| # Benchmark runs | |
| torch.cuda.synchronize() | |
| start_time = time.perf_counter() | |
| for _ in range(num_runs): | |
| with torch.no_grad(): | |
| _ = model(input_ids, kcache, vcache, write_idxs, position_idxs, mask) | |
| torch.cuda.synchronize() | |
| end_time = time.perf_counter() | |
| avg_time = (end_time - start_time) / num_runs | |
| return avg_time, memory_allocated, memory_reserved | |
| def benchmark_hf_model(model, inputs, num_runs=50, warmup=10): | |
| # Move everything to GPU and clear cache | |
| inputs = {k: v.cuda() for k, v in inputs.items()} | |
| torch.cuda.synchronize() | |
| torch.cuda.empty_cache() | |
| # Warmup runs | |
| for _ in range(warmup): | |
| with torch.no_grad(): | |
| _ = model(**inputs) | |
| # Measure GPU memory | |
| torch.cuda.synchronize() | |
| memory_allocated = torch.cuda.memory_allocated() / 1024**2 # MB | |
| memory_reserved = torch.cuda.memory_reserved() / 1024**2 # MB | |
| # Benchmark runs | |
| torch.cuda.synchronize() | |
| start_time = time.perf_counter() | |
| for _ in range(num_runs): | |
| with torch.no_grad(): | |
| _ = model(**inputs) | |
| torch.cuda.synchronize() | |
| end_time = time.perf_counter() | |
| avg_time = (end_time - start_time) / num_runs | |
| return avg_time, memory_allocated, memory_reserved | |
| print("\n=== Starting Benchmark ===") | |
| # Load models | |
| model_name = "Qwen/Qwen2.5-7B" | |
| # model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" | |
| hf_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto").cuda() | |
| our_model = FFFormer.from_llamaish_hf_model(hf_model.config, hf_model.state_dict()).cuda() | |
| hf_model.eval() | |
| our_model.eval() | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| # Prepare inputs of different lengths | |
| sequences = [ | |
| "Hello.<think>", # Short | |
| "The quick brown fox jumps over the lazy dog. " * 16, # Medium | |
| "The quick brown fox jumps over the lazy dog. " * 128, # Long | |
| "The quick brown fox jumps over the lazy dog. " * 1280, # Long | |
| ] | |
| torch.set_grad_enabled(False) | |
| inputs = tokenizer([sequences[0]], return_tensors="pt", add_special_tokens=True) | |
| inputs = {k:v.cuda() for k,v in inputs.items()} | |
| # res = tokenizer.decode(hf_model.generate(**inputs, max_new_tokens=10, do_sample=False)[0]) | |
| #our_res = tokenizer.decode(torch.cat([inputs["input_ids"], generate(our_model, inputs["input_ids"], 10)], dim=1)[0]) | |
| #print("hf", res) | |
| #print("us", our_res) | |
| print("\nBenchmarking with different sequence lengths:") | |
| print(f"{'Sequence Length':<15} {'HF Time (ms)':<15} {'Our Time (ms)':<15} {'HF Memory (MB)':<15} {'Our Memory (MB)':<15} {'Speedup':<10}") | |
| print("-" * 80) | |
| for seq in sequences: | |
| torch.cuda.empty_cache() | |
| torch._dynamo.reset() | |
| inputs = tokenizer([seq], return_tensors="pt", add_special_tokens=True) | |
| seq_len = inputs["input_ids"].shape[1] | |
| inputs = {k:v.cuda() for k, v in inputs.items()} | |
| cat_inputs = inputs["input_ids"].repeat(32, 1) | |
| gen_tok_warmup = generate(our_model, cat_inputs, 10000) | |
| torch.cuda.synchronize() | |
| t = time.perf_counter() | |
| gen_tok = torch.cat([cat_inputs, generate(our_model, cat_inputs, 10000)], dim=1) | |
| gen_dt = time.perf_counter() - t | |
| print(f"generation of 32x{gen_tok.shape[1]} toks took {gen_dt}s, seq: {gen_tok.shape[1] / gen_dt} otok/s par: {gen_tok.numel()}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment