Last active
September 5, 2022 14:17
-
-
Save woolpeeker/ccc9627f33d6d8a140592bfd554b7fe6 to your computer and use it in GitHub Desktop.
adamw_vs_fused
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
V100-32G | |
================================================== | |
model_name: vit_base_patch16_224 | |
iter_num: 1000 | |
bs: 16 | |
adamw avg time: 12.739 ms | |
fused adam avg time: 3.343 ms | |
""" | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import apex | |
import timm | |
torch.backends.cudnn.benchmark = True | |
bs = 16 | |
iter_num = 1000 | |
skip_steps = 10 | |
model_name = 'vit_base_patch16_224' | |
model = timm.create_model(model_name).cuda() | |
num_classes = model.num_classes | |
adamw = torch.optim.AdamW(model.parameters(), lr=1e-6, weight_decay=1e-5) | |
fused_adam = apex.optimizers.FusedAdam(model.parameters(), lr=1e-6, weight_decay=1e-5) | |
inputs = torch.randn([bs, 3, 224, 224]).half().cuda() | |
targets = torch.argmax(torch.randn([bs, num_classes]), dim=1).cuda() | |
scaler = torch.cuda.amp.GradScaler() | |
def test(optimizer): | |
total_time_ms = 0 | |
start_event = torch.cuda.Event(enable_timing=True) | |
end_event = torch.cuda.Event(enable_timing=True) | |
skip_steps = 12 | |
for i in range(iter_num + skip_steps): | |
with torch.cuda.amp.autocast(): | |
outputs = model(inputs) | |
loss = F.cross_entropy(outputs, targets) | |
scaler.scale(loss).backward() | |
start_event.record() | |
optimizer.step() | |
end_event.record() | |
torch.cuda.synchronize() # Wait for the events to be recorded! | |
if i >= skip_steps: | |
total_time_ms += start_event.elapsed_time(end_event) | |
optimizer.zero_grad() | |
avg_time_ms = total_time_ms / iter_num | |
print(f'native_adamw: {avg_time_ms:.3f} ms') | |
return avg_time_ms | |
adamw_time = test(adamw) | |
fused_adam_time = test(fused_adam) | |
print('==================================================') | |
print(f'model_name: {model_name}') | |
print(f'iter_num: {iter_num}') | |
print(f'bs: {bs}') | |
print(f'adamw avg time: {adamw_time:.3f} ms') | |
print(f'fused adam avg time: {fused_adam_time:.3f} ms') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment