Skip to content

Instantly share code, notes, and snippets.

@jwiegley
Created December 19, 2025 20:23
Show Gist options
  • Select an option

  • Save jwiegley/78ae1d74126270ea777849eadd58a28e to your computer and use it in GitHub Desktop.

Select an option

Save jwiegley/78ae1d74126270ea777849eadd58a28e to your computer and use it in GitHub Desktop.
"""Export gpt-oss-20b directly to .fx format using torch.export."""
import torch
from transformers import AutoModelForCausalLM
import os
# Suppress warnings
import warnings
warnings.filterwarnings('ignore')
print("=" * 80)
print("TORCH.EXPORT - DIRECT .FX EXPORT (NO INTERMEDIATE .pt2)")
print("=" * 80)
# Model path
model_path = os.path.expanduser(
"~/.cache/huggingface/hub/models--openai--gpt-oss-20b/snapshots/"
"6cee5e81ee83917806bbde320786a8fb61efebee"
)
print("\n1. Loading model...")
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
)
model.eval()
print(" ✅ Model loaded")
# Step 1: Create ExportWrapper with clean signature
print("\n2. Creating ExportWrapper...")
class ExportWrapper(torch.nn.Module):
"""
Clean wrapper for torch.export that:
1. Has explicit signature (no *args/**kwargs)
2. Bypasses @check_model_inputs decorator
3. Forces dense MoE path
"""
def __init__(self, model):
super().__init__()
self.model = model
def forward(
self,
input_ids: torch.Tensor,
) -> torch.Tensor:
"""
Forward pass that returns only logits.
Args:
input_ids: [batch_size, seq_len] tensor of token IDs
Returns:
logits: [batch_size, seq_len, vocab_size] tensor
"""
# Call undecorated forward, bypassing @check_model_inputs
outputs = self.model.forward.__wrapped__(
self.model,
input_ids,
attention_mask=None,
position_ids=None,
past_key_values=None,
inputs_embeds=None,
use_cache=False,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
cache_position=None,
)
# Return only logits for simplicity
return outputs.logits
wrapper = ExportWrapper(model)
print(" ✅ ExportWrapper created")
# Step 2: Define dynamic shapes
print("\n3. Defining dynamic shapes...")
try:
from torch.export import Dim
# Static batch, dynamic sequence only (batch > 1 causes ConstraintViolationError)
seq_dim = Dim("seq", min=1, max=512) # Allow sequence 1-512
# Map to input tensors - sequence dimension only
dynamic_shapes = {
"input_ids": {1: seq_dim}, # Only seq is dynamic, batch is static at 1
}
print(f" Batch dimension: static (1)")
print(f" Sequence dimension: {seq_dim}")
print(" ✅ Dynamic shapes defined")
except ImportError as e:
print(f" ❌ ERROR: torch.export not available: {e}")
print(" This requires PyTorch 2.1+")
exit(1)
# Step 3: Create example inputs
print("\n4. Creating example inputs...")
example_input_ids = torch.tensor([[1, 2, 3, 4, 5]], dtype=torch.long)
print(f" Example shape: {example_input_ids.shape}")
# Step 4: Attempt export
print("\n5. Attempting torch.export...")
print(" This may take a few minutes...")
try:
# Export with torch.export
exported_program = torch.export.export(
wrapper,
(example_input_ids,), # args tuple
dynamic_shapes=dynamic_shapes,
strict=True, # Strict mode for better error messages
)
print(" ✅ Export succeeded!")
print(f" Graph nodes: {len(list(exported_program.graph.nodes))}")
# Step 5: Extract FX GraphModule directly (no .pt2 intermediate)
print("\n6. Extracting FX GraphModule...")
fx_graph_module = exported_program.graph_module
print(f" ✅ Extracted GraphModule")
print(f" Graph nodes: {len(list(fx_graph_module.graph.nodes))}")
# Step 6: Verify forward pass
print("\n7. Verifying forward pass equivalence...")
# Test with various sequence lengths (batch=1 only)
test_inputs = [
torch.tensor([[1]], dtype=torch.long), # batch=1, seq=1
torch.tensor([[1, 2, 3, 4, 5]], dtype=torch.long), # batch=1, seq=5
torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=torch.long), # batch=1, seq=8
torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]], dtype=torch.long), # batch=1, seq=10
]
all_match = True
for i, test_input in enumerate(test_inputs):
with torch.no_grad():
# Original
orig_logits = wrapper(test_input)
# Exported FX GraphModule
fx_logits = fx_graph_module(test_input)
# Compare
max_diff = (orig_logits - fx_logits).abs().max().item()
mean_diff = (orig_logits - fx_logits).abs().mean().item()
# Tolerance for bfloat16
atol = 1e-2
matches = max_diff < atol
status = "✅" if matches else "❌"
print(f" Test {i+1} {test_input.shape}: {status} max_diff={max_diff:.2e}, mean_diff={mean_diff:.2e}")
if not matches:
all_match = False
if all_match:
print("\n ✅ All forward pass tests passed!")
else:
print("\n ⚠️ Some forward pass tests showed differences (may be numerical)")
# Step 7: Save FX GraphModule directly (no .pt2 file)
print("\n8. Saving FX GraphModule directly to .fx...")
fx_path = "gpt_oss_exported_direct.fx"
torch.save(fx_graph_module, fx_path)
print(f" ✅ Saved FX GraphModule to {fx_path}")
# Also save human-readable graph
graph_txt_path = "gpt_oss_exported_direct_graph.txt"
with open(graph_txt_path, 'w') as f:
f.write(str(fx_graph_module.graph))
print(f" ✅ Saved graph representation to {graph_txt_path}")
print("\n" + "=" * 80)
print("SUCCESS: Model exported directly to .fx format!")
print("=" * 80)
print(f"\nFiles created:")
print(f" - {fx_path} (FX GraphModule)")
print(f" - {graph_txt_path} (Human-readable graph)")
print("\nNo intermediate .pt2 file was created.")
print("\nLoad the FX GraphModule with:")
print(f" fx_module = torch.load('{fx_path}', weights_only=False)")
except Exception as e:
print(f"\n ❌ Export failed: {e}")
print("\nDebugging info:")
import traceback
traceback.print_exc()
print("\n" + "=" * 80)
print("Export failed - see error above")
print("=" * 80)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment