Created
November 14, 2023 03:39
-
-
Save kohya-ss/3f774da220df102548093a7abc8538ed to your computer and use it in GitHub Desktop.
SDXLで高解像度での構図の破綻を軽減する
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def forward(self, x, timesteps=None, context=None, y=None, **kwargs): | |
# broadcast timesteps to batch dimension | |
timesteps = timesteps.expand(x.shape[0]) | |
hs = [] | |
t_emb = get_timestep_embedding(timesteps, self.model_channels) # , repeat_only=False) | |
t_emb = t_emb.to(x.dtype) | |
emb = self.time_embed(t_emb) | |
assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}" | |
assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}" | |
# assert x.dtype == self.dtype | |
emb = emb + self.label_emb(y) | |
def call_module(module, h, emb, context): | |
x = h | |
for layer in module: | |
# print(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None) | |
if isinstance(layer, ResnetBlock2D): | |
x = layer(x, emb) | |
elif isinstance(layer, Transformer2DModel): | |
x = layer(x, context) | |
else: | |
x = layer(x) | |
return x | |
# h = x.type(self.dtype) | |
h = x | |
# downsample depth | |
# 深くすると全体の構図は安定するが、キャラがゆがむ。浅くしすぎると細部が混沌とする | |
ds_depth_1 = 3 # 2~4 くらいがよさそう | |
ds_depth_2 = 3 # depth_1より+0~+2くらいがよさそう | |
# downsample timestep | |
# 大きくすると構図が乱れて、小さくするとディテールが乱れる | |
ds_timestep_1 = 900 | |
ds_timestep_2 = 650 # timestep_1より小さいこと。デッサンに影響する | |
depth = 0 # current depth | |
for module in self.input_blocks: | |
h = call_module(module, h, emb, context) | |
hs.append(h) | |
# print(depth, h.shape, timesteps) | |
if (depth == ds_depth_1 and timesteps[0] > ds_timestep_1) or ( | |
depth == ds_depth_2 and ds_timestep_1 > timesteps[0] and timesteps[0] > ds_timestep_2 | |
): | |
# bicubicでないとゆがむ、align_cornersはあまり影響しない模様 | |
h = F.interpolate(h.float(), scale_factor=0.5, mode="bicubic", align_corners=False).to(h.dtype) # bfloat16対応 | |
depth += 1 | |
h = call_module(self.middle_block, h, emb, context) | |
for module in self.output_blocks: | |
depth -= 1 | |
# print(depth, h.shape) | |
if (depth == ds_depth_1 and timesteps[0] > ds_timestep_1) or ( | |
depth == ds_depth_2 and ds_timestep_1 > timesteps[0] and timesteps[0] > ds_timestep_2 | |
): | |
h = F.interpolate(h.float(), scale_factor=2.0, mode="bicubic", align_corners=False).to(h.dtype) # bfloat16対応 | |
h = torch.cat([h, hs.pop()], dim=1) | |
h = call_module(module, h, emb, context) | |
h = h.type(x.dtype) | |
h = call_module(self.out, h, emb, context) | |
return h |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Mr Kohya, please advise how to implement this at the model level. I would like to try to assemble a checkpoint with a built-in unet shrink for 1.5 checkpoint..