Skip to content

Instantly share code, notes, and snippets.

@maurapintor
Last active March 22, 2022 13:06
Show Gist options
  • Save maurapintor/f6bd32160ee30ffe84b5f83ccaa46242 to your computer and use it in GitHub Desktop.
Save maurapintor/f6bd32160ee30ffe84b5f83ccaa46242 to your computer and use it in GitHub Desktop.
Track queries on PyTorch Model
import torch
class PyTorchModelTracker:
def __init__(self, model):
self._func_counter = None
self._tracked_x = None
self._model = model
self.reset()
def __getattr__(self, attr):
"""
Allows to expose interfaces of all existing functions of the
PyTorch model to the wrapper class.
"""
try:
return self.__getattribute__(attr)
except AttributeError:
orig_attr = self._model.__getattribute__(attr)
if callable(orig_attr):
def hooked(*args, **kwargs):
result = orig_attr(*args, **kwargs)
return result
return hooked
else:
return orig_attr
def __call__(self, x):
self._tracked_x.append(x) # customize as you want
self._func_counter += x.shape[0]
return self._model.__call__(x)
def reset(self):
self._func_counter = torch.tensor(0)
self._tracked_x = list()
@property
def tracked_x(self, value):
self._tracked_x = value
@tracked_x.getter
def tracked_x(self):
return torch.cat(self._tracked_x)
@property
def func_counter(self, value):
self._func_counter = value
@func_counter.getter
def func_counter(self):
return self._func_counter
if __name__ == '__main__':
import torchvision
model = torchvision.models.resnet18(pretrained=True)
wrapped = PyTorchModelTracker(model)
inputs = torch.randn((4, 3, 224, 224))
outs = wrapped(inputs)
print(wrapped.tracked_x)
print(wrapped.func_counter)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment