Created
September 29, 2016 02:11
-
-
Save xccds/a03747ddbf2dea77ba4c1a95ca926852 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## 加载包" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Packages loaded\n" | |
] | |
} | |
], | |
"source": [ | |
"import scipy.io\n", | |
"import numpy as np \n", | |
"import os \n", | |
"import scipy.misc \n", | |
"import matplotlib.pyplot as plt \n", | |
"import tensorflow as tf\n", | |
"%matplotlib inline \n", | |
"print (\"Packages loaded\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## 定义网络结构" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"IMAGE_W = 800 \n", | |
"IMAGE_H = 600 \n", | |
"cwd = os.getcwd()\n", | |
"# 内容图片文档\n", | |
"CONTENT_IMG = cwd + \"/images/Taipei101.jpg\"\n", | |
"# 风格图片文档\n", | |
"STYLE_IMG = cwd + \"/images/StarryNight.jpg\"\n", | |
"# 输出结果的目录和文档名\n", | |
"OUTOUT_DIR = './images'\n", | |
"OUTPUT_IMG = 'results.png'\n", | |
"# VGG模型文件\n", | |
"VGG_MODEL = cwd + \"/data/imagenet-vgg-verydeep-19.mat\"\n", | |
"INI_NOISE_RATIO = 0.7\n", | |
"STYLE_STRENGTH = 500\n", | |
"ITERATION = 5000\n", | |
"\n", | |
"CONTENT_LAYERS =[('conv4_2',1.)]\n", | |
"STYLE_LAYERS=[('conv1_1',1.),('conv2_1',1.5),('conv3_1',2.),('conv4_1',2.5),('conv5_1',3.)]\n", | |
"\n", | |
"\n", | |
"MEAN_VALUES = np.array([123, 117, 104]).reshape((1,1,1,3))\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# 定义前向计算函数,如果是conv层则计算卷积,如果是pool则进行池化\n", | |
"def build_net(ntype, nin, nwb=None):\n", | |
" if ntype == 'conv':\n", | |
" return tf.nn.relu(tf.nn.conv2d(nin, nwb[0], strides=[1, 1, 1, 1], padding='SAME')+ nwb[1])\n", | |
" elif ntype == 'pool':\n", | |
" return tf.nn.avg_pool(nin, ksize=[1, 2, 2, 1],\n", | |
" strides=[1, 2, 2, 1], padding='SAME')\n", | |
"\n", | |
"# 从VGG模型中提取参数\n", | |
"def get_weight_bias(vgg_layers, i,):\n", | |
" weights = vgg_layers[i][0][0][0][0][0]\n", | |
" weights = tf.constant(weights)\n", | |
" bias = vgg_layers[i][0][0][0][0][1]\n", | |
" bias = tf.constant(np.reshape(bias, (bias.size)))\n", | |
" return weights, bias\n", | |
"\n", | |
"# 构建VGG模型网络结构,从现成的VGG模型文档中读取参数\n", | |
"# 以conv1_1层参数为例,长下面这个样子\n", | |
"# (<tf.Tensor 'Const_83:0' shape=(3, 3, 3, 64) dtype=float32>,\n", | |
"# <tf.Tensor 'Const_84:0' shape=(64,) dtype=float32>)\n", | |
"# conv1_1层输出长下面这个样子\n", | |
"# <tf.Tensor 'Relu_32:0' shape=(1, 600, 800, 64) dtype=float32>\n", | |
"\n", | |
"def build_vgg19(path):\n", | |
" net = {}\n", | |
" vgg_rawnet = scipy.io.loadmat(path)\n", | |
" vgg_layers = vgg_rawnet['layers'][0]\n", | |
" net['input'] = tf.Variable(np.zeros((1, IMAGE_H, IMAGE_W, 3)).astype('float32'))\n", | |
" net['conv1_1'] = build_net('conv',net['input'],get_weight_bias(vgg_layers,0))\n", | |
" net['conv1_2'] = build_net('conv',net['conv1_1'],get_weight_bias(vgg_layers,2))\n", | |
" net['pool1'] = build_net('pool',net['conv1_2'])\n", | |
" net['conv2_1'] = build_net('conv',net['pool1'],get_weight_bias(vgg_layers,5))\n", | |
" net['conv2_2'] = build_net('conv',net['conv2_1'],get_weight_bias(vgg_layers,7))\n", | |
" net['pool2'] = build_net('pool',net['conv2_2'])\n", | |
" net['conv3_1'] = build_net('conv',net['pool2'],get_weight_bias(vgg_layers,10))\n", | |
" net['conv3_2'] = build_net('conv',net['conv3_1'],get_weight_bias(vgg_layers,12))\n", | |
" net['conv3_3'] = build_net('conv',net['conv3_2'],get_weight_bias(vgg_layers,14))\n", | |
" net['conv3_4'] = build_net('conv',net['conv3_3'],get_weight_bias(vgg_layers,16))\n", | |
" net['pool3'] = build_net('pool',net['conv3_4'])\n", | |
" net['conv4_1'] = build_net('conv',net['pool3'],get_weight_bias(vgg_layers,19))\n", | |
" net['conv4_2'] = build_net('conv',net['conv4_1'],get_weight_bias(vgg_layers,21))\n", | |
" net['conv4_3'] = build_net('conv',net['conv4_2'],get_weight_bias(vgg_layers,23))\n", | |
" net['conv4_4'] = build_net('conv',net['conv4_3'],get_weight_bias(vgg_layers,25))\n", | |
" net['pool4'] = build_net('pool',net['conv4_4'])\n", | |
" net['conv5_1'] = build_net('conv',net['pool4'],get_weight_bias(vgg_layers,28))\n", | |
" net['conv5_2'] = build_net('conv',net['conv5_1'],get_weight_bias(vgg_layers,30))\n", | |
" net['conv5_3'] = build_net('conv',net['conv5_2'],get_weight_bias(vgg_layers,32))\n", | |
" net['conv5_4'] = build_net('conv',net['conv5_3'],get_weight_bias(vgg_layers,34))\n", | |
" net['pool5'] = build_net('pool',net['conv5_4'])\n", | |
" return net\n", | |
"\n", | |
"# 内容损失函数\n", | |
"def build_content_loss(p, x):\n", | |
" M = p.shape[1]*p.shape[2]\n", | |
" N = p.shape[3]\n", | |
" loss = (1./(2* N**0.5 * M**0.5 )) * tf.reduce_sum(tf.pow((x - p),2)) \n", | |
" return loss\n", | |
"\n", | |
"\n", | |
"def gram_matrix(x, area, depth):\n", | |
" x1 = tf.reshape(x,(area,depth))\n", | |
" g = tf.matmul(tf.transpose(x1), x1)\n", | |
" return g\n", | |
"\n", | |
"def gram_matrix_val(x, area, depth):\n", | |
" x1 = x.reshape(area,depth)\n", | |
" g = np.dot(x1.T, x1)\n", | |
" return g\n", | |
"\n", | |
"# 风格损失函数,A为风格标准图片,G为训练后的结果图片\n", | |
"def build_style_loss(a, x):\n", | |
" M = a.shape[1]*a.shape[2]\n", | |
" N = a.shape[3]\n", | |
" A = gram_matrix_val(a, M, N )\n", | |
" G = gram_matrix(x, M, N )\n", | |
" loss = (1./(4 * N**2 * M**2)) * tf.reduce_sum(tf.pow((G - A),2))\n", | |
" return loss\n", | |
"\n", | |
"\n", | |
"# 读取图片函数,同时做白化\n", | |
"def read_image(path):\n", | |
" image = scipy.misc.imread(path)\n", | |
" image = image[np.newaxis,:IMAGE_H,:IMAGE_W,:] \n", | |
" image = image - MEAN_VALUES\n", | |
" return image\n", | |
"\n", | |
"# 写图片函数\n", | |
"def write_image(path, image):\n", | |
" image = image + MEAN_VALUES\n", | |
" image = image[0]\n", | |
" image = np.clip(image, 0, 255).astype('uint8')\n", | |
" scipy.misc.imsave(path, image)\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## 定义主函数" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"def main():\n", | |
" net = build_vgg19(VGG_MODEL)\n", | |
" sess = tf.Session()\n", | |
" sess.run(tf.initialize_all_variables())\n", | |
"# 建立一个纯噪音图片做为训练参数,使内容符合内容图片,而风格符合风格图片\n", | |
" noise_img = np.random.uniform(-20, 20, (1, IMAGE_H, IMAGE_W, 3)).astype('float32')\n", | |
" content_img = read_image(CONTENT_IMG)\n", | |
" style_img = read_image(STYLE_IMG)\n", | |
"# 将内容图片输入到VGG网络中,取出conv4_2层输出结果,计算内容损失\n", | |
" sess.run([net['input'].assign(content_img)])\n", | |
" cost_content = sum(map(lambda l,: l[1]*build_content_loss(sess.run(net[l[0]]) , net[l[0]])\n", | |
" , CONTENT_LAYERS))\n", | |
"# 将风格图片输入到VGG网络中,取出conv1_1-conv5_1五个层的输出结果,计算风格损失\n", | |
" sess.run([net['input'].assign(style_img)])\n", | |
" cost_style = sum(map(lambda l: l[1]*build_style_loss(sess.run(net[l[0]]) , net[l[0]])\n", | |
" , STYLE_LAYERS))\n", | |
"# 加总两种损失做为最小化训练目标,用cost_style做为调整系数\n", | |
" cost_total = cost_content + STYLE_STRENGTH * cost_style\n", | |
" optimizer = tf.train.AdamOptimizer(2.0)\n", | |
"\n", | |
" train = optimizer.minimize(cost_total)\n", | |
" sess.run(tf.initialize_all_variables())\n", | |
"# 把内容图片加噪音后,做为VGG网络输入层,算法将学习去调整这个输入层,来使得训练目标最小\n", | |
" sess.run(net['input'].assign( INI_NOISE_RATIO* noise_img + (1.-INI_NOISE_RATIO) * content_img))\n", | |
"\n", | |
" if not os.path.exists(OUTOUT_DIR):\n", | |
" os.mkdir(OUTOUT_DIR)\n", | |
"\n", | |
" for i in range(500):\n", | |
" sess.run(train)\n", | |
" print i\n", | |
" if i%100 ==0:\n", | |
" result_img = sess.run(net['input'])\n", | |
" print sess.run(cost_total)\n", | |
" write_image(os.path.join(OUTOUT_DIR,'%s.png'%(str(i).zfill(4))),result_img)\n", | |
" \n", | |
" write_image(os.path.join(OUTOUT_DIR,OUTPUT_IMG),result_img)\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"main()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 2", | |
"language": "python", | |
"name": "python2" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 2 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython2", | |
"version": "2.7.11" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment