Created
December 1, 2017 21:36
-
-
Save chirag1992m/4c1f2cb27d7c138a4dc76aeddfe940c2 to your computer and use it in GitHub Desktop.
weight_transfer
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import torch | |
import keras | |
def pyt_to_keras(pytorch_model, keras_model): | |
""" | |
Given a PyTorch model, this method transfers the weight to | |
a Keras Model (with backend TensorFlow) with the same architecture. | |
Assumptions: | |
1. The corresponding layer names in both the models will be the same | |
2. Will throw KeyError when layer is there in Keras model but not in PyTorch model, | |
otherwise, will ignore the layer. | |
Implementated Layers: | |
1. 2D Convolutional Layer | |
2. Fully Connected Layer | |
Eg: | |
class PyNet(nn.Module): | |
def __init__(self): | |
super(PyNet, self).__init__() | |
self.conv_1 = nn.Conv2d(in_channels=1, out_channels=2, kernel_size=(3, 4), padding=0) | |
def forward(self, x): | |
return self.conv_1(nn.ZeroPad2d((1, 2, 1, 1))(x)) | |
pyt_model = PyNet() | |
a = keras.Input(shape=(5, 6, 1), name='input') | |
b = keras.layers.Conv2D(2, (3, 4), activation='linear', padding='same', | |
name='conv_1', bias_initializer='random_uniform')(a) | |
keras_model = keras.models.Model(inputs=a, outputs=b) | |
keras_model = pyt_to_keras(pyt_model, keras_model) | |
Args: | |
:param pytorch_model: Similar model in PyTorch | |
:type pytorch_model: torch.nn.Module | |
:param keras_model: Model loaded in Keras. | |
:type keras_model: keras.models.Model | |
:return: PyTorch model with weight transferred from the Keras Model | |
:rtype: torch.nn.Module | |
""" | |
pyt_state_dict = pytorch_model.state_dict() | |
for idx, layer in enumerate(keras_model.layers): | |
if type(layer).__name__.endswith('Conv2D'): | |
# Keras 2D Convolutional layer: height * width * input channels * output channels | |
# PyTorch 2D Convolutional layer: output channels * input channels * height * width | |
name = layer.name | |
weights = np.transpose(pyt_state_dict[name + '.weight'].numpy(), (2, 3, 1, 0)) | |
bias = pyt_state_dict[name + '.bias'].numpy() | |
keras_model.layers[idx].set_weights([weights, bias]) | |
elif type(layer).__name__.endswith('Dense'): | |
# Keras Linear Layer: input neurons * output neurons | |
# PyTorch Linear Layer: output neurons * input neurons | |
name = layer.name | |
weights = np.transpose(pyt_state_dict[name + '.weight'].numpy(), (1, 0)) | |
bias = pyt_state_dict[name + '.bias'].numpy() | |
keras_model.layers[idx].set_weights([weights, bias]) | |
return keras_model |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I've got a model checkpoint,
.pth.atr
so to extract weights I'll have to use the['state_dict']
keyword. So I stored it in a variable x, and then used the function but its not working.