Created
June 21, 2017 16:06
-
-
Save cjw85/56f6aae3d3ce8995af6f8173b6f4eb07 to your computer and use it in GitHub Desktop.
Implementation of nanonet using keras.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Reimplementation of nanonet using keras. | |
Follow the instructions at | |
https://www.tensorflow.org/install/install_linux | |
to setup an NVIDIA GPU with CUDA8.0 and cuDNN v5.1. | |
virtualenv venv --python=python3 | |
. venv/bin/activate | |
pip install numpy | |
pip install git+https://github.com/nanoporetech/nanonet@e8ff1edf | |
pip install --upgrade tensorflow-gpu keras numpy | |
python keras_call.py --help | |
Reuses bits of nanonet for peripheral calculations and decoding. Overall speed | |
is limited by decoding step. Reads are chunked into a maximum of 1000 feature | |
vectors for processing on GPU. Batch sizes are set for GPU with 11GB. Stitching | |
together of read chunks is not performed. On a AWS K80 GPU, the network performs | |
at around 140 feature vectors per second. | |
For training files should have data labelled as nanonet requires, or can be hacked | |
in a similar fashion to nanonet. | |
Results should be roughly equivalent to nanonet, only the LSTM implementation | |
used does not contain peepholes. | |
This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. | |
If a copy of the MPL was not distributed with this file, You can obtain one at | |
http://mozilla.org/MPL/2.0/. | |
(c) 2017 Oxford Nanopore Technologies Ltd. | |
""" | |
import argparse | |
import errno | |
import os | |
import sys | |
from glob import glob | |
import multiprocessing | |
import itertools | |
from collections import Counter | |
import numpy as np | |
from keras.utils import to_categorical | |
from keras.models import Sequential, model_from_json | |
from keras.layers import LSTM, Dense | |
from keras.layers.wrappers import Bidirectional | |
from keras.callbacks import ModelCheckpoint, CSVLogger, TensorBoard, EarlyStopping | |
from nanonet.features import * | |
from nanonet.nanonetcall import * | |
from nanonet.fast5 import Fast5 | |
from nanonet.util import all_nmers | |
from nanonet.eventdetection.filters import minknow_event_detect | |
from nanonet.util import tang_imap | |
import logging | |
def mkdir_p(path, info=None): | |
"""Make a directory if it doesn't exist.""" | |
try: | |
os.makedirs(path) | |
except OSError as exc: # Python >2.5 | |
if exc.errno == errno.EEXIST and os.path.isdir(path): | |
if info is not None: | |
info = " {}".format(info) | |
logging.warn("The path {} exists.{}".format(path, info)) | |
pass | |
else: | |
raise | |
def grouper(iterable, n): | |
"""Yield fixed size chunks of an iterable. Remainder is not padded.""" | |
it = iter(iterable) | |
while True: | |
chunk = tuple(itertools.islice(it, n)) | |
if not chunk: | |
return | |
yield chunk | |
def fast5_to_features(fast5_files, section='template', window=[-1, 0, 1], event_detect=True, | |
ed_params={'window_lengths':[3, 6], 'thresholds':[1.4, 1.1], 'peak_height':0.2}, sloika_model=False): | |
"""Generate features from scratch (not using mapping event data as in training).""" | |
skipped = 0 | |
for f in fast5_files: | |
try: | |
with Fast5(f) as fh: | |
if event_detect: | |
raw = fh.get_read(raw=True) | |
events = minknow_event_detect( | |
raw, fh.sample_rate, **ed_params | |
) | |
else: | |
events = fh.get_read() | |
except Exception as e: | |
skipped += 1 | |
continue | |
try: | |
X = events_to_features(events, window=window, sloika_model=sloika_model) | |
except TypeError: | |
skipped += 1 | |
continue | |
yield f, X | |
logging.info("Skipped generating features for {} reads.".format(skipped)) | |
def create_labels(kmer_len, alphabet): | |
kmers = all_nmers(kmer_len, alpha=alphabet) | |
bad_kmer = 'X'*kmer_len | |
kmers.append(bad_kmer) | |
all_kmers = {k:i for i,k in enumerate(kmers)} | |
return kmers, all_kmers | |
def make_training_input(fast5_files, window=[-1, 0, 1], kmer_len=5, alphabet='ACGT', chunk_size=1000, min_chunk=1000, trim=10, get_events=get_events_ont_mapping, get_labels=get_labels_ont_mapping, callback_kwargs={'section':'template', 'kmer_len':5}): | |
"""Generating training input, adapted from nanonet's equivalent. | |
:param fast5_list: list of .fast5 files to process. | |
:param window: event window to derive features. | |
:param kmer_len: length of kmers to learn. | |
:param alphabet: alphabet of kmers. | |
:param chunk_size: chunk size to break reads into for SGE batching. | |
:param min_chunk: minimum chunk size (used to discard remainder of reads. | |
:param trim: no. of feature vectors to trim (from either end). | |
:param get_events: callback to return event data, will be passed .fast5 filename. | |
:param get_labels: callback to return event kmer labels, will be passed .fast5 filename. | |
:param callback_kwargs: kwargs for both `get_events` and `get_labels`. | |
:returns: dictionary of structure: {filename:(features, labels)}. labels | |
are state indices which will need transforming accoring to keras loss | |
function that is used. | |
""" | |
# Our state labels are kmers plus a junk kmer | |
kmers, all_kmers = create_labels(kmer_len, alphabet) | |
data = dict() | |
for i, f in enumerate(fast5_files): | |
try: | |
# Run callbacks to get features and labels | |
X = events_to_features(get_events(f, **callback_kwargs), window=window) | |
labels = get_labels(f, **callback_kwargs) | |
except: | |
logging.debug("Couldn't get features/labels: {}".format(f)) | |
continue | |
try: | |
X = X[trim:-trim] | |
labels = labels[trim:-trim] | |
if len(X) != len(labels): | |
raise RuntimeError('Length of features and labels not equal.') | |
except: | |
logging.debug("Feature/labels bad: {}".format(f)) | |
try: | |
# convert kmers to ints | |
y = np.fromiter( | |
(all_kmers[k] for k in labels), | |
dtype=np.int16, count=len(labels) | |
) | |
except Exception as e: | |
# Checks for erroneous alphabet or kmer length | |
raise RuntimeError( | |
'Could not convert kmer labels to ints in file {}. ' | |
'Check labels are no longer than {} and contain only {}'.format(f, kmer_len, alphabet) | |
) | |
else: | |
for chunk, (X_chunk, y_chunk) in enumerate(zip(chunker(X, chunk_size), chunker(y, chunk_size))): | |
if len(X_chunk) < min_chunk: | |
break | |
ident = '{}_{}'.format(f, chunk) | |
data[ident] = (X_chunk, y_chunk) | |
return data | |
def generate_features(fast5_files, jobs=multiprocessing.cpu_count()): | |
"""Generate training features and labels using multi-processing.""" | |
all_data = dict() | |
logging.info("Processing {} files.".format(len(fast5_files))) | |
n_processed = 0 | |
files_per_worker = 100 | |
file_gen = (list(x) for x in grouper(fast5_files, files_per_worker)) | |
for i, data in enumerate(tang_imap(make_training_input, file_gen, unordered=True, threads=jobs)): | |
all_data.update(data) | |
n_processed += len(data) | |
logging.info("Processed {} read chunks ({} files).".format(n_processed, i*files_per_worker)) | |
logging.info("Finished generating features.") | |
return all_data | |
def build_model(timesteps, data_dim, num_classes): | |
"""Builds a nanonet-style graph. | |
The keras LSTM implementation follows Graves 2013 (with forget gates | |
with bias equal 1). Usually we add-in peepholes. | |
""" | |
model = Sequential() | |
layer_size = 96 | |
# Bidirectional wrapper takes a copy of the first argument and reverses | |
# the direction. Weights are independent between components. | |
model.add(Bidirectional( | |
LSTM(96, return_sequences=True, name='lstm1', implementation=2), | |
input_shape=(timesteps, data_dim) | |
)) | |
model.add(Dense(128, activation='tanh', name='ff1')) | |
model.add(Bidirectional( | |
LSTM(96, return_sequences=True, name='lstm2', implementation=2) | |
)) | |
model.add(Dense(128, activation='tanh', name='ff2')) | |
model.add(Dense(num_classes, activation='softmax', name='classify')) | |
return model | |
def save_model(fname, model): | |
"""Save model definition.""" | |
with open(fname, 'w') as json_file: | |
json_file.write(model.to_json()) | |
def load_model(structure, weights): | |
"""Load a model from .json file with weights initilized from .hdf.""" | |
with open(structure) as json_file: | |
model_json = json_file.read() | |
model = model_from_json(model_json) | |
model.load_weights(weights) | |
return model | |
def save_feature_file(fname, data): | |
"""Save feature dictionary.""" | |
np.save(fname, data) | |
def load_feature_file(fname): | |
"""Load the result of `save_feature_file` back to the original | |
representation. | |
""" | |
data = dict() | |
src = np.load(fname) | |
fnames = src[()].keys() | |
for fname in src[()].keys(): | |
data[fname] = src[()][fname] | |
return data | |
def run_training(train_name, x_train, y_train, num_classes, model_data=None): | |
"""Run training.""" | |
data_dim = x_train.shape[2] | |
timesteps = x_train.shape[1] | |
if model_data is None: | |
model = build_model(timesteps, data_dim, num_classes) | |
else: | |
model = load_model(*model_data) | |
#TODO: should check model data dimensions match data dimensions | |
logging.info("data_dim:", data_dim, 'time_steps:', timesteps, "num_classes:", num_classes) | |
logging.indo("\n{}".format(model.summary())) | |
save_model(os.path.join(train_name, 'model_structure.json'), model) | |
callbacks = [ | |
# Best model according to training set accuracy | |
ModelCheckpoint(os.path.join(train_name, 'weights.best.hdf5'), | |
monitor='acc', verbose=1, save_best_only=True, mode='max'), | |
# Best model according to validation set accuracy | |
ModelCheckpoint(os.path.join(train_name, 'weights.best.val.hdf5'), | |
monitor='val_acc', verbose=1, save_best_only=True, mode='max'), | |
# Checkpoints when training set accuracy improves | |
ModelCheckpoint(os.path.join(train_name, 'weights-improvement-{epoch:02d}-{acc:.2f}.hdf5'), | |
monitor='acc', verbose=1, save_best_only=True, mode='max'), | |
# Stop when no improvement, patience is number of epochs to allow no improvement | |
EarlyStopping(monitor='val_loss', patience=20), | |
# Log of epoch stats | |
CSVLogger(os.path.join(train_name, 'training.log')), | |
# Allow us to run tensorboard to see how things are going. Some | |
# features require validation data, not clear why. | |
TensorBoard(log_dir=os.path.join(train_name, 'logs'), | |
histogram_freq=5, batch_size=100, write_graph=True, write_grads=True, write_images=True) | |
] | |
model.compile( | |
loss='sparse_categorical_crossentropy', | |
optimizer='rmsprop', | |
metrics=['accuracy'], | |
) | |
# maybe possible to increase batch_size for faster processing | |
model.fit( | |
x_train, y_train, | |
batch_size=100, epochs=5000, | |
validation_split=0.2, | |
callbacks=callbacks, | |
) | |
def train(args): | |
"""Training program.""" | |
train_name = args.train_name | |
mkdir_p(train_name, info='Results will be overwritten and may use pregenerated features.') | |
dataset_name = os.path.join(train_name, '{}_squiggles.npy'.format(train_name)) | |
logging.info("Using {} for feature storage/reading.".format(dataset_name)) | |
fast5s = glob(os.path.join(args.fast5_path, "*.fast5"))[:2000] | |
logging.info("Found {} input files.".format(len(fast5s))) | |
if not os.path.isfile(dataset_name): | |
logging.info("Creating dataset. This may take a while.") | |
data = generate_features(fast5s) | |
save_feature_file(dataset_name, data) | |
else: | |
logging.info("Loading dataset from file.") | |
data = load_feature_file(dataset_name) | |
logging.info("Got {} squiggle chunks for training.".format(len(data))) | |
x_data = [] | |
y_labels = [] | |
for fname in data.keys(): | |
x_data.append(data[fname][0]) | |
# this is the form required by keras' sparse_categorical_crossentropy | |
y_labels.append([[yi,] for yi in data[fname][1]]) | |
# stack the individual samples into one big tensor | |
x_data = np.stack(x_data) | |
y_labels = np.stack(y_labels) | |
num_classes = 1025 #TODO: obtain this from somewhere | |
run_training(train_name, x_data, y_labels, num_classes, model_data=args.model_data) | |
def post_to_call(post, min_prob=1e-5): | |
kmers, _ = create_labels(5, 'ACGT') | |
post, good_events = clean_post(post, kmers, min_prob) | |
if post is None: | |
return None | |
# Decode kmers | |
score, states = decoding.decode_homogenous(post, log=False) | |
# Form basecall | |
kmers = [x for x in kmers if 'X' not in x] | |
qdata = get_qdata(post, kmers) | |
seq, qual, kmer_path = form_basecall(qdata, kmers, states) | |
return seq, qual | |
def run_prediction(data, model_structure, model_weights, output_file='basecalls.fasta'): | |
"""Run inference, doesn't do basecalling for now, just exercises the network.""" | |
from timeit import default_timer as now | |
model = load_model(model_structure, model_weights) | |
logging.info('\n{}'.format(model.summary())) | |
t0 = now() | |
class_probs = model.predict(data, batch_size=1500, verbose=1) | |
t1 = now() | |
logging.info('Running network took {}s for data of shape {}'.format(t1 - t0, data.shape)) | |
t0 = now() | |
count = 0 | |
with open(output_file, 'w') as fasta: | |
for i, seq in enumerate(tang_imap(post_to_call, class_probs, unordered=True, threads=multiprocessing.cpu_count())): | |
if seq is not None: | |
count += 1 | |
fasta.write((">block_{}\n{}\n".format(i, seq))) | |
t1 = now() | |
logging.info('Decoding took {}s for {} blocks.'.format(t1 - t0, count)) | |
def predict(args): | |
"""Inference program.""" | |
if args.fast5_path: | |
fast5s = glob(os.path.join(args.fast5_path, "*.fast5"))[:500] | |
logging.info("Found {} input files.".format(len(fast5s))) | |
logging.info("Creating dataset. This may take a while.") | |
data = generate_features(fast5s) | |
else: | |
logging.info("Loading dataset from file.") | |
data = load_feature_file(args.feature_file) | |
logging.info("Got {} squiggle chunks for training.".format(len(data))) | |
for x in data: | |
break | |
x_data = np.stack((x[0] for x in data.values())) | |
run_prediction(x_data, args.model, args.weights) | |
def main(): | |
logging.basicConfig(format='[%(asctime)s - %(name)s] %(message)s', datefmt='%H:%M:%S', level=logging.INFO) | |
parser = argparse.ArgumentParser('Squiggle Demultiplexer') | |
subparsers = parser.add_subparsers(title='subcommands', description='valid commands', help='additional help', dest='command') | |
subparsers.required = True | |
tparser = subparsers.add_parser('train', help='Train a model from labelled squiggles.') | |
tparser.set_defaults(func=train) | |
tparser.add_argument('fast5_path', help='Path for training fast5.') | |
tparser.add_argument('--train_name', type=str, default='keras_train', help='Name for training run.') | |
tparser.add_argument('--model_data', nargs=2, metavar=('def.json', 'weights.hdf'), help='Model definition and initial weights.') | |
pparser = subparsers.add_parser('predict', help='Create a ZMQ router (client).') | |
pparser.set_defaults(func=predict) | |
pparser.add_argument('model', help='Model structure json file from training.') | |
pparser.add_argument('weights', help='Model weights HDF5 file from training.') | |
ingroup = pparser.add_mutually_exclusive_group(required=True) | |
ingroup.add_argument('--fast5_path', help='Path fast5 files.') | |
ingroup.add_argument('--feature_file', help='Pregenerated features as stored during training.') | |
args = parser.parse_args() | |
args.func(args) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment