Last active
November 23, 2023 18:00
-
-
Save kalradivyanshu/4adff01df55f7abfc74ee773422b97da to your computer and use it in GitHub Desktop.
how to get model total trainable parameters for tflite and pytorch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
def count_parameters_tflite(model): | |
details = model.get_tensor_details() | |
# Calculate the total number of trainable parameters | |
total_params = 0 | |
for detail in details: | |
shape = detail['shape'] | |
if 'weight' in detail['name'] or 'bias' in detail['name']: | |
total_params += np.prod(shape) | |
print(f'Total trainable parameters: {total_params}') | |
# https://medium.com/the-owl/how-to-get-model-summary-in-pytorch-57db7824d1e3 | |
from prettytable import PrettyTable | |
def count_parameters_pytorch(model): | |
table = PrettyTable(["Modules", "Parameters"]) | |
total_params = 0 | |
for name, parameter in model.named_parameters(): | |
if not parameter.requires_grad: continue | |
params = parameter.numel() | |
table.add_row([name, params]) | |
total_params+=params | |
print(table) | |
print(f"Total Trainable Params: {total_params}") | |
return total_params |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment