Skip to content

Instantly share code, notes, and snippets.

@kalradivyanshu
Last active November 23, 2023 18:00
Show Gist options
  • Save kalradivyanshu/4adff01df55f7abfc74ee773422b97da to your computer and use it in GitHub Desktop.
Save kalradivyanshu/4adff01df55f7abfc74ee773422b97da to your computer and use it in GitHub Desktop.
how to get model total trainable parameters for tflite and pytorch
import numpy as np
def count_parameters_tflite(model):
details = model.get_tensor_details()
# Calculate the total number of trainable parameters
total_params = 0
for detail in details:
shape = detail['shape']
if 'weight' in detail['name'] or 'bias' in detail['name']:
total_params += np.prod(shape)
print(f'Total trainable parameters: {total_params}')
# https://medium.com/the-owl/how-to-get-model-summary-in-pytorch-57db7824d1e3
from prettytable import PrettyTable
def count_parameters_pytorch(model):
table = PrettyTable(["Modules", "Parameters"])
total_params = 0
for name, parameter in model.named_parameters():
if not parameter.requires_grad: continue
params = parameter.numel()
table.add_row([name, params])
total_params+=params
print(table)
print(f"Total Trainable Params: {total_params}")
return total_params
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment