Skip to content

Instantly share code, notes, and snippets.

@cbensimon
Created September 1, 2025 14:14
Show Gist options
  • Save cbensimon/8dc0ffcd7ee024d91333f6df01907916 to your computer and use it in GitHub Desktop.
Save cbensimon/8dc0ffcd7ee024d91333f6df01907916 to your computer and use it in GitHub Desktop.
Multi-compile + dispatch for ZeroGPU AoT compilation
@spaces.GPU
def compile():
with spaces.aoti_capture(pipe.transformer) as call_landscape:
pipe("prompt", width=832, height=480)
with spaces.aoti_capture(pipe.transformer) as call_portrait:
pipe("prompt", width=480, height=832)
exported_landscape = torch.export.export(model, args=call_landscape.args, kwargs=call_landscape.kwargs)
exported_portrait = torch.export.export(model, args=call_portrait.args, kwargs=call_portrait.kwargs)
compiled_landscape = spaces.aoti_compile(exported_landscape)
compiled_portrait = spaces.aoti_compile(exported_portrait)
# The following line is very important as landscape and portrait weights would duplicate when returning outside of `@spaces.GPU`
compiled_portrait.weights = compiled_landscape.weights
return compiled_landscape, compiled_portrait
compiled_landscape, compiled_portrait = compile()
def combined(*args, **kwargs):
hidden_states = kwargs['hidden_states']
if hidden_states.shape[-1] > hidden_states.shape[-2]:
return compiled_landscape(*args, **kwargs)
else:
return compiled_portrait(*args, **kwargs)
spaces.aoti_apply(combined, pipe.transformer)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment