Last active
May 31, 2023 16:34
-
-
Save takuma104/93094f989ee89e4cd61af09f9d909e26 to your computer and use it in GitHub Desktop.
monkey_patch_minimum_test.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"machine_shape": "hm", | |
"gpuType": "T4", | |
"name": "monkey_patch_minimum_test.ipynb", | |
"authorship_tag": "ABX9TyNjZOFW4I2ayqE1wnNBP8OS", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
}, | |
"accelerator": "GPU", | |
"gpuClass": "standard" | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/takuma104/93094f989ee89e4cd61af09f9d909e26/untitled11.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"id": "dQwQ6YK0BaFf" | |
}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"import torch.nn" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"device = 'cuda' # Changing this to 'cpu' still results in the vanilla case failing." | |
], | |
"metadata": { | |
"id": "PesDolhICF0L" | |
}, | |
"execution_count": 2, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"class TargetModule(torch.nn.Module):\n", | |
" def __init__(self):\n", | |
" super().__init__()\n", | |
" self.linears = torch.nn.ModuleList([\n", | |
" torch.nn.Linear(2, 2),\n", | |
" ])\n", | |
" def forward(self, x):\n", | |
" for module in self.linears:\n", | |
" x = module(x)\n", | |
" return x\n" | |
], | |
"metadata": { | |
"id": "McbX2HrlBdrp" | |
}, | |
"execution_count": 13, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def test_monkey_patch_fix_closure():\n", | |
" def monkey_patch(target):\n", | |
" for name, module in target.named_modules():\n", | |
" if isinstance(module, torch.nn.Linear):\n", | |
" print(f'monkey patching to {name}')\n", | |
"\n", | |
" if hasattr(module, 'old_forward'):\n", | |
" print('undo monkey-patch')\n", | |
" module.forward = module.old_forward\n", | |
" delattr(module, 'old_forward')\n", | |
"\n", | |
" old_forward = module.old_forward = module.forward\n", | |
" def make_new_forward(old_forward):\n", | |
" def new_forward(x):\n", | |
" return old_forward(x) * 2.0\n", | |
" return new_forward\n", | |
" module.forward = make_new_forward(old_forward)\n", | |
"\n", | |
" torch.manual_seed(0)\n", | |
" x = torch.randn((2, 2)).to(device)\n", | |
" target = TargetModule().to(device)\n", | |
" with torch.no_grad():\n", | |
" print('')\n", | |
" print('*' * 80)\n", | |
" print('vanilla:')\n", | |
"\n", | |
" y = target(x)\n", | |
" print(y)\n", | |
" assert y.shape == (2, 2)\n", | |
"\n", | |
" monkey_patch(target)\n", | |
"\n", | |
" yy = target(x)\n", | |
" print(yy)\n", | |
" assert torch.allclose(yy, y*2.0), \"fix closure: monkey patching failed\"\n", | |
"\n", | |
" monkey_patch(target)\n", | |
"\n", | |
" yyy = target(x)\n", | |
" print(yyy)\n", | |
" assert torch.allclose(yyy, y*2.0), \"fix closure: monkey patching failed\"\n", | |
"\n" | |
], | |
"metadata": { | |
"id": "I3-XZvruQRwD" | |
}, | |
"execution_count": 23, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"test_monkey_patch_fix_closure()" | |
], | |
"metadata": { | |
"id": "y63bOY-CQlQX", | |
"outputId": "ce428186-1e12-4739-d6f5-6b2b16e0df88", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
} | |
}, | |
"execution_count": 24, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\n", | |
"********************************************************************************\n", | |
"vanilla:\n", | |
"tensor([[-0.8271, -0.7568],\n", | |
" [-0.4325, -0.0817]], device='cuda:0')\n", | |
"monkey patching to linears.0\n", | |
"tensor([[-1.6543, -1.5137],\n", | |
" [-0.8649, -0.1634]], device='cuda:0')\n", | |
"monkey patching to linears.0\n", | |
"undo monkey-patch\n", | |
"tensor([[-1.6543, -1.5137],\n", | |
" [-0.8649, -0.1634]], device='cuda:0')\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [], | |
"metadata": { | |
"id": "aQU2IEcImwta" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment