Created
March 21, 2023 02:24
-
-
Save DSCF-1224/8a7204fecf61cda5ef43fa11ce8faaef to your computer and use it in GitHub Desktop.
Python の pkl ファイルを GitHub からダウンロードし、適当に変換して保存する
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy | |
import os | |
import pickle | |
import requests | |
def convert_data(targetData, fileName): | |
with open(fileName, 'wb') as writingFileStream: | |
# Output of npy file for comparison and verification | |
numpy.save(fileName, targetData) | |
# Obtain the dimensions of `targetData` | |
ndimTargetData = numpy.ndim(targetData) | |
# Write the dimension of `targetData` | |
writingFileStream.write( ndimTargetData.to_bytes(4, 'little') ) | |
if (ndimTargetData == 1): # if `targetData` was 1D array | |
# Write the number of elements in a 1D array | |
writingFileStream.write( targetData.shape[0].to_bytes(4, 'little') ) | |
# Write the each element in this 1D array | |
writingFileStream.write( targetData ) | |
elif (ndimTargetData == 2): # if `targetData` was 2D array | |
# Write the number of elements in each dimension of a 2D array | |
writingFileStream.write( targetData.shape[0].to_bytes(4, 'little') ) | |
writingFileStream.write( targetData.shape[1].to_bytes(4, 'little') ) | |
# Write the each element in this 2D array | |
writingFileStream.write( targetData.ravel() ) | |
if __name__ == '__main__': | |
# URL of the file you want to download | |
targetFileURL = \ | |
"https://github.com/oreilly-japan/deep-learning-from-scratch/blob/master/ch03/sample_weight.pkl"\ | |
.replace("github.com", "raw.githubusercontent.com")\ | |
.replace("/blob", "") | |
# filename you want to download & save | |
targetFileName = os.path.basename(targetFileURL) | |
if os.path.isfile(targetFileName): | |
print("The target file `" + targetFileName + "` already exists.") | |
else: | |
requestResponse = requests.get(url=targetFileURL) | |
targetFileData = pickle.loads(requestResponse.content) | |
convert_data( targetFileData['b1'], 'sample_weight_b1.bin' ) | |
convert_data( targetFileData['b2'], 'sample_weight_b2.bin' ) | |
convert_data( targetFileData['b3'], 'sample_weight_b3.bin' ) | |
convert_data( targetFileData['W1'], 'sample_weight_W1.bin' ) | |
convert_data( targetFileData['W2'], 'sample_weight_W2.bin' ) | |
convert_data( targetFileData['W3'], 'sample_weight_W3.bin' ) | |
with open(targetFileName, 'wb') as writingFileStream: | |
pickle.dump(targetFileData, writingFileStream) | |
# EOF |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment