Last active
April 20, 2018 00:07
-
-
Save jongwook/a51bb55e1bce849a7d6fc502339bc388 to your computer and use it in GitHub Desktop.
Truncating the last 7 bits of the weights in a saved Keras model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import h5py | |
import numpy as np | |
parser = argparse.ArgumentParser() | |
parser.add_argument('input') | |
parser.add_argument('output') | |
args = parser.parse_args() | |
with h5py.File(args.output, 'w') as out: | |
def visit(name, obj): | |
if hasattr(obj, 'dtype'): | |
assert obj.dtype == np.float32 | |
# round the 7 last bits | |
uint32 = np.fromstring(obj.value.tostring(), dtype=np.uint32) | |
uint32 += 0x00000040 | |
uint32 &= 0xffffff80 | |
truncated = np.fromstring(uint32.tostring(), dtype=np.float32).reshape(obj.value.shape) | |
print('truncated', name, ', shape =', truncated.shape) | |
dataset = out.create_dataset(name, data=truncated) | |
for key, value in obj.attrs.items(): | |
print('copying attribute', key, '=>', value, 'of dataset', name) | |
dataset.attrs[key] = value | |
else: | |
group = out.create_group(name) | |
for key, value in obj.attrs.items(): | |
print('copying attribute', key, '=>', value, 'of group', name) | |
group.attrs[key] = value | |
with h5py.File(args.input) as f: | |
for key, value in f.attrs.items(): | |
print('copying global attribute', key, '=>', value) | |
out.attrs[key] = value | |
f.visititems(visit) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment