Created
January 4, 2018 14:12
-
-
Save tizot/4d599cc7dc6a7e5e0dfa237e29857128 to your computer and use it in GitHub Desktop.
How to update spaCy's named-entity recognizer?
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"toc": "true" | |
}, | |
"source": [ | |
"# Table of Contents\n", | |
" <p>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-01-04T13:18:25.806038Z", | |
"start_time": "2018-01-04T13:18:25.245120Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"import os\n", | |
"import random\n", | |
"import spacy\n", | |
"import tqdm" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-01-04T13:21:28.279142Z", | |
"start_time": "2018-01-04T13:21:28.273581Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"LABEL = 'MON'\n", | |
"TRAIN_DATA = [\n", | |
" (\"Google a renommé ses applications business.\", [\"U-ORG\", \"O\", \"O\", \"O\", \"O\", \"O\"]),\n", | |
" (\"Uber a dépensé 1 million d'euros cette semaine.\", [\"U-ORG\", \"O\", \"O\", \"B-MON\", \"L-MON\", \"O\", \"O\", \"O\", \"O\"]),\n", | |
" (\"Qui est Shaka Khan ?\", [\"O\", \"O\", \"B-PER\", \"L-PER\"]),\n", | |
" (\"J'aime Londres et Berlin.\", [\"O\", \"O\", \"U-LOC\", \"O\", \"U-LOC\"]),\n", | |
"]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-01-04T13:21:41.789178Z", | |
"start_time": "2018-01-04T13:21:30.976206Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"nlp = spacy.load('fr')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-01-04T13:21:41.862279Z", | |
"start_time": "2018-01-04T13:21:41.791512Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[('Google a renommé ses applications business.',\n", | |
" {'entities': [(0, 6, 'ORG')]}),\n", | |
" (\"Uber a dépensé 1 million d'euros cette semaine.\",\n", | |
" {'entities': [(0, 4, 'ORG'), (15, 24, 'MON')]}),\n", | |
" ('Qui est Shaka Khan ?', {'entities': [(8, 18, 'PER')]}),\n", | |
" (\"J'aime Londres et Berlin.\",\n", | |
" {'entities': [(7, 14, 'LOC'), (18, 24, 'LOC')]})]" | |
] | |
}, | |
"execution_count": 17, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"train_data = [(sent, {'entities': spacy.gold.offsets_from_biluo_tags(nlp(sent), tags)}) for sent, tags in TRAIN_DATA]\n", | |
"train_data" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-01-04T13:21:43.003286Z", | |
"start_time": "2018-01-04T13:21:42.920792Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[(Google, 'ORG')]\n", | |
"[(Uber, 'LOC')]\n", | |
"[(Shaka Khan, 'PER')]\n", | |
"[(Londres, 'LOC'), (Berlin, 'LOC')]\n" | |
] | |
} | |
], | |
"source": [ | |
"for sentence, _ in TRAIN_DATA:\n", | |
" doc = nlp(sentence)\n", | |
" print([(ent, ent.label_) for ent in doc.ents])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-01-04T13:21:45.531788Z", | |
"start_time": "2018-01-04T13:21:45.313070Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"ner = nlp.get_pipe('ner')\n", | |
"ner.add_label(LABEL)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-01-04T13:22:07.554649Z", | |
"start_time": "2018-01-04T13:22:01.777994Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"with nlp.disable_pipes('tagger', 'parser'):\n", | |
" optimizer = nlp.begin_training()\n", | |
" for i in range(20):\n", | |
" random.shuffle(train_data)\n", | |
" for text, annotations in train_data:\n", | |
" nlp.update([text], [annotations], sgd=optimizer)\n", | |
"#nlp.to_disk('/model')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-01-04T13:22:08.272617Z", | |
"start_time": "2018-01-04T13:22:08.178271Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[(Google, 'ORG')]\n", | |
"[(Uber, 'ORG'), (1 million, 'MON')]\n", | |
"[(Shaka Khan, 'PER')]\n", | |
"[(Londres, 'LOC'), (Berlin, 'LOC')]\n" | |
] | |
} | |
], | |
"source": [ | |
"for sentence, _ in TRAIN_DATA:\n", | |
" doc = nlp(sentence)\n", | |
" print([(ent, ent.label_) for ent in doc.ents])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.3" | |
}, | |
"toc": { | |
"colors": { | |
"hover_highlight": "#DAA520", | |
"running_highlight": "#FF0000", | |
"selected_highlight": "#FFD700" | |
}, | |
"moveMenuLeft": true, | |
"nav_menu": { | |
"height": "12px", | |
"width": "252px" | |
}, | |
"navigate_menu": true, | |
"number_sections": true, | |
"sideBar": false, | |
"threshold": 4, | |
"toc_cell": true, | |
"toc_section_display": "block", | |
"toc_window_display": false | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment