Skip to content

Instantly share code, notes, and snippets.

@johntips
Created July 26, 2017 20:58
Show Gist options
  • Save johntips/9d40e647cf8b3118ee078dee0500f130 to your computer and use it in GitHub Desktop.
Save johntips/9d40e647cf8b3118ee078dee0500f130 to your computer and use it in GitHub Desktop.
import numpy as np
import matplotlib.pyplot as plt
from scipy.misc import toimage
from keras.datasets import cifar10
if __name__ == '__main__':
# CIFAR-10データセットをロード
(X_train, y_train), (X_test, y_test) = cifar10.load_data()
print(X_train.shape, y_train.shape)
print(X_test.shape, y_test.shape)
# 画像を描画
nclasses = 10
pos = 1
for targetClass in range(nclasses):
targetIdx = []
# クラスclassIDの画像のインデックスリストを取得
for i in range(len(y_train)):
if y_train[i][0] == targetClass:
targetIdx.append(i)
# 各クラスからランダムに選んだ最初の10個の画像を描画
np.random.shuffle(targetIdx)
for idx in targetIdx[:10]:
img = toimage(X_train[idx])
plt.subplot(10, 10, pos)
plt.imshow(img)
plt.axis('off')
pos += 1
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment