Last active
September 7, 2018 09:45
-
-
Save ecoopnet/20f504863f21cdcf9e318ddc6cd88f62 to your computer and use it in GitHub Desktop.
Deep Learning Tools (mainly for keras)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
import h5py | |
import numpy as np | |
import sys | |
# file: h5py.File or h5py.Group | |
# level: nest | |
def dump_h5py_file(item, level = 0, by_name = False): | |
pre = level * ' ' | |
for i in list(item): | |
if by_name: | |
name = item[i].name | |
else: | |
name = str(i) | |
if type(item[i]) is h5py.Dataset: | |
print(pre + name + " ... " + str(np.array(item[i]).shape )) | |
else: # h5py.Group, h5py.File | |
print(pre + name) | |
dump_h5py_file(item[i], level + 1) | |
argv = sys.argv | |
argc = len(argv) | |
print(argv) | |
if argc != 2: | |
print("Dump a .h5 file structure(without modifing the file)") | |
print("Usage: # %s filename" % argv[0]) | |
quit() | |
file=argv[1] | |
with h5py.File(file, 'r') as f: | |
print("Dump ", file, ":") | |
dump_h5py_file(f) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras.utils import Sequence | |
import threading | |
class GeneratorSequence(Sequence): | |
''' | |
Convert generator to keras.utils.Sequence | |
Usage: | |
length = 100000 # ,length of dataset. If generator is countable, it can be `length = len(generator)`. | |
train_generator = train_dataset.generate( ... ) # create uncountable generator. | |
validate_generator = validate_dataset.generate( ... ) # create uncountable generator. | |
train_sequence = GeneratorSequence(train_generator, length) | |
validate_sequence = GeneratorSequence(validate_generator, length) | |
model.fit_generator( | |
# you can use sequence as if it is a generator. | |
generator=train_sequence, | |
validation_data=validate_sequence, | |
... ) | |
''' | |
def __init__(self, generator, length, lock=False): | |
if lock: | |
self.lock = threading.RLock() | |
else: | |
self.lock = None | |
self.length = length | |
self.train_generator = train_generator | |
# self.validation_generator = validation_generator | |
def __len__(self): | |
return self.length | |
def __getitem__(self, idx): | |
if self.lock != None: | |
self.lock.acquire() | |
value = next(self.train_generator) | |
if self.lock != None: | |
self.lock.release() | |
return value | |
def on_epoch_end(self): | |
pass |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras.callbacks import Callback | |
import numpy as np | |
import matplotlib.pyplot as plt | |
# レイヤ内の1個目のフィルタをグレースケールで可視化する | |
# layer = model.layers[i] # 0, 2, 6 | |
# filter_show(layer.get_weights()[0], nx=16) | |
def filter_show(filters, nx=8, margin=3, scale=10): | |
""" | |
c.f. https://gist.github.com/aidiary/07d530d5e08011832b12#file-draw_weight-py | |
https://stackoverflow.com/questions/43305891/how-to-correctly-get-layer-weights-from-conv2d-in-keras | |
""" | |
# 本のサンプルはコレ | |
# FN, C, FH, FW = filters.shape | |
# だがkerasはこうらしい | |
FH, FW, C, FN = filters.shape | |
ny = int(np.ceil(FN / nx)) | |
# | |
fig = plt.figure(figsize=(1,1), dpi=200) | |
# fig = plt.figure() | |
fig.subplots_adjust(left=0, right=1, bottom=0, top=1, hspace=0.05, wspace=0.05) | |
# | |
for i in range(FN): | |
ax = fig.add_subplot(ny, nx, i+1, xticks=[], yticks=[]) | |
# # [X,Y,1]なグレイスケール表示 | |
# # i番目のフィルタの1チャンネル目(RGBならR)をとる | |
f = filters[:,:,:,i][:,:,0] | |
ax.imshow(f, cmap=plt.cm.gray_r, interpolation='nearest') | |
# [X,Y,3]なRGB表示 | |
# i番目のフィルタ参照 / マイナス値になるのでclipされてしまうことに注意 | |
# CNN 2層目以降は3チャンネルに基本ならないので使えない。。 | |
# f = filters[:,:,:,i] | |
# ax.imshow(np.array(f), interpolation='nearest') | |
# ax.imshow(f[i, 0], cmap=plt.cm.gray_r, interpolation='nearest') | |
plt.show() | |
# # keras の Callback 例 | |
# class VisualizeLayer(keras.callbacks.Callback): | |
# def on_epoch_end(self, epoch, logs={}): | |
# print("end epoch("+str(epoch)+")") | |
# for i in (0,2,6,8): | |
# print("layer["+str(i)+"]:") | |
# layer = model.layers[i] # 0, 2, 6 | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment