Created
October 3, 2023 11:27
-
-
Save laksjdjf/0abd398e9feb8686fd87dade337657db to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# ref:https://github.com/tfernd/HyperTile | |
from einops import rearrange | |
''' | |
1. put this file in ComfyUI/custom_nodes | |
2. load node from <loader> | |
3. set nh and nw (2-4 is recommended. if 1 is set, it will be the same as original) | |
''' | |
def to_tile(x, nh, nw, original_shape): | |
_, _, h, w = original_shape | |
assert h % nh == 0, "Height must be divisible by nh" | |
assert w % nw == 0, "Width must be divisible by nw" | |
x = rearrange(x, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw) | |
return x | |
def from_tile(x, nh, nw, original_shape): | |
_, _, h, w = original_shape | |
x = rearrange(x, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw) | |
x = rearrange(x, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw) | |
return x | |
class HyperTile: | |
@classmethod | |
def INPUT_TYPES(s): | |
return { | |
"required": { | |
"model": ("MODEL", ), | |
"nh": ("INT", { | |
"default": 1, | |
"min": 1, # Minimum value | |
"max": 64, # Maximum value | |
"step": 1, # Slider's step | |
"display": "number" # Cosmetic only: display as "number" or "slider" | |
}), | |
"nw": ("INT", { | |
"default": 1, | |
"min": 1, # Minimum value | |
"max": 64, # Maximum value | |
"step": 1, # Slider's step | |
"display": "number" # Cosmetic only: display as "number" or "slider" | |
}), | |
} | |
} | |
RETURN_TYPES = ("MODEL", ) | |
FUNCTION = "apply" | |
CATEGORY = "loaders" | |
def apply(self, model, nh, nw): | |
new_model = model.clone() | |
self.applied = False | |
def attn1_patch(q, k, v, extra_options): | |
_, _, h, w = extra_options["original_shape"] | |
_, qn, _ = q.shape | |
if qn == h * w: | |
self.applied = True | |
q = to_tile(q, nh, nw, extra_options["original_shape"]) | |
k = to_tile(k, nh, nw, extra_options["original_shape"]) | |
v = to_tile(v, nh, nw, extra_options["original_shape"]) | |
return q, k, v | |
def attn1_output_patch(out, extra_options): | |
if self.applied: | |
out = from_tile(out, nh, nw, extra_options["original_shape"]) | |
self.applied = False | |
return out | |
new_model.set_model_attn1_patch(attn1_patch) | |
new_model.set_model_attn1_output_patch(attn1_output_patch) | |
return (new_model, ) | |
NODE_CLASS_MAPPINGS = { | |
"HyperTile": HyperTile, | |
} | |
NODE_DISPLAY_NAME_MAPPINGS = { | |
"HyperTile": "Apply HyperTile", | |
} | |
__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment