Skip to content

Instantly share code, notes, and snippets.

@mehdidc
Created August 17, 2024 11:14
Show Gist options
  • Save mehdidc/3333eafbe5d13b00ba833b3fc7bbc293 to your computer and use it in GitHub Desktop.
Save mehdidc/3333eafbe5d13b00ba833b3fc7bbc293 to your computer and use it in GitHub Desktop.
flops_mobileclip.py
import torch
import mobileclip
from PIL import Image
import numpy as np
import pandas as pd
from torch.utils.flop_counter import FlopCounterMode
import open_clip
import fvcore
import fvcore.nn
model, _, preprocess = mobileclip.create_model_and_transforms(
'mobileclip_s0',
#pretrained='mobileclip_s0.pt'
)
# model, _, preprocess = mobileclip.create_model_and_transforms(
# 'mobileclip_s1',
# pretrained='mobileclip_s1.pt'
# )
# model, _, preprocess = mobileclip.create_model_and_transforms(
# 'mobileclip_s2',
# pretrained='mobileclip_s2.pt'
# )
# model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='openai', cache_dir='./cache')
# model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k', cache_dir='./cache')
# model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-16', pretrained='laion2b_s34b_b88k', cache_dir='./cache')
# model, preprocess = open_clip.create_model_from_pretrained('hf-hub:timm/ViT-B-16-SigLIP', cache_dir='./cache')
# model, preprocess = open_clip.create_model_from_pretrained('hf-hub:timm/ViT-B-16-SigLIP-i18n-256', cache_dir='./cache')
# model, preprocess = open_clip.create_model_from_pretrained('hf-hub:timm/ViT-L-16-SigLIP-256', cache_dir='./cache')
model.eval()
model = model.cuda()
device, dtype = next(model.parameters()).device, next(model.parameters()).dtype
batch_size = 1
###### Image GFlops compute ######
# image_input_size = (3, 224, 224)
image_input_size = (3, 256, 256)
example_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype)
flop_counter = FlopCounterMode()
with flop_counter, torch.no_grad():
model(image=example_input, text=None)
fca = fvcore.nn.FlopCountAnalysis(model, (example_input, None))
fca_total = fca.total()
# this counter returns GMacs, we have to mul by 2 to get GFlops
# see: https://gist.github.com/soumith/5f81c3d40d41bb9d08041431c656b233
# see also: https://dev-discuss.pytorch.org/t/the-ideal-pytorch-flop-counter-with-torch-dispatch/505/9
total_flops = sum(flop_counter.get_flop_counts()['Global'].values())
total_flops = total_flops * 2
total_flops = round(total_flops / 1e9, 4)
print('Image GFlops (matching MobileCLIP): {}'.format(total_flops/4))
print('Image GMacs: {}'.format(total_flops/2))
print('Image GFlops: {}'.format(total_flops))
print('Image GFlops with fvcore: {}'.format(fca_total/1e9))
###### Text GFlops compute ######
device = next(model.parameters()).device
text_input_size = (77,)
# text_input_size = (64,)
example_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64)
flop_counter = FlopCounterMode()
with flop_counter, torch.no_grad():
model(image=None, text=example_input)
# this counter returns GMacs, we have to mul by 2 to get GFlops
# see: https://gist.github.com/soumith/5f81c3d40d41bb9d08041431c656b233
# see also: https://dev-discuss.pytorch.org/t/the-ideal-pytorch-flop-counter-with-torch-dispatch/505/9
total_flops = sum(flop_counter.get_flop_counts()['Global'].values())
total_flops = total_flops * 2
total_flops = round(total_flops / 1e9, 4)
fca = fvcore.nn.FlopCountAnalysis(model, (None, example_input))
fca_total = fca.total()
print("Text GFlops (matching MobileCLIP): {}".format(total_flops/4))
print("Text GMacs: {}".format(total_flops/2))
print("Text GFlops: {}".format(total_flops))
print('Text GFlops with fvcore: {}'.format(fca_total/1e9))
###### Full GFlops compute ######
device, dtype = next(model.parameters()).device, next(model.parameters()).dtype
image_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype)
text_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64)
flop_counter = FlopCounterMode()
with flop_counter, torch.no_grad():
model(image=image_input, text=text_input)
# this counter returns GMacs, we have to mul by 2 to get GFlops
# see: https://gist.github.com/soumith/5f81c3d40d41bb9d08041431c656b233
# see also: https://dev-discuss.pytorch.org/t/the-ideal-pytorch-flop-counter-with-torch-dispatch/505/9
total_flops = sum(flop_counter.get_flop_counts()['Global'].values())
total_flops = total_flops * 2
total_flops = round(total_flops / 1e9, 4)
fca = fvcore.nn.FlopCountAnalysis(model, (image_input, text_input))
fca_total = fca.total()
def count_params(model):
return sum([m.numel() for m in model.parameters()])
total_params = count_params(model.image_encoder)
print("Full Model GFlops (matching MobileCLIP): {}".format(total_flops/2))
print("Full Model GMacs: {}".format(total_flops/2))
print("Full Model GFlops: {}".format(total_flops))
print("Full Model GFlops with fvcore: {}".format(fca_total / 1e9))
print(f"Full model Params (M): {total_params/1e6}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment