Skip to content

Instantly share code, notes, and snippets.

@laksjdjf
Created October 3, 2023 11:27
Show Gist options
  • Save laksjdjf/0abd398e9feb8686fd87dade337657db to your computer and use it in GitHub Desktop.
Save laksjdjf/0abd398e9feb8686fd87dade337657db to your computer and use it in GitHub Desktop.
# ref:https://github.com/tfernd/HyperTile
from einops import rearrange
'''
1. put this file in ComfyUI/custom_nodes
2. load node from <loader>
3. set nh and nw (2-4 is recommended. if 1 is set, it will be the same as original)
'''
def to_tile(x, nh, nw, original_shape):
_, _, h, w = original_shape
assert h % nh == 0, "Height must be divisible by nh"
assert w % nw == 0, "Width must be divisible by nw"
x = rearrange(x, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)
return x
def from_tile(x, nh, nw, original_shape):
_, _, h, w = original_shape
x = rearrange(x, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw)
x = rearrange(x, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw)
return x
class HyperTile:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL", ),
"nh": ("INT", {
"default": 1,
"min": 1, # Minimum value
"max": 64, # Maximum value
"step": 1, # Slider's step
"display": "number" # Cosmetic only: display as "number" or "slider"
}),
"nw": ("INT", {
"default": 1,
"min": 1, # Minimum value
"max": 64, # Maximum value
"step": 1, # Slider's step
"display": "number" # Cosmetic only: display as "number" or "slider"
}),
}
}
RETURN_TYPES = ("MODEL", )
FUNCTION = "apply"
CATEGORY = "loaders"
def apply(self, model, nh, nw):
new_model = model.clone()
self.applied = False
def attn1_patch(q, k, v, extra_options):
_, _, h, w = extra_options["original_shape"]
_, qn, _ = q.shape
if qn == h * w:
self.applied = True
q = to_tile(q, nh, nw, extra_options["original_shape"])
k = to_tile(k, nh, nw, extra_options["original_shape"])
v = to_tile(v, nh, nw, extra_options["original_shape"])
return q, k, v
def attn1_output_patch(out, extra_options):
if self.applied:
out = from_tile(out, nh, nw, extra_options["original_shape"])
self.applied = False
return out
new_model.set_model_attn1_patch(attn1_patch)
new_model.set_model_attn1_output_patch(attn1_output_patch)
return (new_model, )
NODE_CLASS_MAPPINGS = {
"HyperTile": HyperTile,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"HyperTile": "Apply HyperTile",
}
__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment