Created
December 19, 2025 20:23
-
-
Save jwiegley/78ae1d74126270ea777849eadd58a28e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """Export gpt-oss-20b directly to .fx format using torch.export.""" | |
| import torch | |
| from transformers import AutoModelForCausalLM | |
| import os | |
| # Suppress warnings | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| print("=" * 80) | |
| print("TORCH.EXPORT - DIRECT .FX EXPORT (NO INTERMEDIATE .pt2)") | |
| print("=" * 80) | |
| # Model path | |
| model_path = os.path.expanduser( | |
| "~/.cache/huggingface/hub/models--openai--gpt-oss-20b/snapshots/" | |
| "6cee5e81ee83917806bbde320786a8fb61efebee" | |
| ) | |
| print("\n1. Loading model...") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_path, | |
| torch_dtype=torch.bfloat16, | |
| low_cpu_mem_usage=True, | |
| ) | |
| model.eval() | |
| print(" ✅ Model loaded") | |
| # Step 1: Create ExportWrapper with clean signature | |
| print("\n2. Creating ExportWrapper...") | |
| class ExportWrapper(torch.nn.Module): | |
| """ | |
| Clean wrapper for torch.export that: | |
| 1. Has explicit signature (no *args/**kwargs) | |
| 2. Bypasses @check_model_inputs decorator | |
| 3. Forces dense MoE path | |
| """ | |
| def __init__(self, model): | |
| super().__init__() | |
| self.model = model | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """ | |
| Forward pass that returns only logits. | |
| Args: | |
| input_ids: [batch_size, seq_len] tensor of token IDs | |
| Returns: | |
| logits: [batch_size, seq_len, vocab_size] tensor | |
| """ | |
| # Call undecorated forward, bypassing @check_model_inputs | |
| outputs = self.model.forward.__wrapped__( | |
| self.model, | |
| input_ids, | |
| attention_mask=None, | |
| position_ids=None, | |
| past_key_values=None, | |
| inputs_embeds=None, | |
| use_cache=False, | |
| output_attentions=False, | |
| output_hidden_states=False, | |
| return_dict=True, | |
| cache_position=None, | |
| ) | |
| # Return only logits for simplicity | |
| return outputs.logits | |
| wrapper = ExportWrapper(model) | |
| print(" ✅ ExportWrapper created") | |
| # Step 2: Define dynamic shapes | |
| print("\n3. Defining dynamic shapes...") | |
| try: | |
| from torch.export import Dim | |
| # Static batch, dynamic sequence only (batch > 1 causes ConstraintViolationError) | |
| seq_dim = Dim("seq", min=1, max=512) # Allow sequence 1-512 | |
| # Map to input tensors - sequence dimension only | |
| dynamic_shapes = { | |
| "input_ids": {1: seq_dim}, # Only seq is dynamic, batch is static at 1 | |
| } | |
| print(f" Batch dimension: static (1)") | |
| print(f" Sequence dimension: {seq_dim}") | |
| print(" ✅ Dynamic shapes defined") | |
| except ImportError as e: | |
| print(f" ❌ ERROR: torch.export not available: {e}") | |
| print(" This requires PyTorch 2.1+") | |
| exit(1) | |
| # Step 3: Create example inputs | |
| print("\n4. Creating example inputs...") | |
| example_input_ids = torch.tensor([[1, 2, 3, 4, 5]], dtype=torch.long) | |
| print(f" Example shape: {example_input_ids.shape}") | |
| # Step 4: Attempt export | |
| print("\n5. Attempting torch.export...") | |
| print(" This may take a few minutes...") | |
| try: | |
| # Export with torch.export | |
| exported_program = torch.export.export( | |
| wrapper, | |
| (example_input_ids,), # args tuple | |
| dynamic_shapes=dynamic_shapes, | |
| strict=True, # Strict mode for better error messages | |
| ) | |
| print(" ✅ Export succeeded!") | |
| print(f" Graph nodes: {len(list(exported_program.graph.nodes))}") | |
| # Step 5: Extract FX GraphModule directly (no .pt2 intermediate) | |
| print("\n6. Extracting FX GraphModule...") | |
| fx_graph_module = exported_program.graph_module | |
| print(f" ✅ Extracted GraphModule") | |
| print(f" Graph nodes: {len(list(fx_graph_module.graph.nodes))}") | |
| # Step 6: Verify forward pass | |
| print("\n7. Verifying forward pass equivalence...") | |
| # Test with various sequence lengths (batch=1 only) | |
| test_inputs = [ | |
| torch.tensor([[1]], dtype=torch.long), # batch=1, seq=1 | |
| torch.tensor([[1, 2, 3, 4, 5]], dtype=torch.long), # batch=1, seq=5 | |
| torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=torch.long), # batch=1, seq=8 | |
| torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]], dtype=torch.long), # batch=1, seq=10 | |
| ] | |
| all_match = True | |
| for i, test_input in enumerate(test_inputs): | |
| with torch.no_grad(): | |
| # Original | |
| orig_logits = wrapper(test_input) | |
| # Exported FX GraphModule | |
| fx_logits = fx_graph_module(test_input) | |
| # Compare | |
| max_diff = (orig_logits - fx_logits).abs().max().item() | |
| mean_diff = (orig_logits - fx_logits).abs().mean().item() | |
| # Tolerance for bfloat16 | |
| atol = 1e-2 | |
| matches = max_diff < atol | |
| status = "✅" if matches else "❌" | |
| print(f" Test {i+1} {test_input.shape}: {status} max_diff={max_diff:.2e}, mean_diff={mean_diff:.2e}") | |
| if not matches: | |
| all_match = False | |
| if all_match: | |
| print("\n ✅ All forward pass tests passed!") | |
| else: | |
| print("\n ⚠️ Some forward pass tests showed differences (may be numerical)") | |
| # Step 7: Save FX GraphModule directly (no .pt2 file) | |
| print("\n8. Saving FX GraphModule directly to .fx...") | |
| fx_path = "gpt_oss_exported_direct.fx" | |
| torch.save(fx_graph_module, fx_path) | |
| print(f" ✅ Saved FX GraphModule to {fx_path}") | |
| # Also save human-readable graph | |
| graph_txt_path = "gpt_oss_exported_direct_graph.txt" | |
| with open(graph_txt_path, 'w') as f: | |
| f.write(str(fx_graph_module.graph)) | |
| print(f" ✅ Saved graph representation to {graph_txt_path}") | |
| print("\n" + "=" * 80) | |
| print("SUCCESS: Model exported directly to .fx format!") | |
| print("=" * 80) | |
| print(f"\nFiles created:") | |
| print(f" - {fx_path} (FX GraphModule)") | |
| print(f" - {graph_txt_path} (Human-readable graph)") | |
| print("\nNo intermediate .pt2 file was created.") | |
| print("\nLoad the FX GraphModule with:") | |
| print(f" fx_module = torch.load('{fx_path}', weights_only=False)") | |
| except Exception as e: | |
| print(f"\n ❌ Export failed: {e}") | |
| print("\nDebugging info:") | |
| import traceback | |
| traceback.print_exc() | |
| print("\n" + "=" * 80) | |
| print("Export failed - see error above") | |
| print("=" * 80) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment