Created
September 1, 2025 14:14
-
-
Save cbensimon/8dc0ffcd7ee024d91333f6df01907916 to your computer and use it in GitHub Desktop.
Multi-compile + dispatch for ZeroGPU AoT compilation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@spaces.GPU | |
def compile(): | |
with spaces.aoti_capture(pipe.transformer) as call_landscape: | |
pipe("prompt", width=832, height=480) | |
with spaces.aoti_capture(pipe.transformer) as call_portrait: | |
pipe("prompt", width=480, height=832) | |
exported_landscape = torch.export.export(model, args=call_landscape.args, kwargs=call_landscape.kwargs) | |
exported_portrait = torch.export.export(model, args=call_portrait.args, kwargs=call_portrait.kwargs) | |
compiled_landscape = spaces.aoti_compile(exported_landscape) | |
compiled_portrait = spaces.aoti_compile(exported_portrait) | |
# The following line is very important as landscape and portrait weights would duplicate when returning outside of `@spaces.GPU` | |
compiled_portrait.weights = compiled_landscape.weights | |
return compiled_landscape, compiled_portrait | |
compiled_landscape, compiled_portrait = compile() | |
def combined(*args, **kwargs): | |
hidden_states = kwargs['hidden_states'] | |
if hidden_states.shape[-1] > hidden_states.shape[-2]: | |
return compiled_landscape(*args, **kwargs) | |
else: | |
return compiled_portrait(*args, **kwargs) | |
spaces.aoti_apply(combined, pipe.transformer) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment