Created
May 4, 2023 06:44
-
-
Save maharjun/48b4c583572ecaf655a52fd56b420f9b to your computer and use it in GitHub Desktop.
Torch Dill Shim
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This is a shim for dill to be used with torch (namely that when used in a project | |
that pickles torch objects, dill should be imported from this module). | |
for example:: | |
from utils.dillshim import dill | |
The purpose of this shim is register the pickling and unpickling logic | |
for certain native pytorch types such as torch random generators that | |
otherwise cannot be pickled by dill, as well as to be able to unpickle | |
objects that were created in different devices | |
""" | |
############################################################################### | |
# BSD 3-Clause License | |
# | |
# Copyright (c) 2023, maharjun | |
# | |
# Redistribution and use in source and binary forms, with or without | |
# modification, are permitted provided that the following conditions are met: | |
# | |
# 1. Redistributions of source code must retain the above copyright notice, this | |
# list of conditions and the following disclaimer. | |
# | |
# 2. Redistributions in binary form must reproduce the above copyright notice, | |
# this list of conditions and the following disclaimer in the documentation | |
# and/or other materials provided with the distribution. | |
# | |
# 3. Neither the name of the copyright holder nor the names of its | |
# contributors may be used to endorse or promote products derived from | |
# this software without specific prior written permission. | |
# | |
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | |
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | |
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |
############################################################################### | |
import torch | |
import io | |
import dill | |
def get_device_name(device: torch.device): | |
""" | |
Get name of specified device object. | |
""" | |
assert isinstance(device, torch.device), "device must be a torch.device" | |
if device.index: | |
return f'{device.type}:{device.index}' | |
else: | |
return device.type | |
def _recreate_generator(gen_state: torch.Tensor, gen_device): | |
return_gen: torch.Generator = torch.Generator(device=gen_device) | |
return_gen.set_state(gen_state) | |
return return_gen | |
@dill.register(torch.Generator) | |
def _save_generator(pickler, gen): | |
return pickler.save_reduce(_recreate_generator, (gen.get_state(), get_device_name(gen.device)), obj=gen) | |
class device_unpickler(dill.Unpickler): | |
""" | |
This is an extension of the dill unpickler that unpickles tensors onto the device specified in the member variable device. | |
Examples | |
-------- | |
One can set the device in the class member `device` and unpickle a file as below:: | |
from utils.generic.dillshim import device_unpickler | |
device_unpickler.device = torch.device('cpu') | |
with open('pickle_file.p', 'rb') as fin: | |
values = device_unpickler(fin).load() | |
One may also set the device for each instance of the device_unpickler as follows:: | |
from utils.generic.dillshim import device_unpickler | |
with open('pickle_file.p', 'rb') as fin: | |
unpickler = device_unpickler(fin) | |
unpickler.device = torch.device('cpu') | |
values = unpickler.load() | |
""" | |
device = None | |
def find_class(self, module, name): | |
if self.device is not None and module == 'torch.storage' and name == '_load_from_bytes': | |
return lambda b: torch.load(io.BytesIO(b), map_location=get_device_name(self.device)) | |
else: return super().find_class(module, name) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment