Skip to content

Instantly share code, notes, and snippets.

@kohya-ss
Created November 14, 2023 03:39
Show Gist options
  • Save kohya-ss/3f774da220df102548093a7abc8538ed to your computer and use it in GitHub Desktop.
Save kohya-ss/3f774da220df102548093a7abc8538ed to your computer and use it in GitHub Desktop.
SDXLで高解像度での構図の破綻を軽減する
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
# broadcast timesteps to batch dimension
timesteps = timesteps.expand(x.shape[0])
hs = []
t_emb = get_timestep_embedding(timesteps, self.model_channels) # , repeat_only=False)
t_emb = t_emb.to(x.dtype)
emb = self.time_embed(t_emb)
assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}"
assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}"
# assert x.dtype == self.dtype
emb = emb + self.label_emb(y)
def call_module(module, h, emb, context):
x = h
for layer in module:
# print(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None)
if isinstance(layer, ResnetBlock2D):
x = layer(x, emb)
elif isinstance(layer, Transformer2DModel):
x = layer(x, context)
else:
x = layer(x)
return x
# h = x.type(self.dtype)
h = x
# downsample depth
# 深くすると全体の構図は安定するが、キャラがゆがむ。浅くしすぎると細部が混沌とする
ds_depth_1 = 3 # 2~4 くらいがよさそう
ds_depth_2 = 3 # depth_1より+0~+2くらいがよさそう
# downsample timestep
# 大きくすると構図が乱れて、小さくするとディテールが乱れる
ds_timestep_1 = 900
ds_timestep_2 = 650 # timestep_1より小さいこと。デッサンに影響する
depth = 0 # current depth
for module in self.input_blocks:
h = call_module(module, h, emb, context)
hs.append(h)
# print(depth, h.shape, timesteps)
if (depth == ds_depth_1 and timesteps[0] > ds_timestep_1) or (
depth == ds_depth_2 and ds_timestep_1 > timesteps[0] and timesteps[0] > ds_timestep_2
):
# bicubicでないとゆがむ、align_cornersはあまり影響しない模様
h = F.interpolate(h.float(), scale_factor=0.5, mode="bicubic", align_corners=False).to(h.dtype) # bfloat16対応
depth += 1
h = call_module(self.middle_block, h, emb, context)
for module in self.output_blocks:
depth -= 1
# print(depth, h.shape)
if (depth == ds_depth_1 and timesteps[0] > ds_timestep_1) or (
depth == ds_depth_2 and ds_timestep_1 > timesteps[0] and timesteps[0] > ds_timestep_2
):
h = F.interpolate(h.float(), scale_factor=2.0, mode="bicubic", align_corners=False).to(h.dtype) # bfloat16対応
h = torch.cat([h, hs.pop()], dim=1)
h = call_module(module, h, emb, context)
h = h.type(x.dtype)
h = call_module(self.out, h, emb, context)
return h
@recoilme
Copy link

Mr Kohya, please advise how to implement this at the model level. I would like to try to assemble a checkpoint with a built-in unet shrink for 1.5 checkpoint..

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment