Created
August 30, 2019 13:58
-
-
Save romanegloo/07151d6f11ac8eb15bef785601e3379a to your computer and use it in GitHub Desktop.
A script that reads MeSH descriptors and PubTator doc data from data files and create SQLite database to store the encoded docs for later training uses. (deprecated)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
"""Preprocess PubTator corpus and ScopeNotes of MeSH descriptors for language | |
model training (LmBMET). | |
1) Given the original PubTator biocepts annotated documents, this interpolates | |
the concept codes into document texts. Before that, this will count word | |
frequencies and generate vocabulary which will include the entire set of | |
bioconcepts (MeSH in particular). In case that a pre-trained embeddings file | |
(.vec) is provided, we obtain a vocabulary from the embeddings. | |
2) Along with the PubTator documents and its vocabulary, this script will parse | |
MeSH descriptors, extract definitions (concept name and scope note), and store t | |
hem in another table "MeSH" | |
To make the set of MeSH terms complete, we add all the concepts and their | |
definitions into the training dataset. | |
Raw datafiles can be obtained from the following links: | |
- PubTator: ftp://ftp.ncbi.nlm.nih.gov/pub/lu/PubTator/ | |
- MeSH Descriptors: | |
""" | |
from typing import List | |
import argparse | |
import sqlite3 | |
import time | |
from pathlib import Path | |
from multiprocessing import Pool | |
from functools import partial | |
import gzip | |
from collections import Counter | |
from bounter import bounter | |
from lxml import etree | |
from tqdm import tqdm | |
import logging | |
import spacy | |
import utils as tp_utils | |
from data import Vocab | |
# ------------------------------------------------------------------------------ | |
# Vocabulary | |
# ------------------------------------------------------------------------------ | |
def build_vocab(): | |
"""Read datafile by chunk, count word frequencies and update vocabs""" | |
mesh_keys = ['εmesh_' + ent for ent in meshes] | |
vocab_options = { | |
'special': ['<unk>', '<eod>'], | |
'concepts': mesh_keys, | |
'min_freq': 10, | |
'max_size': 200000 | |
} | |
vocab_ = Vocab(**vocab_options) | |
agg = bounter(size_mb=1024) # Counter is inefficient with a large key set | |
# Ready docs by chunk | |
logger.info('Reading documents to build a vocabulary...') | |
# Read docs | |
docs = [] | |
timer = tp_utils.Timer() | |
def cb_aggregate_counts(res): | |
agg.update(res) | |
p = Pool() | |
fn = partial(mp_parse_docs, nlp) | |
cnt = 0 | |
with gzip.open(args.pubtator_file, 'rt') as f: | |
for doc in tp_utils.read_pubtator_doc(f, fields=['title', 'body']): | |
docs.append(' '.join(doc)) | |
cnt += 1 | |
if cnt % 5000 == 0: | |
print('Counting words in {} docs...\r'.format(cnt), end='') | |
p.apply_async(fn, args=(docs, ), callback=cb_aggregate_counts) | |
docs = [] | |
# Remainder | |
p.apply_async(fn, args=(docs, ), callback=cb_aggregate_counts) | |
p.close() | |
p.join() | |
logger.info('agg cardinality %d' % agg.cardinality()) | |
vocab_.counter.update({k: v for k, v in agg.items()}) | |
agg = None | |
logger.info('Completed counting words of %d PubTator documents ' | |
'in %.2f secs' % (cnt-1, timer.time())) | |
logger.debug('20 most common tokens: [%s]' % vocab_.counter.most_common(20)) | |
vocab_.build_vocab() | |
logger.debug('idx2sym %s' % vocab_.idx2sym[:10]) | |
return vocab_, cnt-1 | |
def read_vocab(vocab_size=-1): | |
"""Read a vocabulary from a pre-trained embeddings file (.vec)""" | |
# Separate concepts from symbols | |
concepts = [] | |
symbols = [] | |
with args.wbmet.open() as f: | |
for line in f: | |
w = line.split()[0] | |
if w.startswith('εmesh_'): | |
concepts.append(w) | |
else: | |
symbols.append(w) | |
logger.info('{} concepts and {} regular words found in the given embeddings' | |
''.format(len(concepts), len(symbols))) | |
vocab = Vocab(special=['<unk>', '<eod>'], concepts=concepts) | |
n_reg_words = len(symbols) if vocab_size < 0\ | |
else max(0, vocab_size-len(vocab.special)) | |
logger.info('Adding {} special codes and {} regular words...' | |
''.format(len(vocab.special), n_reg_words)) | |
for sym in symbols: | |
if vocab_size < 0: | |
vocab.add_symbol(sym) | |
else: | |
if len(vocab) < vocab_size: | |
vocab.add_symbol(sym) | |
else: | |
break | |
return vocab | |
# ------------------------------------------------------------------------------ | |
# Parsing PubTator | |
# ------------------------------------------------------------------------------ | |
def mp_parse_docs(nlp, docs: List[str]): | |
"""Parse documents for counting word frequencies""" | |
word_freq = Counter() | |
for doc in docs: | |
word_freq.update([t.lower_ for t in nlp(doc) if not t.is_space]) | |
return word_freq | |
def mp_encode_docs(nlp, vocab, docs): | |
"""Encode given list of texts (which are in PubTator structure), do text | |
preprocess, interpolate entities, and convert them into word-level | |
vocabulary indices | |
:param nlp: SpaCy client | |
:param vocab: vocabulary built with PubTator documents and MeSH descriptors | |
:param docs: list of docs | |
""" | |
encoded_docs = [] | |
for pmid, doc in docs: | |
encoded_doc = [] | |
for t in nlp(doc): | |
sym = t.text | |
if not t.is_space: | |
if sym in vocab.sym2idx: | |
encoded_doc.append(vocab.sym2idx[sym]) | |
else: | |
encoded_doc.append(vocab.sym2idx['<unk>']) | |
encoded_doc.append(vocab.sym2idx['<eod>']) | |
encoded_docs.append((pmid, ' '.join(map(str, encoded_doc)))) | |
return encoded_docs | |
def encode_pubtator_docs(): | |
logger.info('Encoding PubTator documents...') | |
# Initialize DB | |
conn = sqlite3.connect(args.db_file.as_posix(), check_same_thread=False) | |
c = conn.cursor() | |
c.execute('CREATE TABLE documents (pmid INTEGER PRIMARY KEY, ' | |
'enc_text TEXT);') | |
conn.commit() | |
if num_docs > 0: | |
pbar = tqdm(total=num_docs) | |
cnt_done = 0 | |
def cb_insert_encoded_docs(res): | |
nonlocal cnt_done | |
cnt_done += len(res) | |
c = conn.cursor() | |
c.executemany("INSERT INTO documents VALUES (?,?)", res) | |
conn.commit() | |
if num_docs > 0: | |
pbar.update(len(res)) | |
else: | |
print('{} docs inserted...\r'.format(cnt_done), end='') | |
# Read docs | |
docs = [] | |
batch_size = 5000 | |
cnt_done = 0 | |
fn = partial(mp_encode_docs, nlp, vocab) | |
p = Pool() | |
with gzip.open(args.pubtator_file, 'rt') as f: | |
for pmid, doc in tp_utils.read_pubtator_doc(f, annotate=True): | |
docs.append((pmid, doc)) | |
cnt_done += 1 | |
if cnt_done % batch_size == 0: | |
p.apply_async(fn, args=(docs, ), | |
callback=cb_insert_encoded_docs) | |
docs = [] | |
p.close() | |
p.join() | |
if num_docs > 0: | |
pbar.close() | |
conn.close() | |
# ------------------------------------------------------------------------------ | |
# Parsing MeSH concepts | |
# ------------------------------------------------------------------------------ | |
def encode_mesh_definitions(nlp, vocab, meshes): | |
logger.info('Encoding MeSH definitions...') | |
# Initialize DB; create the 'mesh' table | |
conn = sqlite3.connect(args.db_file.as_posix(), check_same_thread=False) | |
c = conn.cursor() | |
c.execute('CREATE TABLE mesh (mshid TEXT PRIMARY KEY, enc_doc TEXT);') | |
conn.commit() | |
entries = [] | |
for k, (name, note) in meshes.items(): | |
tokens = [t.lower_ for t in nlp(name + ' ' + note) if not t.is_space] | |
tok_indices = [vocab.sym2idx[t] | |
if t in vocab.sym2idx else vocab.sym2idx['<unk>'] | |
for t in tokens] | |
entries.append((k, ' '.join(map(str, tok_indices)))) | |
c.executemany('INSERT INTO mesh VALUES (?, ?)', entries) | |
conn.commit() | |
# RUN~! | |
if __name__ == '__main__': | |
# Logger | |
logging.basicConfig(level=logging.INFO, | |
format='[%(asctime)s-%(levelname)s] -- %(message)s') | |
logger = logging.getLogger() | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--pubtator', type=str, | |
default='data/pubtator/bioconcepts2pubtator_offsets.gz', | |
help='Path to the PubTator data file') | |
parser.add_argument('--mesh', type=str, default='data/mesh/desc2019.gz', | |
help='Path to the MeSH descriptors data file') | |
parser.add_argument('--wbmet', type=str, default='', | |
help='Path to a pre-trained wbmet embeddings file') | |
parser.add_argument('--vocab_size', type=int, default=-1, | |
help='Limit the max vocabulary size. ' | |
'If -1, include all seen words.') | |
args = parser.parse_args() | |
# Path setup | |
cwd = Path(__file__).resolve().parent | |
args.pubtator_file = cwd / Path(args.pubtator) | |
args.mesh_file = cwd / Path(args.mesh) | |
args.wbmet = None if args.wbmet == '' else cwd / Path(args.wbmet) | |
args.db_file = cwd / 'data/pubtator-{}.db'.format(time.strftime('%m%d_%H%M')) | |
logger.info('Creating corpus DB [{}]'.format(args.db_file)) | |
# Spacy client | |
nlp = spacy.load('en', disable=['parser', 'ner', 'tagger']) | |
# Read MeSH descriptors and instantiate a vocabulary with the concepts | |
logger.info("Reading MeSH terms into a vocabulary...") | |
data = etree.parse(gzip.open(args.mesh_file, 'rt')) | |
meshes = {} | |
for rec in data.getiterator("DescriptorRecord"): | |
mshid = rec.find("DescriptorUI").text.lower() | |
name = rec.find("DescriptorName/String").text | |
elm = rec.find('ConceptList/Concept[@PreferredConceptYN="Y"]/ScopeNote') | |
scope_note = elm.text if elm is not None else '' | |
meshes[mshid] = [name.strip(), scope_note.strip()] | |
# Build a vocabulary | |
if args.wbmet is not None: | |
vocab = read_vocab(vocab_size=args.vocab_size) | |
num_docs = -1 | |
else: | |
vocab, num_docs = build_vocab() | |
logger.info('Saving the vocabulary into a corpus database...') | |
vocab.save2db(args.db_file.as_posix()) | |
# Encode entire PubTator documents | |
encode_pubtator_docs() | |
# Encode MeSH definitions | |
encode_mesh_definitions(nlp, vocab, meshes) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment