Last active
May 19, 2023 18:26
-
-
Save takuma104/894dff4e48a7e1dbebedcff136da5956 to your computer and use it in GitHub Desktop.
monkey_patch_minimum_test.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"machine_shape": "hm", | |
"gpuType": "T4", | |
"name": "monkey_patch_minimum_test.ipynb", | |
"authorship_tag": "ABX9TyMFttedlOM6CLYig1i5WIyI", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
}, | |
"accelerator": "GPU", | |
"gpuClass": "standard" | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/takuma104/894dff4e48a7e1dbebedcff136da5956/untitled11.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"id": "dQwQ6YK0BaFf" | |
}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"import torch.nn" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"device = 'cuda' # Changing this to 'cpu' still results in the vanilla case failing." | |
], | |
"metadata": { | |
"id": "PesDolhICF0L" | |
}, | |
"execution_count": 2, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"class TargetModule(torch.nn.Module):\n", | |
" def __init__(self):\n", | |
" super().__init__()\n", | |
" self.linears = torch.nn.ModuleList([\n", | |
" torch.nn.Linear(2, 2),\n", | |
" torch.nn.Linear(2, 2), # If you comment out this line, the vanilla case will succeed.\n", | |
" ])\n", | |
" def forward(self, x):\n", | |
" for module in self.linears:\n", | |
" x = module(x)\n", | |
" return x\n" | |
], | |
"metadata": { | |
"id": "McbX2HrlBdrp" | |
}, | |
"execution_count": 3, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def test_monkey_patch_instance_method():\n", | |
" def monkey_patch(target):\n", | |
" for name, module in target.named_modules():\n", | |
" if isinstance(module, torch.nn.Linear):\n", | |
" print(f'monkey patching to {name}')\n", | |
" module.old_forward = module.forward\n", | |
" def new_forward(self, x):\n", | |
" return self.old_forward(x)\n", | |
" module.forward = new_forward.__get__(module)\n", | |
"\n", | |
" torch.manual_seed(0)\n", | |
" x = torch.randn((2, 2)).to(device)\n", | |
" target = TargetModule().to(device)\n", | |
" with torch.no_grad():\n", | |
" print('')\n", | |
" print('*' * 80)\n", | |
" print('instance_method:')\n", | |
"\n", | |
" y = target(x)\n", | |
" print(y)\n", | |
" assert y.shape == (2, 2)\n", | |
"\n", | |
" monkey_patch(target)\n", | |
"\n", | |
" yy = target(x)\n", | |
" print(yy)\n", | |
" assert torch.allclose(yy, y), \"instance_method: monkey patching failed\"\n" | |
], | |
"metadata": { | |
"id": "RENlHKFNBh6x" | |
}, | |
"execution_count": 4, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def test_monkey_patch_fix_closure():\n", | |
" def monkey_patch(target):\n", | |
" for name, module in target.named_modules():\n", | |
" if isinstance(module, torch.nn.Linear):\n", | |
" print(f'monkey patching to {name}')\n", | |
" old_forward = module.forward\n", | |
" def make_new_forward(old_forward):\n", | |
" def new_forward(x):\n", | |
" return old_forward(x)\n", | |
" return new_forward\n", | |
" module.forward = make_new_forward(old_forward)\n", | |
"\n", | |
" torch.manual_seed(0)\n", | |
" x = torch.randn((2, 2)).to(device)\n", | |
" target = TargetModule().to(device)\n", | |
" with torch.no_grad():\n", | |
" print('')\n", | |
" print('*' * 80)\n", | |
" print('vanilla:')\n", | |
"\n", | |
" y = target(x)\n", | |
" print(y)\n", | |
" assert y.shape == (2, 2)\n", | |
"\n", | |
" monkey_patch(target)\n", | |
"\n", | |
" yy = target(x)\n", | |
" print(yy)\n", | |
" assert torch.allclose(yy, y), \"fix closure: monkey patching failed\"" | |
], | |
"metadata": { | |
"id": "I3-XZvruQRwD" | |
}, | |
"execution_count": 5, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def test_monkey_patch_vanilla():\n", | |
" def monkey_patch(target):\n", | |
" for name, module in target.named_modules():\n", | |
" if isinstance(module, torch.nn.Linear):\n", | |
" print(f'monkey patching to {name}')\n", | |
" old_forward = module.forward\n", | |
" def new_forward(x):\n", | |
" return old_forward(x)\n", | |
" module.forward = new_forward\n", | |
"\n", | |
" torch.manual_seed(0)\n", | |
" x = torch.randn((2, 2)).to(device)\n", | |
" target = TargetModule().to(device)\n", | |
" with torch.no_grad():\n", | |
" print('')\n", | |
" print('*' * 80)\n", | |
" print('vanilla:')\n", | |
"\n", | |
" y = target(x)\n", | |
" print(y)\n", | |
" assert y.shape == (2, 2)\n", | |
"\n", | |
" monkey_patch(target)\n", | |
"\n", | |
" yy = target(x)\n", | |
" print(yy)\n", | |
" assert torch.allclose(yy, y), \"vanilla: monkey patching failed\"\n" | |
], | |
"metadata": { | |
"id": "DGjU6E2jBgsi" | |
}, | |
"execution_count": 6, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"test_monkey_patch_instance_method()" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "0AG9kwGoBsQ5", | |
"outputId": "70185998-3a8c-4bde-8bb1-cc02d425d6e8" | |
}, | |
"execution_count": 7, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\n", | |
"********************************************************************************\n", | |
"instance_method:\n", | |
"tensor([[-0.2581, -0.8602],\n", | |
" [-0.3555, -0.4635]], device='cuda:0')\n", | |
"monkey patching to linears.0\n", | |
"monkey patching to linears.1\n", | |
"tensor([[-0.2581, -0.8602],\n", | |
" [-0.3555, -0.4635]], device='cuda:0')\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"test_monkey_patch_fix_closure()" | |
], | |
"metadata": { | |
"id": "y63bOY-CQlQX", | |
"outputId": "d54b493e-9620-4d82-a352-4197c0cd81d4", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
} | |
}, | |
"execution_count": 8, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\n", | |
"********************************************************************************\n", | |
"vanilla:\n", | |
"tensor([[-0.2581, -0.8602],\n", | |
" [-0.3555, -0.4635]], device='cuda:0')\n", | |
"monkey patching to linears.0\n", | |
"monkey patching to linears.1\n", | |
"tensor([[-0.2581, -0.8602],\n", | |
" [-0.3555, -0.4635]], device='cuda:0')\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"test_monkey_patch_vanilla()" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 423 | |
}, | |
"id": "Y0vUq9ZxBu6R", | |
"outputId": "8c2f6af9-0717-4be1-9bfb-f72010d23fb7" | |
}, | |
"execution_count": 9, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\n", | |
"********************************************************************************\n", | |
"vanilla:\n", | |
"tensor([[-0.2581, -0.8602],\n", | |
" [-0.3555, -0.4635]], device='cuda:0')\n", | |
"monkey patching to linears.0\n", | |
"monkey patching to linears.1\n", | |
"tensor([[-0.2065, -0.5703],\n", | |
" [-0.5468, -0.5470]], device='cuda:0')\n" | |
] | |
}, | |
{ | |
"output_type": "error", | |
"ename": "AssertionError", | |
"evalue": "ignored", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", | |
"\u001b[0;32m<ipython-input-9-05af1f1bc236>\u001b[0m in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mtest_monkey_patch_vanilla\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", | |
"\u001b[0;32m<ipython-input-6-95b097a22274>\u001b[0m in \u001b[0;36mtest_monkey_patch_vanilla\u001b[0;34m()\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0myy\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0myy\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 27\u001b[0;31m \u001b[0;32massert\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mallclose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0myy\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"vanilla: monkey patching failed\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", | |
"\u001b[0;31mAssertionError\u001b[0m: vanilla: monkey patching failed" | |
] | |
} | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment