Skip to content

Instantly share code, notes, and snippets.

@ypwhs
Last active October 8, 2019 09:40
Show Gist options
  • Save ypwhs/c54dc9031a408a31c71fa513c2da9945 to your computer and use it in GitHub Desktop.
Save ypwhs/c54dc9031a408a31c71fa513c2da9945 to your computer and use it in GitHub Desktop.
Keras 固定随机数种子重现相同的训练结果
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using TensorFlow backend.\n"
]
}
],
"source": [
"from keras.layers import *\n",
"from keras.models import *\n",
"from keras.datasets.mnist import load_data\n",
"\n",
"import numpy as np\n",
"import tensorflow as tf\n",
"\n",
"(X_train, y_train), (X_test, y_test) = load_data()\n",
"X_train, X_test = X_train / 255.0, X_test / 255.0"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING:tensorflow:From /Users/ypw/miniconda3/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Colocations handled automatically by placer.\n",
"WARNING:tensorflow:From /Users/ypw/miniconda3/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:3445: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.\n",
"WARNING:tensorflow:From /Users/ypw/miniconda3/lib/python3.6/site-packages/tensorflow/python/ops/math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Use tf.cast instead.\n",
"Train on 60000 samples, validate on 10000 samples\n",
"Epoch 1/1\n",
"60000/60000 [==============================] - 2s 34us/step - loss: 1.1246 - acc: 0.7235 - val_loss: 0.6050 - val_acc: 0.8626\n"
]
},
{
"data": {
"text/plain": [
"<keras.callbacks.History at 0x11edfb320>"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.random.seed(2019)\n",
"tf.set_random_seed(2019)\n",
"\n",
"model = Sequential()\n",
"model.add(Flatten(input_shape=X_train.shape[1:]))\n",
"model.add(Dense(512, activation='relu'))\n",
"model.add(Dropout(0.2))\n",
"model.add(Dense(10, activation='softmax'))\n",
"\n",
"model.compile(optimizer='sgd', \n",
" loss='sparse_categorical_crossentropy',\n",
" metrics=['accuracy'])\n",
"\n",
"model.fit(X_train, y_train, batch_size=128, epochs=1, validation_data=(X_test, y_test))"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[ 0.08332505, -0.04046969, -0.10420379, ..., 0.01150452,\n",
" 0.00657567, 0.09461053],\n",
" [-0.06537206, -0.00617846, -0.01010726, ..., -0.09795328,\n",
" 0.03855512, -0.09573275],\n",
" [ 0.06116908, -0.02009934, -0.0755541 , ..., 0.10054862,\n",
" 0.01570166, 0.1094622 ],\n",
" ...,\n",
" [ 0.02430227, 0.08767582, -0.14034536, ..., -0.0235549 ,\n",
" -0.06785435, 0.07870072],\n",
" [-0.1178989 , 0.04752532, 0.1322875 , ..., -0.01668982,\n",
" 0.03100156, 0.08343341],\n",
" [ 0.04044323, -0.08505493, 0.01293952, ..., -0.0128099 ,\n",
" 0.0589669 , 0.09185769]], dtype=float32)"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"weights1 = K.get_value(model.layers[-1].weights[0])\n",
"weights1"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train on 60000 samples, validate on 10000 samples\n",
"Epoch 1/1\n",
"60000/60000 [==============================] - 2s 35us/step - loss: 1.1246 - acc: 0.7235 - val_loss: 0.6050 - val_acc: 0.8626\n"
]
},
{
"data": {
"text/plain": [
"<keras.callbacks.History at 0x140fe7f28>"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.random.seed(2019)\n",
"tf.set_random_seed(2019)\n",
"\n",
"model = Sequential()\n",
"model.add(Flatten(input_shape=X_train.shape[1:]))\n",
"model.add(Dense(512, activation='relu'))\n",
"model.add(Dropout(0.2))\n",
"model.add(Dense(10, activation='softmax'))\n",
"\n",
"model.compile(optimizer='sgd', \n",
" loss='sparse_categorical_crossentropy',\n",
" metrics=['accuracy'])\n",
"\n",
"model.fit(X_train, y_train, batch_size=128, epochs=1, validation_data=(X_test, y_test))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[ 0.08332505, -0.04046969, -0.10420379, ..., 0.01150452,\n",
" 0.00657567, 0.09461053],\n",
" [-0.06537206, -0.00617846, -0.01010726, ..., -0.09795328,\n",
" 0.03855512, -0.09573275],\n",
" [ 0.06116908, -0.02009934, -0.0755541 , ..., 0.10054862,\n",
" 0.01570166, 0.1094622 ],\n",
" ...,\n",
" [ 0.02430227, 0.08767582, -0.14034536, ..., -0.0235549 ,\n",
" -0.06785435, 0.07870072],\n",
" [-0.1178989 , 0.04752532, 0.1322875 , ..., -0.01668982,\n",
" 0.03100156, 0.08343341],\n",
" [ 0.04044323, -0.08505493, 0.01293952, ..., -0.0128099 ,\n",
" 0.0589669 , 0.09185769]], dtype=float32)"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"weights2 = K.get_value(model.layers[-1].weights[0])\n",
"weights2"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(weights1 == weights2).all()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
@yylonly
Copy link

yylonly commented May 15, 2018

我在google colab无法重现 你tf和keras版本是?

Train on 60000 samples, validate on 10000 samples
Epoch 1/2
60000/60000 [==============================] - 4s 61us/step - loss: 0.2117 - acc: 0.9365 - val_loss: 0.1123 - val_acc: 0.9645
Epoch 2/2
60000/60000 [==============================] - 3s 54us/step - loss: 0.0834 - acc: 0.9755 - val_loss: 0.0853 - val_acc: 0.9724
<keras.callbacks.History at 0x7f37e8751dd8>

Train on 60000 samples, validate on 10000 samples
Epoch 1/2
60000/60000 [==============================] - 4s 65us/step - loss: 0.2067 - acc: 0.9375 - val_loss: 0.1110 - val_acc: 0.9654
Epoch 2/2
60000/60000 [==============================] - 3s 53us/step - loss: 0.0805 - acc: 0.9759 - val_loss: 0.0831 - val_acc: 0.9735
<keras.callbacks.History at 0x7f37e87f2b38>

@ypwhs
Copy link
Author

ypwhs commented Apr 11, 2019

需要使用CPU运行。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment