Skip to content

Instantly share code, notes, and snippets.

@ivanstepanovftw
Created March 5, 2025 05:48
Show Gist options
  • Save ivanstepanovftw/1247a59c4e85eb68d5a2623dcf7f75bc to your computer and use it in GitHub Desktop.
Save ivanstepanovftw/1247a59c4e85eb68d5a2623dcf7f75bc to your computer and use it in GitHub Desktop.
import copy
import itertools
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
import plotly.graph_objs as go
import plotly.offline as pyo
# =============================================================================
# Unified CNN Model
#
# Сеть использует встроенный nn.PixelUnshuffle заменённую на блок последовательных
# свёрточных слоёв. В трансформере входное изображение приводится к размеру 32×32,
# затем через 4 свёрточных слоя (каждый с шагом 2) и адаптивный пулинг получаем эмбеддинг
# размера (bs, emb_dim, 1, 1) перед flatten. Далее производится flatten до (bs, emb_dim),
# и логиты вычисляются как emb @ unembed, где unembed – обучаемая матрица (emb_dim, num_classes).
# =============================================================================
class CNN(nn.Module):
def __init__(self, num_classes=10):
super(CNN, self).__init__()
emb_dim = 1
self.features = nn.Sequential(
# Вход: (bs, 1, 32, 32)
*(
nn.Sequential(
nn.PixelUnshuffle(downscale_factor=2),
nn.LazyConv2d(emb_dim := 2 ** (2 * i + 3), kernel_size=1, stride=1, padding=0),
nn.ReLU(),
) for i in range(4)
),
nn.AdaptiveAvgPool2d(1), # -> (bs, emb_dim, 1, 1)
nn.Flatten(), # -> (bs, emb_dim)
)
# Обучаемая матрица unembed: shape (emb_dim, num_classes)
self.unembed = nn.Parameter(torch.randn(emb_dim, num_classes) * 0.01)
def forward(self, x):
# x: (bs, 1, 32, 32)
embeddings = self.features(x) # -> (bs, emb_dim)
logits = embeddings @ self.unembed # -> (bs, num_classes)
return embeddings, logits
# =============================================================================
# Loss functions
# =============================================================================
def cross_entropy_loss(logits, targets):
return F.cross_entropy(logits, targets)
def harmonic_loss(embeddings, targets, unembed, n=2):
dists = torch.cdist(embeddings, unembed.t(), p=2) # (bs, num_classes)
inv_dn = 1.0 / (dists ** n + 1e-8)
probs = inv_dn / inv_dn.sum(dim=1, keepdim=True)
p_correct = probs[torch.arange(embeddings.size(0)), targets]
return -torch.log(p_correct + 1e-8).mean()
# =============================================================================
# Utility: L2 norm of model parameters
# =============================================================================
def get_params_l2(model):
total_norm = 0.0
for p in model.parameters():
total_norm += p.pow(2).sum()
return total_norm.sqrt().item()
# =============================================================================
# Training loop function (для одного метода)
#
# Обучает модель (на max_steps шагов) с указанным лоссом ('ce' или 'h').
# Каждые validate_every шагов происходит валидация, при этом сохраняются:
# - метрики в словарь metrics,
# - чекпоинты: (step, state_dict, avg_val_err).
#
# Функция возвращает (metrics, best_state, checkpoints).
# =============================================================================
def train_model(model, optimizer, loss_type, train_batches, val_loader, device,
max_steps, validate_every, n_harmonic=2):
metrics = {
'steps': [],
"train_ce": [],
"train_h": [],
"train_err": [],
"val_ce": [],
"val_h": [],
"val_err": [],
"l2": []
}
best_val_err = float('inf')
best_state = None
checkpoints = [] # Список: (step, state_dict, avg_val_err)
header = f"{'Step':8} | {'Method':8} | {'Train CE':12} | {'Train H':12} | {'Train E':12} | " \
f"{'Val CE':12} | {'Val H':12} | {'Val E':12} | {'Val L2':12}"
print(header)
print("-" * len(header))
train_cycle = itertools.cycle(train_batches)
for step in range(max_steps):
x, y = next(train_cycle)
x, y = x.to(device), y.to(device)
model.train()
optimizer.zero_grad()
emb, logits = model(x)
# Вычисляем оба лосса
ce_loss = cross_entropy_loss(logits, y)
h_loss = harmonic_loss(emb, y, model.unembed, n=n_harmonic)
# Для обратного прохода используем выбранный loss
loss = ce_loss if loss_type == 'ce' else h_loss
loss.backward()
optimizer.step()
# Вычисляем ошибки для обучающего батча
ce_preds = logits.argmax(dim=1)
ce_train_err = (ce_preds != y).float().mean().item()
dists = torch.cdist(emb, model.unembed.t(), p=2)
inv_dn = 1.0 / (dists**n_harmonic + 1e-8)
probs = inv_dn / inv_dn.sum(dim=1, keepdim=True)
h_preds = probs.argmax(dim=1)
h_train_err = (h_preds != y).float().mean().item()
train_err_avg = (ce_train_err + h_train_err) / 2
l2_norm = get_params_l2(model)
# Если пришло время валидации
if (step + 1) % validate_every == 0:
model.eval()
total_ce_val_loss = 0.0
total_h_val_loss = 0.0
total_ce_val_err = 0.0
total_h_val_err = 0.0
count = 0
with torch.no_grad():
for vx, vy in val_loader:
vx, vy = vx.to(device), vy.to(device)
emb_val, logits_val = model(vx)
ce_loss_val = cross_entropy_loss(logits_val, vy)
ce_preds_val = logits_val.argmax(dim=1)
ce_err_val = (ce_preds_val != vy).float().mean().item()
dists_val = torch.cdist(emb_val, model.unembed.t(), p=2)
inv_dn_val = 1.0 / (dists_val**n_harmonic + 1e-8)
probs_val = inv_dn_val / inv_dn_val.sum(dim=1, keepdim=True)
h_preds_val = probs_val.argmax(dim=1)
h_loss_val = harmonic_loss(emb_val, vy, model.unembed, n=n_harmonic)
h_err_val = (h_preds_val != vy).float().mean().item()
total_ce_val_loss += ce_loss_val.item() * vx.size(0)
total_h_val_loss += h_loss_val.item() * vx.size(0)
total_ce_val_err += ce_err_val * vx.size(0)
total_h_val_err += h_err_val * vx.size(0)
count += vx.size(0)
avg_ce_val_loss = total_ce_val_loss / count
avg_h_val_loss = total_h_val_loss / count
avg_ce_val_err = total_ce_val_err / count
avg_h_val_err = total_h_val_err / count
val_err_avg = (avg_ce_val_err + avg_h_val_err) / 2
metrics['steps'].append(step + 1)
metrics["train_ce"].append(ce_loss.item())
metrics["train_h"].append(h_loss.item())
metrics["train_err"].append(train_err_avg)
metrics["val_ce"].append(avg_ce_val_loss)
metrics["val_h"].append(avg_h_val_loss)
metrics["val_err"].append(val_err_avg)
metrics["l2"].append(l2_norm)
checkpoints.append((step + 1, copy.deepcopy(model.state_dict()), val_err_avg))
if val_err_avg < best_val_err:
best_val_err = val_err_avg
best_state = copy.deepcopy(model.state_dict())
print(f"{step+1:8d} | {loss_type.upper():8} | "
f"{ce_loss.item():12.4f} | {h_loss.item():12.4f} | {train_err_avg:12.4f} | "
f"{avg_ce_val_loss:12.4f} | {avg_h_val_loss:12.4f} | {val_err_avg:12.4f} | {l2_norm:12.4f}")
return metrics, best_state, checkpoints
# =============================================================================
# Функция для ансамблирования – вычисляет точность ансамбля моделей
# =============================================================================
def ensemble_accuracy(models, loss_type, test_loader, n_harmonic, device):
model_count = len(models)
correct = 0
total = 0
with torch.no_grad():
for x, y in test_loader:
x, y = x.to(device), y.to(device)
ensemble_probs = None
for m in models:
m.eval()
emb, logits = m(x)
if loss_type == 'ce':
probs = F.softmax(logits, dim=1)
else:
dists = torch.cdist(emb, m.unembed.t(), p=2)
inv_dn = 1.0 / (dists ** n_harmonic + 1e-8)
probs = inv_dn / inv_dn.sum(dim=1, keepdim=True)
if ensemble_probs is None:
ensemble_probs = probs
else:
ensemble_probs += probs
ensemble_probs /= model_count
preds = ensemble_probs.argmax(dim=1)
correct += (preds == y).sum().item()
total += y.size(0)
return correct / total
# =============================================================================
# Функция для усреднения весов (state_dict) из списка чекпоинтов
# =============================================================================
def average_state_dicts(state_dicts):
avg_state = {}
for key in state_dicts[0]:
avg_state[key] = sum(sd[key] for sd in state_dicts) / len(state_dicts)
return avg_state
# =============================================================================
# Main training and evaluation
# =============================================================================
def main():
# Фиксируем случайность – оба метода получают одинаковый порядок батчей и начальные веса.
torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
# Гиперпараметры
batch_size = 64
lr = 1e-3
weight_decay = 1e-999
max_steps = 5000
validate_every = 32
num_classes = 10
n_harmonic = 2
print(f"{batch_size=}, {lr=}, {weight_decay=}, {max_steps=}, {validate_every=}, {n_harmonic=}")
# Трансформы: изменяем размер до 32×32, преобразуем в тензор и нормализуем.
transform = transforms.Compose([
transforms.Pad(2),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# Данные: FashionMNIST
dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)
# Разбиваем на обучающую и валидационную (55k/5k)
train_data, val_data = random_split(dataset, [55000, 5000])
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
# Преобразуем train_loader в список, чтобы оба метода использовали один и тот же порядок.
train_batches = list(train_loader)
# Создаем базовую модель и deep copy для обоих методов.
base_model = CNN(num_classes=num_classes).to(device)
# Initialize lazy weights before calculating number of parameters
with torch.no_grad():
base_model(torch.empty(1, 1, 32, 32).to(device))
print("Model architecture:")
print(base_model)
print("Number of parameters:", sum(p.numel() for p in base_model.parameters()))
cnn_ce = copy.deepcopy(base_model)
cnn_h = copy.deepcopy(base_model)
# ----- Обучение модели с кросс-энтропией -----
print("\nTraining Cross-Entropy model:")
ce_optimizer = optim.AdamW(cnn_ce.parameters(), lr=lr, weight_decay=weight_decay)
ce_metrics, ce_best_state, ce_checkpoints = train_model(
cnn_ce, ce_optimizer, "ce", train_batches, val_loader, device,
max_steps, validate_every, n_harmonic)
# ----- Обучение модели с гармоническим лоссом -----
print("\nTraining Harmonic model:")
h_optimizer = optim.AdamW(cnn_h.parameters(), lr=lr, weight_decay=weight_decay)
h_metrics, h_best_state, h_checkpoints = train_model(
cnn_h, h_optimizer, "h", train_batches, val_loader, device,
max_steps, validate_every, n_harmonic)
# ----- Формируем ансамбли:
# Для каждого метода:
# 1. best_model: модель с наименьшей ошибкой на валидации (уже сохранена как ce_best_state, h_best_state)
# 2. last10: ансамбль из 10 последних чекпоинтов (если меньше – все)
# 3. best10: ансамбль из 10 чекпоинтов с наименьшей валидационной ошибкой
# 4. avg_last10: модель, полученная усреднением весов из последних 10 чекпоинтов
# 5. avg_best10: модель, полученная усреднением весов из лучших 10 чекпоинтов
def get_ensemble_models(base_model, checkpoints_list, indices):
models = []
for idx in indices:
model_copy = copy.deepcopy(base_model)
state = checkpoints_list[idx][1]
model_copy.load_state_dict(state)
models.append(model_copy)
return models
def average_state_dicts(state_dicts):
avg_state = {}
for key in state_dicts[0]:
avg_state[key] = sum(sd[key] for sd in state_dicts) / len(state_dicts)
return avg_state
def get_avg_state(checkpoints_list, indices):
state_dicts = [checkpoints_list[i][1] for i in indices]
return average_state_dicts(state_dicts)
# Для CE
ce_total = len(ce_checkpoints)
last10_indices = list(range(max(0, ce_total - 10), ce_total))
sorted_by_err = sorted(ce_checkpoints, key=lambda x: x[2])
best10_indices = [ce_checkpoints.index(c) for c in sorted_by_err[:10]]
# Для H
h_total = len(h_checkpoints)
last10_indices_h = list(range(max(0, h_total - 10), h_total))
sorted_by_err_h = sorted(h_checkpoints, key=lambda x: x[2])
best10_indices_h = [h_checkpoints.index(c) for c in sorted_by_err_h[:10]]
base_ce_model = CNN(num_classes=num_classes).to(device)
base_h_model = CNN(num_classes=num_classes).to(device)
ce_last10_models = get_ensemble_models(base_ce_model, ce_checkpoints, last10_indices)
ce_best10_models = get_ensemble_models(base_ce_model, ce_checkpoints, best10_indices)
h_last10_models = get_ensemble_models(base_h_model, h_checkpoints, last10_indices_h)
h_best10_models = get_ensemble_models(base_h_model, h_checkpoints, best10_indices_h)
# Усреднение весов для последних и лучших чекпоинтов
avg_state_ce_last10 = get_avg_state(ce_checkpoints, last10_indices)
avg_state_ce_best10 = get_avg_state(ce_checkpoints, best10_indices)
avg_state_h_last10 = get_avg_state(h_checkpoints, last10_indices_h)
avg_state_h_best10 = get_avg_state(h_checkpoints, best10_indices_h)
# Создаём новые модели и загружаем усреднённые веса
avg_ce_last10_model = CNN(num_classes=num_classes).to(device)
avg_ce_last10_model.load_state_dict(avg_state_ce_last10)
avg_ce_best10_model = CNN(num_classes=num_classes).to(device)
avg_ce_best10_model.load_state_dict(avg_state_ce_best10)
avg_h_last10_model = CNN(num_classes=num_classes).to(device)
avg_h_last10_model.load_state_dict(avg_state_h_last10)
avg_h_best10_model = CNN(num_classes=num_classes).to(device)
avg_h_best10_model.load_state_dict(avg_state_h_best10)
# ----- Финальное тестирование: используем веса с минимальной ошибкой на валидации -----
cnn_ce.load_state_dict(ce_best_state)
cnn_h.load_state_dict(h_best_state)
def test_accuracy(model, loss_type):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for x, y in test_loader:
x, y = x.to(device), y.to(device)
emb, logits = model(x)
if loss_type == 'ce':
preds = logits.argmax(dim=1)
else:
dists = torch.cdist(emb, model.unembed.t(), p=2)
inv_dn = 1.0 / (dists ** n_harmonic + 1e-8)
probs = inv_dn / inv_dn.sum(dim=1, keepdim=True)
preds = probs.argmax(dim=1)
correct += (preds == y).sum().item()
total += y.size(0)
return correct / total
ce_best_acc = test_accuracy(cnn_ce, "ce")
h_best_acc = test_accuracy(cnn_h, "h")
ce_last10_acc = ensemble_accuracy(ce_last10_models, "ce", test_loader, n_harmonic, device)
ce_best10_acc = ensemble_accuracy(ce_best10_models, "ce", test_loader, n_harmonic, device)
h_last10_acc = ensemble_accuracy(h_last10_models, "h", test_loader, n_harmonic, device)
h_best10_acc = ensemble_accuracy(h_best10_models, "h", test_loader, n_harmonic, device)
avg_ce_last10_acc = test_accuracy(avg_ce_last10_model, "ce")
avg_ce_best10_acc = test_accuracy(avg_ce_best10_model, "ce")
avg_h_last10_acc = test_accuracy(avg_h_last10_model, "h")
avg_h_best10_acc = test_accuracy(avg_h_best10_model, "h")
print("\nFinal Test Accuracy (CE):")
print(f" Best model (lowest val error): {ce_best_acc * 100:.2f}%")
print(f" Ensemble of last 10 models: {ce_last10_acc * 100:.2f}%")
print(f" Ensemble of best 10 models: {ce_best10_acc * 100:.2f}%")
print(f" Avg model (last 10 weights): {avg_ce_last10_acc * 100:.2f}%")
print(f" Avg model (best 10 weights): {avg_ce_best10_acc * 100:.2f}%")
print("\nFinal Test Accuracy (H):")
print(f" Best model (lowest val error): {h_best_acc * 100:.2f}%")
print(f" Ensemble of last 10 models: {h_last10_acc * 100:.2f}%")
print(f" Ensemble of best 10 models: {h_best10_acc * 100:.2f}%")
print(f" Avg model (last 10 weights): {avg_h_last10_acc * 100:.2f}%")
print(f" Avg model (best 10 weights): {avg_h_best10_acc * 100:.2f}%")
# ----- Генерируем график и сохраняем его в loss_comparison.html -----
fig = go.Figure()
fig.add_trace(go.Scatter(x=ce_metrics['steps'], y=ce_metrics["train_ce"], mode='lines+markers', name='CE, Train, CE Loss (to minimize)', line=dict(color='blue')))
fig.add_trace(go.Scatter(x=ce_metrics['steps'], y=ce_metrics["train_h"], mode='lines+markers', name='CE, Train, H Loss', line=dict(color='cyan')))
fig.add_trace(go.Scatter(x=ce_metrics['steps'], y=ce_metrics["train_err"], mode='lines+markers', name='CE, Train, Err', line=dict(color='navy')))
fig.add_trace(go.Scatter(x=ce_metrics['steps'], y=ce_metrics["val_ce"], mode='lines+markers', name='CE, Val, CE Loss', line=dict(color='royalblue')))
fig.add_trace(go.Scatter(x=ce_metrics['steps'], y=ce_metrics["val_h"], mode='lines+markers', name='CE, Val, H Loss', line=dict(color='lightblue')))
fig.add_trace(go.Scatter(x=ce_metrics['steps'], y=ce_metrics["val_err"], mode='lines+markers', name='CE, Val, Err', line=dict(color='mediumblue')))
fig.add_trace(go.Scatter(x=ce_metrics['steps'], y=ce_metrics["l2"], mode='lines+markers', name='CE, L2 Weights Norm', line=dict(color='darkblue')))
fig.add_trace(go.Scatter(x=h_metrics['steps'], y=h_metrics["train_ce"], mode='lines+markers', name='H, Train, CE Loss', line=dict(color='red')))
fig.add_trace(go.Scatter(x=h_metrics['steps'], y=h_metrics["train_h"], mode='lines+markers', name='H, Train, H Loss (to minimize)', line=dict(color='orange')))
fig.add_trace(go.Scatter(x=h_metrics['steps'], y=h_metrics["train_err"], mode='lines+markers', name='H, Train, Err', line=dict(color='darkred')))
fig.add_trace(go.Scatter(x=h_metrics['steps'], y=h_metrics["val_ce"], mode='lines+markers', name='H, Val, CE Loss', line=dict(color='crimson')))
fig.add_trace(go.Scatter(x=h_metrics['steps'], y=h_metrics["val_h"], mode='lines+markers', name='H, Val, H Loss', line=dict(color='tomato')))
fig.add_trace(go.Scatter(x=h_metrics['steps'], y=h_metrics["val_err"], mode='lines+markers', name='H, Val, Err', line=dict(color='firebrick')))
fig.add_trace(go.Scatter(x=h_metrics['steps'], y=h_metrics["l2"], mode='lines+markers', name='H, L2 Weights Norm', line=dict(color='maroon')))
fig.update_layout(title="Loss and Error Comparison",
xaxis_title="Training Step",
yaxis_title="Metric Value")
pyo.plot(fig, filename='loss_comparison.html', auto_open=True)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment