Created
March 5, 2025 05:48
-
-
Save ivanstepanovftw/1247a59c4e85eb68d5a2623dcf7f75bc to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import copy | |
import itertools | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import torch.nn.functional as F | |
from torch.utils.data import DataLoader, random_split | |
from torchvision import datasets, transforms | |
import plotly.graph_objs as go | |
import plotly.offline as pyo | |
# ============================================================================= | |
# Unified CNN Model | |
# | |
# Сеть использует встроенный nn.PixelUnshuffle заменённую на блок последовательных | |
# свёрточных слоёв. В трансформере входное изображение приводится к размеру 32×32, | |
# затем через 4 свёрточных слоя (каждый с шагом 2) и адаптивный пулинг получаем эмбеддинг | |
# размера (bs, emb_dim, 1, 1) перед flatten. Далее производится flatten до (bs, emb_dim), | |
# и логиты вычисляются как emb @ unembed, где unembed – обучаемая матрица (emb_dim, num_classes). | |
# ============================================================================= | |
class CNN(nn.Module): | |
def __init__(self, num_classes=10): | |
super(CNN, self).__init__() | |
emb_dim = 1 | |
self.features = nn.Sequential( | |
# Вход: (bs, 1, 32, 32) | |
*( | |
nn.Sequential( | |
nn.PixelUnshuffle(downscale_factor=2), | |
nn.LazyConv2d(emb_dim := 2 ** (2 * i + 3), kernel_size=1, stride=1, padding=0), | |
nn.ReLU(), | |
) for i in range(4) | |
), | |
nn.AdaptiveAvgPool2d(1), # -> (bs, emb_dim, 1, 1) | |
nn.Flatten(), # -> (bs, emb_dim) | |
) | |
# Обучаемая матрица unembed: shape (emb_dim, num_classes) | |
self.unembed = nn.Parameter(torch.randn(emb_dim, num_classes) * 0.01) | |
def forward(self, x): | |
# x: (bs, 1, 32, 32) | |
embeddings = self.features(x) # -> (bs, emb_dim) | |
logits = embeddings @ self.unembed # -> (bs, num_classes) | |
return embeddings, logits | |
# ============================================================================= | |
# Loss functions | |
# ============================================================================= | |
def cross_entropy_loss(logits, targets): | |
return F.cross_entropy(logits, targets) | |
def harmonic_loss(embeddings, targets, unembed, n=2): | |
dists = torch.cdist(embeddings, unembed.t(), p=2) # (bs, num_classes) | |
inv_dn = 1.0 / (dists ** n + 1e-8) | |
probs = inv_dn / inv_dn.sum(dim=1, keepdim=True) | |
p_correct = probs[torch.arange(embeddings.size(0)), targets] | |
return -torch.log(p_correct + 1e-8).mean() | |
# ============================================================================= | |
# Utility: L2 norm of model parameters | |
# ============================================================================= | |
def get_params_l2(model): | |
total_norm = 0.0 | |
for p in model.parameters(): | |
total_norm += p.pow(2).sum() | |
return total_norm.sqrt().item() | |
# ============================================================================= | |
# Training loop function (для одного метода) | |
# | |
# Обучает модель (на max_steps шагов) с указанным лоссом ('ce' или 'h'). | |
# Каждые validate_every шагов происходит валидация, при этом сохраняются: | |
# - метрики в словарь metrics, | |
# - чекпоинты: (step, state_dict, avg_val_err). | |
# | |
# Функция возвращает (metrics, best_state, checkpoints). | |
# ============================================================================= | |
def train_model(model, optimizer, loss_type, train_batches, val_loader, device, | |
max_steps, validate_every, n_harmonic=2): | |
metrics = { | |
'steps': [], | |
"train_ce": [], | |
"train_h": [], | |
"train_err": [], | |
"val_ce": [], | |
"val_h": [], | |
"val_err": [], | |
"l2": [] | |
} | |
best_val_err = float('inf') | |
best_state = None | |
checkpoints = [] # Список: (step, state_dict, avg_val_err) | |
header = f"{'Step':8} | {'Method':8} | {'Train CE':12} | {'Train H':12} | {'Train E':12} | " \ | |
f"{'Val CE':12} | {'Val H':12} | {'Val E':12} | {'Val L2':12}" | |
print(header) | |
print("-" * len(header)) | |
train_cycle = itertools.cycle(train_batches) | |
for step in range(max_steps): | |
x, y = next(train_cycle) | |
x, y = x.to(device), y.to(device) | |
model.train() | |
optimizer.zero_grad() | |
emb, logits = model(x) | |
# Вычисляем оба лосса | |
ce_loss = cross_entropy_loss(logits, y) | |
h_loss = harmonic_loss(emb, y, model.unembed, n=n_harmonic) | |
# Для обратного прохода используем выбранный loss | |
loss = ce_loss if loss_type == 'ce' else h_loss | |
loss.backward() | |
optimizer.step() | |
# Вычисляем ошибки для обучающего батча | |
ce_preds = logits.argmax(dim=1) | |
ce_train_err = (ce_preds != y).float().mean().item() | |
dists = torch.cdist(emb, model.unembed.t(), p=2) | |
inv_dn = 1.0 / (dists**n_harmonic + 1e-8) | |
probs = inv_dn / inv_dn.sum(dim=1, keepdim=True) | |
h_preds = probs.argmax(dim=1) | |
h_train_err = (h_preds != y).float().mean().item() | |
train_err_avg = (ce_train_err + h_train_err) / 2 | |
l2_norm = get_params_l2(model) | |
# Если пришло время валидации | |
if (step + 1) % validate_every == 0: | |
model.eval() | |
total_ce_val_loss = 0.0 | |
total_h_val_loss = 0.0 | |
total_ce_val_err = 0.0 | |
total_h_val_err = 0.0 | |
count = 0 | |
with torch.no_grad(): | |
for vx, vy in val_loader: | |
vx, vy = vx.to(device), vy.to(device) | |
emb_val, logits_val = model(vx) | |
ce_loss_val = cross_entropy_loss(logits_val, vy) | |
ce_preds_val = logits_val.argmax(dim=1) | |
ce_err_val = (ce_preds_val != vy).float().mean().item() | |
dists_val = torch.cdist(emb_val, model.unembed.t(), p=2) | |
inv_dn_val = 1.0 / (dists_val**n_harmonic + 1e-8) | |
probs_val = inv_dn_val / inv_dn_val.sum(dim=1, keepdim=True) | |
h_preds_val = probs_val.argmax(dim=1) | |
h_loss_val = harmonic_loss(emb_val, vy, model.unembed, n=n_harmonic) | |
h_err_val = (h_preds_val != vy).float().mean().item() | |
total_ce_val_loss += ce_loss_val.item() * vx.size(0) | |
total_h_val_loss += h_loss_val.item() * vx.size(0) | |
total_ce_val_err += ce_err_val * vx.size(0) | |
total_h_val_err += h_err_val * vx.size(0) | |
count += vx.size(0) | |
avg_ce_val_loss = total_ce_val_loss / count | |
avg_h_val_loss = total_h_val_loss / count | |
avg_ce_val_err = total_ce_val_err / count | |
avg_h_val_err = total_h_val_err / count | |
val_err_avg = (avg_ce_val_err + avg_h_val_err) / 2 | |
metrics['steps'].append(step + 1) | |
metrics["train_ce"].append(ce_loss.item()) | |
metrics["train_h"].append(h_loss.item()) | |
metrics["train_err"].append(train_err_avg) | |
metrics["val_ce"].append(avg_ce_val_loss) | |
metrics["val_h"].append(avg_h_val_loss) | |
metrics["val_err"].append(val_err_avg) | |
metrics["l2"].append(l2_norm) | |
checkpoints.append((step + 1, copy.deepcopy(model.state_dict()), val_err_avg)) | |
if val_err_avg < best_val_err: | |
best_val_err = val_err_avg | |
best_state = copy.deepcopy(model.state_dict()) | |
print(f"{step+1:8d} | {loss_type.upper():8} | " | |
f"{ce_loss.item():12.4f} | {h_loss.item():12.4f} | {train_err_avg:12.4f} | " | |
f"{avg_ce_val_loss:12.4f} | {avg_h_val_loss:12.4f} | {val_err_avg:12.4f} | {l2_norm:12.4f}") | |
return metrics, best_state, checkpoints | |
# ============================================================================= | |
# Функция для ансамблирования – вычисляет точность ансамбля моделей | |
# ============================================================================= | |
def ensemble_accuracy(models, loss_type, test_loader, n_harmonic, device): | |
model_count = len(models) | |
correct = 0 | |
total = 0 | |
with torch.no_grad(): | |
for x, y in test_loader: | |
x, y = x.to(device), y.to(device) | |
ensemble_probs = None | |
for m in models: | |
m.eval() | |
emb, logits = m(x) | |
if loss_type == 'ce': | |
probs = F.softmax(logits, dim=1) | |
else: | |
dists = torch.cdist(emb, m.unembed.t(), p=2) | |
inv_dn = 1.0 / (dists ** n_harmonic + 1e-8) | |
probs = inv_dn / inv_dn.sum(dim=1, keepdim=True) | |
if ensemble_probs is None: | |
ensemble_probs = probs | |
else: | |
ensemble_probs += probs | |
ensemble_probs /= model_count | |
preds = ensemble_probs.argmax(dim=1) | |
correct += (preds == y).sum().item() | |
total += y.size(0) | |
return correct / total | |
# ============================================================================= | |
# Функция для усреднения весов (state_dict) из списка чекпоинтов | |
# ============================================================================= | |
def average_state_dicts(state_dicts): | |
avg_state = {} | |
for key in state_dicts[0]: | |
avg_state[key] = sum(sd[key] for sd in state_dicts) / len(state_dicts) | |
return avg_state | |
# ============================================================================= | |
# Main training and evaluation | |
# ============================================================================= | |
def main(): | |
# Фиксируем случайность – оба метода получают одинаковый порядок батчей и начальные веса. | |
torch.manual_seed(42) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print("Using device:", device) | |
# Гиперпараметры | |
batch_size = 64 | |
lr = 1e-3 | |
weight_decay = 1e-999 | |
max_steps = 5000 | |
validate_every = 32 | |
num_classes = 10 | |
n_harmonic = 2 | |
print(f"{batch_size=}, {lr=}, {weight_decay=}, {max_steps=}, {validate_every=}, {n_harmonic=}") | |
# Трансформы: изменяем размер до 32×32, преобразуем в тензор и нормализуем. | |
transform = transforms.Compose([ | |
transforms.Pad(2), | |
transforms.ToTensor(), | |
transforms.Normalize((0.5,), (0.5,)) | |
]) | |
# Данные: FashionMNIST | |
dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform) | |
test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform) | |
# Разбиваем на обучающую и валидационную (55k/5k) | |
train_data, val_data = random_split(dataset, [55000, 5000]) | |
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True) | |
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False) | |
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) | |
# Преобразуем train_loader в список, чтобы оба метода использовали один и тот же порядок. | |
train_batches = list(train_loader) | |
# Создаем базовую модель и deep copy для обоих методов. | |
base_model = CNN(num_classes=num_classes).to(device) | |
# Initialize lazy weights before calculating number of parameters | |
with torch.no_grad(): | |
base_model(torch.empty(1, 1, 32, 32).to(device)) | |
print("Model architecture:") | |
print(base_model) | |
print("Number of parameters:", sum(p.numel() for p in base_model.parameters())) | |
cnn_ce = copy.deepcopy(base_model) | |
cnn_h = copy.deepcopy(base_model) | |
# ----- Обучение модели с кросс-энтропией ----- | |
print("\nTraining Cross-Entropy model:") | |
ce_optimizer = optim.AdamW(cnn_ce.parameters(), lr=lr, weight_decay=weight_decay) | |
ce_metrics, ce_best_state, ce_checkpoints = train_model( | |
cnn_ce, ce_optimizer, "ce", train_batches, val_loader, device, | |
max_steps, validate_every, n_harmonic) | |
# ----- Обучение модели с гармоническим лоссом ----- | |
print("\nTraining Harmonic model:") | |
h_optimizer = optim.AdamW(cnn_h.parameters(), lr=lr, weight_decay=weight_decay) | |
h_metrics, h_best_state, h_checkpoints = train_model( | |
cnn_h, h_optimizer, "h", train_batches, val_loader, device, | |
max_steps, validate_every, n_harmonic) | |
# ----- Формируем ансамбли: | |
# Для каждого метода: | |
# 1. best_model: модель с наименьшей ошибкой на валидации (уже сохранена как ce_best_state, h_best_state) | |
# 2. last10: ансамбль из 10 последних чекпоинтов (если меньше – все) | |
# 3. best10: ансамбль из 10 чекпоинтов с наименьшей валидационной ошибкой | |
# 4. avg_last10: модель, полученная усреднением весов из последних 10 чекпоинтов | |
# 5. avg_best10: модель, полученная усреднением весов из лучших 10 чекпоинтов | |
def get_ensemble_models(base_model, checkpoints_list, indices): | |
models = [] | |
for idx in indices: | |
model_copy = copy.deepcopy(base_model) | |
state = checkpoints_list[idx][1] | |
model_copy.load_state_dict(state) | |
models.append(model_copy) | |
return models | |
def average_state_dicts(state_dicts): | |
avg_state = {} | |
for key in state_dicts[0]: | |
avg_state[key] = sum(sd[key] for sd in state_dicts) / len(state_dicts) | |
return avg_state | |
def get_avg_state(checkpoints_list, indices): | |
state_dicts = [checkpoints_list[i][1] for i in indices] | |
return average_state_dicts(state_dicts) | |
# Для CE | |
ce_total = len(ce_checkpoints) | |
last10_indices = list(range(max(0, ce_total - 10), ce_total)) | |
sorted_by_err = sorted(ce_checkpoints, key=lambda x: x[2]) | |
best10_indices = [ce_checkpoints.index(c) for c in sorted_by_err[:10]] | |
# Для H | |
h_total = len(h_checkpoints) | |
last10_indices_h = list(range(max(0, h_total - 10), h_total)) | |
sorted_by_err_h = sorted(h_checkpoints, key=lambda x: x[2]) | |
best10_indices_h = [h_checkpoints.index(c) for c in sorted_by_err_h[:10]] | |
base_ce_model = CNN(num_classes=num_classes).to(device) | |
base_h_model = CNN(num_classes=num_classes).to(device) | |
ce_last10_models = get_ensemble_models(base_ce_model, ce_checkpoints, last10_indices) | |
ce_best10_models = get_ensemble_models(base_ce_model, ce_checkpoints, best10_indices) | |
h_last10_models = get_ensemble_models(base_h_model, h_checkpoints, last10_indices_h) | |
h_best10_models = get_ensemble_models(base_h_model, h_checkpoints, best10_indices_h) | |
# Усреднение весов для последних и лучших чекпоинтов | |
avg_state_ce_last10 = get_avg_state(ce_checkpoints, last10_indices) | |
avg_state_ce_best10 = get_avg_state(ce_checkpoints, best10_indices) | |
avg_state_h_last10 = get_avg_state(h_checkpoints, last10_indices_h) | |
avg_state_h_best10 = get_avg_state(h_checkpoints, best10_indices_h) | |
# Создаём новые модели и загружаем усреднённые веса | |
avg_ce_last10_model = CNN(num_classes=num_classes).to(device) | |
avg_ce_last10_model.load_state_dict(avg_state_ce_last10) | |
avg_ce_best10_model = CNN(num_classes=num_classes).to(device) | |
avg_ce_best10_model.load_state_dict(avg_state_ce_best10) | |
avg_h_last10_model = CNN(num_classes=num_classes).to(device) | |
avg_h_last10_model.load_state_dict(avg_state_h_last10) | |
avg_h_best10_model = CNN(num_classes=num_classes).to(device) | |
avg_h_best10_model.load_state_dict(avg_state_h_best10) | |
# ----- Финальное тестирование: используем веса с минимальной ошибкой на валидации ----- | |
cnn_ce.load_state_dict(ce_best_state) | |
cnn_h.load_state_dict(h_best_state) | |
def test_accuracy(model, loss_type): | |
model.eval() | |
correct = 0 | |
total = 0 | |
with torch.no_grad(): | |
for x, y in test_loader: | |
x, y = x.to(device), y.to(device) | |
emb, logits = model(x) | |
if loss_type == 'ce': | |
preds = logits.argmax(dim=1) | |
else: | |
dists = torch.cdist(emb, model.unembed.t(), p=2) | |
inv_dn = 1.0 / (dists ** n_harmonic + 1e-8) | |
probs = inv_dn / inv_dn.sum(dim=1, keepdim=True) | |
preds = probs.argmax(dim=1) | |
correct += (preds == y).sum().item() | |
total += y.size(0) | |
return correct / total | |
ce_best_acc = test_accuracy(cnn_ce, "ce") | |
h_best_acc = test_accuracy(cnn_h, "h") | |
ce_last10_acc = ensemble_accuracy(ce_last10_models, "ce", test_loader, n_harmonic, device) | |
ce_best10_acc = ensemble_accuracy(ce_best10_models, "ce", test_loader, n_harmonic, device) | |
h_last10_acc = ensemble_accuracy(h_last10_models, "h", test_loader, n_harmonic, device) | |
h_best10_acc = ensemble_accuracy(h_best10_models, "h", test_loader, n_harmonic, device) | |
avg_ce_last10_acc = test_accuracy(avg_ce_last10_model, "ce") | |
avg_ce_best10_acc = test_accuracy(avg_ce_best10_model, "ce") | |
avg_h_last10_acc = test_accuracy(avg_h_last10_model, "h") | |
avg_h_best10_acc = test_accuracy(avg_h_best10_model, "h") | |
print("\nFinal Test Accuracy (CE):") | |
print(f" Best model (lowest val error): {ce_best_acc * 100:.2f}%") | |
print(f" Ensemble of last 10 models: {ce_last10_acc * 100:.2f}%") | |
print(f" Ensemble of best 10 models: {ce_best10_acc * 100:.2f}%") | |
print(f" Avg model (last 10 weights): {avg_ce_last10_acc * 100:.2f}%") | |
print(f" Avg model (best 10 weights): {avg_ce_best10_acc * 100:.2f}%") | |
print("\nFinal Test Accuracy (H):") | |
print(f" Best model (lowest val error): {h_best_acc * 100:.2f}%") | |
print(f" Ensemble of last 10 models: {h_last10_acc * 100:.2f}%") | |
print(f" Ensemble of best 10 models: {h_best10_acc * 100:.2f}%") | |
print(f" Avg model (last 10 weights): {avg_h_last10_acc * 100:.2f}%") | |
print(f" Avg model (best 10 weights): {avg_h_best10_acc * 100:.2f}%") | |
# ----- Генерируем график и сохраняем его в loss_comparison.html ----- | |
fig = go.Figure() | |
fig.add_trace(go.Scatter(x=ce_metrics['steps'], y=ce_metrics["train_ce"], mode='lines+markers', name='CE, Train, CE Loss (to minimize)', line=dict(color='blue'))) | |
fig.add_trace(go.Scatter(x=ce_metrics['steps'], y=ce_metrics["train_h"], mode='lines+markers', name='CE, Train, H Loss', line=dict(color='cyan'))) | |
fig.add_trace(go.Scatter(x=ce_metrics['steps'], y=ce_metrics["train_err"], mode='lines+markers', name='CE, Train, Err', line=dict(color='navy'))) | |
fig.add_trace(go.Scatter(x=ce_metrics['steps'], y=ce_metrics["val_ce"], mode='lines+markers', name='CE, Val, CE Loss', line=dict(color='royalblue'))) | |
fig.add_trace(go.Scatter(x=ce_metrics['steps'], y=ce_metrics["val_h"], mode='lines+markers', name='CE, Val, H Loss', line=dict(color='lightblue'))) | |
fig.add_trace(go.Scatter(x=ce_metrics['steps'], y=ce_metrics["val_err"], mode='lines+markers', name='CE, Val, Err', line=dict(color='mediumblue'))) | |
fig.add_trace(go.Scatter(x=ce_metrics['steps'], y=ce_metrics["l2"], mode='lines+markers', name='CE, L2 Weights Norm', line=dict(color='darkblue'))) | |
fig.add_trace(go.Scatter(x=h_metrics['steps'], y=h_metrics["train_ce"], mode='lines+markers', name='H, Train, CE Loss', line=dict(color='red'))) | |
fig.add_trace(go.Scatter(x=h_metrics['steps'], y=h_metrics["train_h"], mode='lines+markers', name='H, Train, H Loss (to minimize)', line=dict(color='orange'))) | |
fig.add_trace(go.Scatter(x=h_metrics['steps'], y=h_metrics["train_err"], mode='lines+markers', name='H, Train, Err', line=dict(color='darkred'))) | |
fig.add_trace(go.Scatter(x=h_metrics['steps'], y=h_metrics["val_ce"], mode='lines+markers', name='H, Val, CE Loss', line=dict(color='crimson'))) | |
fig.add_trace(go.Scatter(x=h_metrics['steps'], y=h_metrics["val_h"], mode='lines+markers', name='H, Val, H Loss', line=dict(color='tomato'))) | |
fig.add_trace(go.Scatter(x=h_metrics['steps'], y=h_metrics["val_err"], mode='lines+markers', name='H, Val, Err', line=dict(color='firebrick'))) | |
fig.add_trace(go.Scatter(x=h_metrics['steps'], y=h_metrics["l2"], mode='lines+markers', name='H, L2 Weights Norm', line=dict(color='maroon'))) | |
fig.update_layout(title="Loss and Error Comparison", | |
xaxis_title="Training Step", | |
yaxis_title="Metric Value") | |
pyo.plot(fig, filename='loss_comparison.html', auto_open=True) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment