Last active
March 22, 2022 13:06
-
-
Save maurapintor/f6bd32160ee30ffe84b5f83ccaa46242 to your computer and use it in GitHub Desktop.
Track queries on PyTorch Model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
class PyTorchModelTracker: | |
def __init__(self, model): | |
self._func_counter = None | |
self._tracked_x = None | |
self._model = model | |
self.reset() | |
def __getattr__(self, attr): | |
""" | |
Allows to expose interfaces of all existing functions of the | |
PyTorch model to the wrapper class. | |
""" | |
try: | |
return self.__getattribute__(attr) | |
except AttributeError: | |
orig_attr = self._model.__getattribute__(attr) | |
if callable(orig_attr): | |
def hooked(*args, **kwargs): | |
result = orig_attr(*args, **kwargs) | |
return result | |
return hooked | |
else: | |
return orig_attr | |
def __call__(self, x): | |
self._tracked_x.append(x) # customize as you want | |
self._func_counter += x.shape[0] | |
return self._model.__call__(x) | |
def reset(self): | |
self._func_counter = torch.tensor(0) | |
self._tracked_x = list() | |
@property | |
def tracked_x(self, value): | |
self._tracked_x = value | |
@tracked_x.getter | |
def tracked_x(self): | |
return torch.cat(self._tracked_x) | |
@property | |
def func_counter(self, value): | |
self._func_counter = value | |
@func_counter.getter | |
def func_counter(self): | |
return self._func_counter | |
if __name__ == '__main__': | |
import torchvision | |
model = torchvision.models.resnet18(pretrained=True) | |
wrapped = PyTorchModelTracker(model) | |
inputs = torch.randn((4, 3, 224, 224)) | |
outs = wrapped(inputs) | |
print(wrapped.tracked_x) | |
print(wrapped.func_counter) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment