Skip to content

Instantly share code, notes, and snippets.

@woolpeeker
Last active September 5, 2022 14:17
Show Gist options
  • Save woolpeeker/ccc9627f33d6d8a140592bfd554b7fe6 to your computer and use it in GitHub Desktop.
Save woolpeeker/ccc9627f33d6d8a140592bfd554b7fe6 to your computer and use it in GitHub Desktop.
adamw_vs_fused
"""
V100-32G
==================================================
model_name: vit_base_patch16_224
iter_num: 1000
bs: 16
adamw avg time: 12.739 ms
fused adam avg time: 3.343 ms
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import apex
import timm
torch.backends.cudnn.benchmark = True
bs = 16
iter_num = 1000
skip_steps = 10
model_name = 'vit_base_patch16_224'
model = timm.create_model(model_name).cuda()
num_classes = model.num_classes
adamw = torch.optim.AdamW(model.parameters(), lr=1e-6, weight_decay=1e-5)
fused_adam = apex.optimizers.FusedAdam(model.parameters(), lr=1e-6, weight_decay=1e-5)
inputs = torch.randn([bs, 3, 224, 224]).half().cuda()
targets = torch.argmax(torch.randn([bs, num_classes]), dim=1).cuda()
scaler = torch.cuda.amp.GradScaler()
def test(optimizer):
total_time_ms = 0
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
skip_steps = 12
for i in range(iter_num + skip_steps):
with torch.cuda.amp.autocast():
outputs = model(inputs)
loss = F.cross_entropy(outputs, targets)
scaler.scale(loss).backward()
start_event.record()
optimizer.step()
end_event.record()
torch.cuda.synchronize() # Wait for the events to be recorded!
if i >= skip_steps:
total_time_ms += start_event.elapsed_time(end_event)
optimizer.zero_grad()
avg_time_ms = total_time_ms / iter_num
print(f'native_adamw: {avg_time_ms:.3f} ms')
return avg_time_ms
adamw_time = test(adamw)
fused_adam_time = test(fused_adam)
print('==================================================')
print(f'model_name: {model_name}')
print(f'iter_num: {iter_num}')
print(f'bs: {bs}')
print(f'adamw avg time: {adamw_time:.3f} ms')
print(f'fused adam avg time: {fused_adam_time:.3f} ms')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment