Created
November 28, 2016 02:34
-
-
Save Lyken17/da43f6c21df8139b0c043db0c753e3b7 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 91, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import optparse, sys, os\n", | |
"from collections import namedtuple\n", | |
"import random \n", | |
"import math\n", | |
"\n", | |
"import numpy as np\n", | |
"\n", | |
"import bleu" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 92, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"prefix = 'train'\n", | |
"class opts:\n", | |
" origin = os.path.join(\"data\", prefix + \".fr\")\n", | |
" reference = os.path.join(\"data\", prefix + \".en\")\n", | |
" nbest = os.path.join(\"data\", prefix + \".nbest\")\n", | |
" rich_nbest = os.path.join(\"data\", prefix + \"_rich.nbest\")\n", | |
" f_data = os.path.join(\"data\", \"hansards.fr\")\n", | |
" e_data = os.path.join(\"data\", \"hansards.en\")\n", | |
" num_sents = sys.maxint\n", | |
" iter = 10" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 93, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"translation_candidate = namedtuple(\"candidate\", \"sentence, features, bleu_score\")\n", | |
"en_ref = [line.strip() for line in open(opts.reference)]\n", | |
"fr_ori = [line.strip() for line in open(opts.origin)]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 94, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(1989, 1989)" | |
] | |
}, | |
"execution_count": 94, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"len(en_ref), len(fr_ori)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 95, | |
"metadata": { | |
"collapsed": false, | |
"scrolled": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"load 0 sentences\n", | |
"load 5000 sentences\n", | |
"load 10000 sentences\n", | |
"load 15000 sentences\n", | |
"load 20000 sentences\n", | |
"load 25000 sentences\n", | |
"load 30000 sentences\n", | |
"load 35000 sentences\n", | |
"load 40000 sentences\n" | |
] | |
} | |
], | |
"source": [ | |
"nbests = []\n", | |
"i = 0\n", | |
"for n, line in enumerate(open(opts.nbest)):\n", | |
" (i, sentence, features) = line.strip().split(\"|||\")\n", | |
" (i, sentence) = (int(i), sentence.strip())\n", | |
" features = np.array([float(h) for h in features.strip().split()])\n", | |
" \n", | |
" # select first 200 sentence only\n", | |
" if i >= 200:\n", | |
" break\n", | |
" \n", | |
" s = sentence.split()\n", | |
" r = en_ref[i].split()\n", | |
" \n", | |
" # compute bleu score\n", | |
" stats = [0 for _ in xrange(10)]\n", | |
" bleu_stats = [sum(scores) for scores in zip(stats, bleu.bleu_stats(s, r))]\n", | |
" bleu_score = bleu.smoothed_bleu(bleu_stats)\n", | |
" \n", | |
" if len(en_ref) <= i:\n", | |
" error(\"not enough english reference\")\n", | |
" while len(nbests) <= i:\n", | |
" nbests.append([])\n", | |
" \n", | |
" nbests[i].append(translation_candidate(sentence, features, bleu_score))\n", | |
" \n", | |
" if n % 5000 == 0:\n", | |
" print (\"load %d sentences\" % n)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 96, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"the american president barack obama will fly into oslo , norway for 26 hours to receive the nobel peace prize , the fourth american president in history to do so .\n", | |
"the american president barack obama will fly away for oslo , norway , to spend 26 hours and the fourth american president in history to receive the nobel peace prize .\n" | |
] | |
} | |
], | |
"source": [ | |
"# test BLEU scores\n", | |
"idx = 1\n", | |
"print en_ref[idx]\n", | |
"print sorted(nbests[idx], key=lambda s:s.bleu_score, reverse=True)[0].sentence" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 101, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"# parameters for perceptron\n", | |
"tau = 5000\n", | |
"alpha = 0.1\n", | |
"xi = 100\n", | |
"eta = 0.5\n", | |
"epochs = 15\n", | |
"theta = np.array([1.0 / len(features) for _ in range(len(features))])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 102, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch 0, mistakes : 1251\n", | |
"Epoch 1, mistakes : 1303\n", | |
"Epoch 2, mistakes : 1232\n", | |
"Epoch 3, mistakes : 1284\n", | |
"Epoch 4, mistakes : 1319\n", | |
"Epoch 5, mistakes : 1330\n", | |
"Epoch 6, mistakes : 1307\n", | |
"Epoch 7, mistakes : 1287\n", | |
"Epoch 8, mistakes : 1317\n", | |
"Epoch 9, mistakes : 1241\n", | |
"Epoch 10, mistakes : 1299\n", | |
"Epoch 11, mistakes : 1302\n", | |
"Epoch 12, mistakes : 1245\n", | |
"Epoch 13, mistakes : 1280\n", | |
"Epoch 14, mistakes : 1286\n" | |
] | |
} | |
], | |
"source": [ | |
"for i in range(epochs):\n", | |
" mistakes = 0\n", | |
" for nbest in nbests:\n", | |
" \n", | |
" def get_sample():\n", | |
" sample_list = []\n", | |
" for _ in range(tau):\n", | |
" s1, s2 = random.sample(nbest, 2)\n", | |
" if math.fabs(s1.bleu_score - s2.bleu_score) > alpha:\n", | |
" if s1.bleu_score > s2.bleu_score:\n", | |
" sample_list.append((s1, s2))\n", | |
" else:\n", | |
" sample_list.append((s2, s1))\n", | |
" else:\n", | |
" continue\n", | |
" return sample_list\n", | |
" \n", | |
" samples = sorted(get_sample(), key=lambda s:s[0].bleu_score - s[1].bleu_score, reverse=True)[:xi]\n", | |
" # TODO: shuffle samples here\n", | |
" \n", | |
" for s1, s2 in samples:\n", | |
" if np.dot(s1.features, theta) <= np.dot(s2.features, theta):\n", | |
" mistakes += 1\n", | |
" theta += eta * (s1.features - s2.features)\n", | |
" \n", | |
" print(\"Epoch %d, mistakes : %d\" % (i, mistakes))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 105, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(array([-23.31 , -2. , -5. , -3. , -11. , -4.777]),\n", | |
" array([-0.56183333, -1.83333333, 2.16666667, -1.83333333, -0.83333333,\n", | |
" -0.37633333]))" | |
] | |
}, | |
"execution_count": 105, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"s1.features, theta" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 106, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"22.39407933337624" | |
] | |
}, | |
"execution_count": 106, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"np.dot(s1.features, theta)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 73, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[candidate(sentence='barack obama will be the fourth american president to receive the winner of the nobel peace prize', features=array([-31.637, -1. , -5. , -2. , -17. , -7.383]), bleu_score=0.5879295158304544),\n", | |
" candidate(sentence='barack obama would be the fourth american president to receive the nobel peace prize', features=array([-26.311, -1. , -5. , -2. , -14. , -6.08 ]), bleu_score=0.7399596201502179)]" | |
] | |
}, | |
"execution_count": 73, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"random.sample(nbests[0], 2)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 82, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"barack obama will be the fourth american president to receive the nobel peace prize laureate\n", | |
"barack obama becomes the fourth american president to receive the nobel peace prize\n", | |
"0.696798457126\n" | |
] | |
} | |
], | |
"source": [ | |
"idx = 0\n", | |
"\n", | |
"sentence = random.choice(nbests[idx]).sentence\n", | |
"s = sentence.split()\n", | |
"r = en_ref[idx].split()\n", | |
"\n", | |
"print sentence\n", | |
"print en_ref[idx]\n", | |
"\n", | |
"stats = [0 for _ in xrange(10)]\n", | |
"stats = [sum(scores) for scores in zip(stats, bleu.bleu_stats(s,r))]\n", | |
"print bleu.smoothed_bleu(stats)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 76, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"candidate(sentence='barack obama will be the fourth american president to receive the winner of the nobel peace prize', features=array([-31.637, -1. , -5. , -2. , -17. , -7.383]), bleu_score=0.5879295158304544)" | |
] | |
}, | |
"execution_count": 76, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"random.choice(nbests[idx])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 2", | |
"language": "python", | |
"name": "python2" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 2 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython2", | |
"version": "2.7.6" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment