Last active
August 29, 2018 11:10
-
-
Save artcg/65b04c3fa43fab02d7ebb12e33075c32 to your computer and use it in GitHub Desktop.
Quick and dirty script to convert CUDA t7 files to CPU-friendly versions
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def cuda2float(filename): | |
f = open(filename, 'rb') | |
s = f.read() | |
f.close() | |
CudaTensor = b''.fromhex('10000000746F7263682E 43756461 54656E736F72 '.replace(' ', '')) | |
FloatTensor = b''.fromhex('11000000746F7263682E 466C6F6174 54656E736F72 '.replace(' ', '')) | |
CudaStorage = b''.fromhex('11000000746F7263682E 43756461 53746F72616765'.replace(' ', '')) | |
FloatStorage= b''.fromhex('12000000746F7263682E 466C6F6174 53746F72616765'.replace(' ', '')) | |
cudnnSpatialBatchNorm = b''.fromhex('1F0000006375646E6E2E5370617469616C42617463684E6F726D616C697A6174696F6E') | |
nnSpatialBatchNorm = b''.fromhex('1C0000006E6E2E5370617469616C42617463684E6F726D616C697A6174696F6E') | |
s = s.replace(CudaTensor, FloatTensor) | |
s = s.replace(CudaStorage, FloatStorage) | |
s = s.replace(cudnnSpatialBatchNorm, nnSpatialBatchNorm) | |
f = open(filename, 'wb') | |
f.write(s) | |
f.close() | |
filename = 'some_weights.t7' | |
cuda2float(filename) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment