Skip to content

Instantly share code, notes, and snippets.

@CoffeeVampir3
Created February 13, 2024 20:11
Show Gist options
  • Save CoffeeVampir3/9638e59103cf7b6b55a898265778969d to your computer and use it in GitHub Desktop.
Save CoffeeVampir3/9638e59103cf7b6b55a898265778969d to your computer and use it in GitHub Desktop.
Kohya SDXL style run
#!/bin/bash
# Config Start
# Configurations
ckpt="checkpoint path here" # base checkpoint to finetune
image_dir="training path here, should have images as subfolder" # folder containing folders with repeats_conceptname
reg_dir="" #optional, just point this to an empty folder if you don't care
output="output path here" # safetensors output folder
learning_rate=1
lr_warmup_ratio=0
train_batch_size=5
num_epochs=15
save_every_n_epochs=1
scheduler="linear"
network_dim=32
text_encoder_lr=0
unet_lr=1
# Config End
echo "Image directory: $image_dir"
ls "$image_dir"
while IFS= read -r -d $'\0' dir; do
dirname=$(basename "$dir")
IFS='_' read -r -a parts <<< "$dirname"
if [ "${#parts[@]}" -ne 2 ]; then
echo "Directory name $dirname does not follow expected format."
continue
fi
repeats=${parts[0]}
concept=${parts[1]}
echo "Processing: $dirname, Repeats: $repeats, Concept: $concept"
imgs=$(find "$dir" -type f \( -iname "*.png" -o -iname "*.bmp" -o -iname "*.gif" -o -iname "*.jpg" -o -iname "*.jpeg" -o -iname "*.webp" \) | wc -l)
img_repeats=$((repeats * imgs))
echo -e "\t$concept: $repeats repeats * $imgs images = $img_repeats"
total=$((total + img_repeats))
done < <(find "$image_dir" -mindepth 1 -maxdepth 1 -type d -print0)
# Calculations based on total images
total=$imgs
mts=$((total / train_batch_size * num_epochs))
lr_warmup_steps=$(echo "$mts * $lr_warmup_ratio" | bc)
echo "Total images with repeats: $total"
echo "Max training steps $total / $train_batch_size * $num_epochs = $mts"
echo "LR Warmup Steps: $lr_warmup_steps"
# Activate the virtual environment
source ./venv/bin/activate
# Launch training script with parameters
accelerate launch --num_cpu_threads_per_process 16 sdxl_train_network.py \
--cache_latents \
--enable_bucket \
--network_train_unet_only \
--min_bucket_reso=512 \
--max_bucket_reso=2048 \
--bucket_reso_steps=256 \
--max_data_loader_n_workers=1 \
--persistent_data_loader_workers \
--pretrained_model_name_or_path="$ckpt" \
--train_data_dir="$image_dir" \
--reg_data_dir="$reg_dir" \
--resolution=1024,1024 \
--optimizer_type="Prodigy" \
--optimizer_args "weight_decay=0.01" "betas=0.9,0.999" "d_coef=0.8" "use_bias_correction=True" "safeguard_warmup=False" \
--output_dir="$output" \
--train_batch_size=$train_batch_size \
--lr_scheduler="$scheduler" \
--lr_warmup_steps=$lr_warmup_steps \
--max_train_steps=$mts \
--multires_noise_discount=0.3 \
--prior_loss_weight=1 \
--gradient_checkpointing \
--xformers \
--mixed_precision=bf16 \
--save_every_n_epochs=$save_every_n_epochs \
--seed=1234 \
--save_precision=bf16 \
--logging_dir="" \
--caption_extension=.txt \
--save_model_as=safetensors \
--network_module=networks.lora \
--text_encoder_lr=$text_encoder_lr \
--unet_lr=$unet_lr \
--network_dim=$network_dim \
--network_alpha=$network_dim \
--output_name="yd_sdxl_linearv6"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment