Created
May 5, 2023 00:32
-
-
Save nb312/f3b851a1503b5649352d3eae42054e41 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from tensorflow import keras | |
from keras.datasets import mnist | |
from keras.models import Sequential | |
from keras.layers import Dense, Dropout, Flatten | |
from keras.layers import Conv2D, MaxPooling2D | |
from keras import backend as K | |
# 加载数据集并预处理 | |
(x_train, y_train), (x_test, y_test) = mnist.load_data() | |
# 设置参数 | |
num_classes = 10 | |
img_rows, img_cols = 28, 28 | |
# 根据后端调整数据格式 | |
if K.image_data_format() == 'channels_first': | |
x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols) | |
x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols) | |
input_shape = (1, img_rows, img_cols) | |
else: | |
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1) | |
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1) | |
input_shape = (img_rows, img_cols, 1) | |
# 归一化 | |
x_train = x_train.astype('float32') / 255 | |
x_test = x_test.astype('float32') / 255 | |
# 将标签转换为 one-hot 编码 | |
y_train = keras.utils.to_categorical(y_train, num_classes) | |
y_test = keras.utils.to_categorical(y_test, num_classes) | |
# 构建模型 | |
model = Sequential() | |
model.add(Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=input_shape)) | |
model.add(MaxPooling2D(pool_size=(2, 2))) | |
model.add(Conv2D(64, (3, 3), activation='relu')) | |
model.add(MaxPooling2D(pool_size=(2, 2))) | |
model.add(Dropout(0.25)) # 丢弃网络以防止过拟合 | |
model.add(Flatten()) ##扁平化 | |
model.add(Dense(128, activation='relu')) # 全链接层 | |
model.add(Dropout(0.25)) #Dropout扔掉 部分参数 | |
model.add(Dense(num_classes, activation='softmax')) | |
# 编译模型 | |
model.compile(loss=keras.losses.categorical_crossentropy, | |
optimizer=keras.optimizers.Adadelta(), | |
metrics=['accuracy']) | |
# 训练模型 | |
model.fit(x_train, y_train, | |
batch_size=128, | |
epochs=100, | |
verbose=1, | |
validation_data=(x_test, y_test)) | |
# 测试模型 | |
score = model.evaluate(x_test, y_test, verbose=0) | |
print('Test loss:', score[0]) | |
print('Test accuracy:', score[1]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment