Skip to content

Instantly share code, notes, and snippets.

@chirag1992m
Created December 1, 2017 21:36
Show Gist options
  • Save chirag1992m/4c1f2cb27d7c138a4dc76aeddfe940c2 to your computer and use it in GitHub Desktop.
Save chirag1992m/4c1f2cb27d7c138a4dc76aeddfe940c2 to your computer and use it in GitHub Desktop.
weight_transfer
import numpy as np
import torch
import keras
def pyt_to_keras(pytorch_model, keras_model):
"""
Given a PyTorch model, this method transfers the weight to
a Keras Model (with backend TensorFlow) with the same architecture.
Assumptions:
1. The corresponding layer names in both the models will be the same
2. Will throw KeyError when layer is there in Keras model but not in PyTorch model,
otherwise, will ignore the layer.
Implementated Layers:
1. 2D Convolutional Layer
2. Fully Connected Layer
Eg:
class PyNet(nn.Module):
def __init__(self):
super(PyNet, self).__init__()
self.conv_1 = nn.Conv2d(in_channels=1, out_channels=2, kernel_size=(3, 4), padding=0)
def forward(self, x):
return self.conv_1(nn.ZeroPad2d((1, 2, 1, 1))(x))
pyt_model = PyNet()
a = keras.Input(shape=(5, 6, 1), name='input')
b = keras.layers.Conv2D(2, (3, 4), activation='linear', padding='same',
name='conv_1', bias_initializer='random_uniform')(a)
keras_model = keras.models.Model(inputs=a, outputs=b)
keras_model = pyt_to_keras(pyt_model, keras_model)
Args:
:param pytorch_model: Similar model in PyTorch
:type pytorch_model: torch.nn.Module
:param keras_model: Model loaded in Keras.
:type keras_model: keras.models.Model
:return: PyTorch model with weight transferred from the Keras Model
:rtype: torch.nn.Module
"""
pyt_state_dict = pytorch_model.state_dict()
for idx, layer in enumerate(keras_model.layers):
if type(layer).__name__.endswith('Conv2D'):
# Keras 2D Convolutional layer: height * width * input channels * output channels
# PyTorch 2D Convolutional layer: output channels * input channels * height * width
name = layer.name
weights = np.transpose(pyt_state_dict[name + '.weight'].numpy(), (2, 3, 1, 0))
bias = pyt_state_dict[name + '.bias'].numpy()
keras_model.layers[idx].set_weights([weights, bias])
elif type(layer).__name__.endswith('Dense'):
# Keras Linear Layer: input neurons * output neurons
# PyTorch Linear Layer: output neurons * input neurons
name = layer.name
weights = np.transpose(pyt_state_dict[name + '.weight'].numpy(), (1, 0))
bias = pyt_state_dict[name + '.bias'].numpy()
keras_model.layers[idx].set_weights([weights, bias])
return keras_model
@debadridtt
Copy link

I've got a model checkpoint, .pth.atr so to extract weights I'll have to use the ['state_dict'] keyword. So I stored it in a variable x, and then used the function but its not working.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment