Forked from kohya-ss/forward_of_sdxl_original_unet.py
Created
February 21, 2024 13:02
-
-
Save AmesianX/133831676e24c8812e52da492cd4d074 to your computer and use it in GitHub Desktop.
SDXLで高解像度での構図の破綻を軽減する
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def forward(self, x, timesteps=None, context=None, y=None, **kwargs): | |
# broadcast timesteps to batch dimension | |
timesteps = timesteps.expand(x.shape[0]) | |
hs = [] | |
t_emb = get_timestep_embedding(timesteps, self.model_channels) # , repeat_only=False) | |
t_emb = t_emb.to(x.dtype) | |
emb = self.time_embed(t_emb) | |
assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}" | |
assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}" | |
# assert x.dtype == self.dtype | |
emb = emb + self.label_emb(y) | |
def call_module(module, h, emb, context): | |
x = h | |
for layer in module: | |
# print(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None) | |
if isinstance(layer, ResnetBlock2D): | |
x = layer(x, emb) | |
elif isinstance(layer, Transformer2DModel): | |
x = layer(x, context) | |
else: | |
x = layer(x) | |
return x | |
# h = x.type(self.dtype) | |
h = x | |
# downsample depth | |
# 深くすると全体の構図は安定するが、キャラがゆがむ。浅くしすぎると細部が混沌とする | |
ds_depth_1 = 3 # 2~4 くらいがよさそう | |
ds_depth_2 = 3 # depth_1より+0~+2くらいがよさそう | |
# downsample timestep | |
# 大きくすると構図が乱れて、小さくするとディテールが乱れる | |
ds_timestep_1 = 900 | |
ds_timestep_2 = 650 # timestep_1より小さいこと。デッサンに影響する | |
depth = 0 # current depth | |
for module in self.input_blocks: | |
h = call_module(module, h, emb, context) | |
hs.append(h) | |
# print(depth, h.shape, timesteps) | |
if (depth == ds_depth_1 and timesteps[0] > ds_timestep_1) or ( | |
depth == ds_depth_2 and ds_timestep_1 > timesteps[0] and timesteps[0] > ds_timestep_2 | |
): | |
# bicubicでないとゆがむ、align_cornersはあまり影響しない模様 | |
h = F.interpolate(h.float(), scale_factor=0.5, mode="bicubic", align_corners=False).to(h.dtype) # bfloat16対応 | |
depth += 1 | |
h = call_module(self.middle_block, h, emb, context) | |
for module in self.output_blocks: | |
depth -= 1 | |
# print(depth, h.shape) | |
if (depth == ds_depth_1 and timesteps[0] > ds_timestep_1) or ( | |
depth == ds_depth_2 and ds_timestep_1 > timesteps[0] and timesteps[0] > ds_timestep_2 | |
): | |
h = F.interpolate(h.float(), scale_factor=2.0, mode="bicubic", align_corners=False).to(h.dtype) # bfloat16対応 | |
h = torch.cat([h, hs.pop()], dim=1) | |
h = call_module(module, h, emb, context) | |
h = h.type(x.dtype) | |
h = call_module(self.out, h, emb, context) | |
return h |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment