Created
December 24, 2013 05:26
-
-
Save mpkocher/8109125 to your computer and use it in GitHub Desktop.
Coroutine model for running the filter stats report.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Filter Stats Report using a Coroutine model""" | |
import os | |
import abc | |
import logging | |
import sys | |
import functools | |
import operator | |
import argparse | |
import time | |
import math | |
#from pbreports.util import setup_log | |
import itertools | |
__version__ = '2.5' | |
log = logging.getLogger(__name__) | |
_REQUIRED_HEADER_FIELDS = ('Readlength', 'ReadScore', | |
'PassedFilter', 'SequencingZMW', "ReadId") | |
# Format is { Name: ( dType, f(str->dType) ) } | |
VALID_COLUMNS = {"Movie": ("|S64", str), | |
"ReadId": ("|S64", str), | |
"#Bases": (int, int), | |
"Readlength": (int, int), | |
"ReadScore": (float, float), | |
"Productivity": (int, int), | |
"SequencingZMW": (int, int), | |
"PassedFilter": (int, int), | |
"Sandwiches": (int, int), | |
"Whitelisted": (int, int), | |
"SNR": (float, float), | |
"ArtifactScore": ("|S64", str)} | |
COLUMN_TYPES = {k: operator.getitem(v, 1) for k, v in VALID_COLUMNS.iteritems()} | |
class _BaseFilterException(Exception): | |
pass | |
class CsvParserError(_BaseFilterException): | |
pass | |
class NoFilteredReadsError(_BaseFilterException): | |
pass | |
class NoPassedFilteredReadsError(_BaseFilterException): | |
pass | |
def trace(func): | |
log.info("Calling {f}".format(f=func.__name__)) | |
def f(*args, **kwargs): | |
return func(*args, **kwargs) | |
log.info("Completed {f}.".format(f=func.__name__)) | |
return f | |
def setup_log(alog, file_name=None, level=logging.DEBUG): | |
""" | |
Util function for setting up logging. | |
Due to how smrtpipe logs, the default behavior is that the stdout | |
is where the logging is redirected. If a file name is given the log | |
will be written to that file. | |
:param log: (log instance) Log instance that handlers and filters will | |
be added. | |
:param file_name: (str, None), Path to file. If None, stdout will be used. | |
:param level: (int) logging level | |
""" | |
if file_name is None: | |
handler = logging.StreamHandler(sys.stdout) | |
else: | |
handler = logging.FileHandler(file_name) | |
str_formatter = '[%(levelname)s] %(asctime)-15s [%(name)s %(funcName)s %(lineno)d] %(message)s' | |
formatter = logging.Formatter(str_formatter) | |
handler.setFormatter(formatter) | |
alog.addHandler(handler) | |
alog.setLevel(level) | |
def _validate_resource(func, resource): | |
"""Validate the existence of a file/dir""" | |
if func(resource): | |
return os.path.abspath(resource) | |
else: | |
raise IOError("Unable to find {f}".format(f=resource)) | |
validate_file = functools.partial(_validate_resource, os.path.isfile) | |
validate_dir = functools.partial(_validate_resource, os.path.isdir) | |
validate_output_dir = functools.partial(_validate_resource, os.path.isdir) | |
class BaseAggregator(object): | |
__metaclass__ = abc.ABCMeta | |
@abc.abstractmethod | |
def apply(self, record): | |
pass | |
class CountAggregator(BaseAggregator): | |
def __init__(self, record_field, total=0): | |
self.total = total | |
self.record_field = record_field | |
def apply(self, record): | |
self.total += 1 | |
def __repr__(self): | |
return "<{k} {f} total={t}>".format(k=self.__class__.__name__, t=self.total, f=self.record_field) | |
class MinAggregator(BaseAggregator): | |
def __init__(self, record_field): | |
self.record_field = record_field | |
# | |
self.value = None | |
def apply(self, record): | |
v = getattr(record, self.record_field) | |
if self.value is None: | |
self.value = v | |
if v < self.value: | |
self.value = v | |
def __repr__(self): | |
return "<{k} {f} min={t}>".format(k=self.__class__.__name__, t=self.value, f=self.record_field) | |
class MaxAggregator(BaseAggregator): | |
def __init__(self, record_field): | |
self.record_field = record_field | |
self.value = None | |
def apply(self, record): | |
v = getattr(record, self.record_field) | |
if self.value is None: | |
self.value = v | |
if v > self.value: | |
self.value = v | |
def __repr__(self): | |
return "<{k} {f} max={t}>".format(k=self.__class__.__name__, t=self.value, f=self.record_field) | |
class SumAggregator(BaseAggregator): | |
def __init__(self, record_field, total=0): | |
self.total = total | |
self.record_field = record_field | |
def apply(self, record): | |
self.total += getattr(record, self.record_field) | |
def __repr__(self): | |
_d = dict(k=self.__class__.__name__, | |
t=self.total, | |
f=self.record_field) | |
return "<{k} {f} total={t} >".format(**_d) | |
class MeanAggregator(BaseAggregator): | |
def __init__(self, record_field): | |
self.record_field = record_field | |
self.total = 0 | |
self.nvalues = 0 | |
def apply(self, record): | |
self.nvalues += 1 | |
v = getattr(record, self.record_field) | |
self.total += v | |
@property | |
def mean(self): | |
if self.nvalues == 0: | |
return 0.0 | |
return self.total / self.nvalues | |
def __repr__(self): | |
_d = dict(n=self.nvalues, t=self.total, | |
k=self.__class__.__name__, | |
f=self.record_field, | |
m=self.mean) | |
return "<{k} {f} mean={m} nvalue={n} total={t}>".format(**_d) | |
class HistogramAggregator(BaseAggregator): | |
def __init__(self, record_field, min_value, dx, nbins=10): | |
self.record_field = record_field | |
self.min_value = min_value | |
# bin width | |
self.dx = dx | |
self.bins = [0 for _ in xrange(nbins)] | |
@property | |
def nbins(self): | |
return len(self.bins) | |
@property | |
def max_value(self): | |
return self.nbins * self.dx | |
def apply(self, record): | |
"""Adaptively compute the histogram. If there are not enough bins, | |
more will be added.""" | |
v = getattr(record, self.record_field) | |
# If value is larger than the current list of bins | |
n = int(math.ceil(v / self.dx)) | |
max_v = (self.nbins - 1) * self.dx | |
if v >= max_v: | |
delta = v - max_v | |
n_new_bins = delta / self.dx | |
i = int(math.ceil(n_new_bins)) + 2 | |
# add more bins | |
#log.info(("Adding more bins ", delta, n_new_bins, i)) | |
for _ in xrange(i): | |
self.bins.append(0) | |
#log.info((v, n, max_v )) | |
#log.info("{k} {f} Adding value {v} index={n} to nbins {b} dx {x}".format(v=v, b=self.nbins, x=self.dx, n=n, f=self.record_field, k=self.__class__.__name__)) | |
self.bins[n] += 1 | |
def __repr__(self): | |
_d = dict(k=self.__class__.__name__, | |
f=self.record_field, | |
n=self.min_value, | |
x=self.max_value, | |
dx=self.dx, | |
nbins=len(self.bins)) | |
return "<{k} {f} nbins={nbins} dx={dx} min={n} max={x}>".format(**_d) | |
def _get_header_fields_from_csv(file_name): | |
"""Peak into the CSV file and extract the column headers. | |
:raises: CsvParserError if unable to extract column headers. | |
""" | |
with open(file_name, 'r') as f: | |
header = f.readline() | |
if ',' in header.rstrip(): | |
return header.rstrip().split(',') | |
else: | |
msg = "Malformed CSV. Enable to get column headers in {h}.".format(h=header) | |
raise CsvParserError(msg) | |
def _validate_header(headers): | |
return all([field in headers for field in _REQUIRED_HEADER_FIELDS]) | |
class Record(object): | |
def __init__(self, **kwargs): | |
for k, v in kwargs.iteritems(): | |
setattr(self, k, v) | |
def __repr__(self): | |
o = " ".join([''.join([x, '=', str(getattr(self, x))]) for x in _REQUIRED_HEADER_FIELDS]) | |
_d = dict(k=self.__class__.__name__, o=o) | |
return "<{k} {o}>".format(**_d) | |
def _row_to_record(column_names, row): | |
if ',' not in row: | |
raise CsvParserError("Malformed row {r}".format(r=row)) | |
d = {} | |
for name, value in zip(column_names, row.strip().split(',')): | |
t = COLUMN_TYPES[name] | |
#print row | |
#print name, t, value | |
v = t(value) | |
d[name] = v | |
r = Record(**d) | |
return r | |
def coroutine(func): | |
"""Main deco used""" | |
def start(*args, **kwargs): | |
cr = func(*args, **kwargs) | |
cr.next() | |
return cr | |
return start | |
@coroutine | |
def broadcast(targets): | |
"""Send the record to multiple targets""" | |
while True: | |
item = (yield) | |
if item is not None: | |
for target in targets: | |
target.send(item) | |
@coroutine | |
def filter_record_by(f, target): | |
"""Filter record using a func (f)""" | |
while True: | |
record = (yield) | |
if f(record): | |
target.send(record) | |
@trace | |
@coroutine | |
def to_record(f, target): | |
"""Convert the row of the CSV to a Record instance""" | |
while True: | |
line = (yield) | |
record = f(line) | |
target.send(record) | |
@coroutine | |
def apply_aggregator(aggregator): | |
"""Call the apply method on the aggregator instance""" | |
while True: | |
record = (yield) | |
aggregator.apply(record) | |
@trace | |
def process_file(file_handle, target): | |
"""Wrapper for processing a file via pipeline.""" | |
#if not hasattr(file_handle, 'open'): | |
# raise TypeError("Expected file handle. Got type {t}".format(t=type(file_handle))) | |
for line in file_handle: | |
target.send(line) | |
def get_parser(): | |
desc = "" | |
parser = argparse.ArgumentParser(version=__version__, description=desc) | |
parser.add_argument('filter_summary_csv', type=validate_file, | |
help="Filter CSV file.") | |
parser.add_argument('-o', "--output", dest='output_dir', default=os.getcwd(), type=validate_dir, | |
help="Output directory for histogram images generated") | |
parser.add_argument('-r', '--report', dest='json_report', | |
help='Path of JSON report.') | |
parser.add_argument("--dpi", default=60, type=int, | |
help="dots/inch") | |
parser.add_argument('--debug', action='store_true', | |
help="Enable debug mode to stdout.") | |
return parser | |
def run_report(filter_csv, output_dir, base_report_name, dpi): | |
"""Main point of entry | |
The filter stats report has two main modes. | |
All Reads: (i.e., PreFilter) | |
SequencingZMW > 0 | |
- total bases | |
- total number of reads | |
- mean readlength | |
- mean readscore | |
HQ Region: (i.e., PostFilter) | |
PassedFilter > 0, SequencingZMW > 0 | |
- total bases | |
- total number of reads | |
- mean readlength | |
- mean readscore | |
Generates: | |
- Pre and Post filter ReadLength histograms with SDF (with thumbnails) | |
- Pre and Post filter ReadScore Histogram with SDF (with thumbnails) | |
- Pre and Post table of total bases, # of reads, mean readlength, mean readscore | |
""" | |
row_to_rec_func = functools.partial(_row_to_record, _get_header_fields_from_csv(filter_csv)) | |
##### General Filters | |
# General construct to create a func with signature f(record) -> Bool | |
def seq_zmw_filter_f(record): | |
if isinstance(record, Record): | |
return record.SequencingZMW > 0 | |
return False | |
def hq_filter_f(record): | |
if isinstance(record, Record): | |
return record.PassedFilter > 0 | |
return False | |
####################### Pre-Filter Aggregator(s) | |
nbases_ag = SumAggregator('#Bases') | |
nreads_ag = CountAggregator('Readlength') | |
readlength_ag = MeanAggregator('Readlength') | |
max_readlength_ag = MaxAggregator('Readlength') | |
min_readlength_ag = MinAggregator('Readlength') | |
# the histogram is adaptively computed. The min value and dx is the | |
readlength_hist_ag = HistogramAggregator('Readlength', 0, dx=10) | |
read_score_hist_ag = HistogramAggregator('ReadScore', 0, dx=0.01) | |
readscore_ag = SumAggregator('ReadScore', total=0) | |
readscore_mean_ag = MeanAggregator('ReadScore') | |
## Create/bind core Functions that can be based to the applyer method | |
# Calling these 'Models'. A model is list of filters and an aggregator | |
# Signature to _apply is ([filter1, filter2], aggregator, record) | |
# calling functools.partial returns a function signature f(record) | |
pre_agros = [nbases_ag, nreads_ag, | |
readscore_ag, readscore_mean_ag, | |
max_readlength_ag, min_readlength_ag, | |
readlength_ag, | |
readlength_hist_ag, | |
read_score_hist_ag] | |
####################### Post-Filter Aggregator(s) | |
# | |
post_nbases_ag = SumAggregator('#Bases') | |
post_nreads_ag = CountAggregator('Readlength') | |
post_readlength_ag = MeanAggregator('Readlength') | |
post_min_readlength_ag = MinAggregator('Readlength') | |
post_max_readlength_ag = MaxAggregator('Readlength') | |
# the histogram is adaptively computed. The min value and dx is the | |
post_readlength_hist_ag = HistogramAggregator('Readlength', 0, dx=10) | |
post_readscore_hist_ag = HistogramAggregator('ReadScore', 0, dx=0.01) | |
post_readscore_ag = SumAggregator('ReadScore') | |
post_readscore_mean_ag = MeanAggregator('ReadScore') | |
post_agros = [post_nbases_ag, post_nreads_ag, | |
post_readlength_ag, post_readscore_ag, | |
post_min_readlength_ag, | |
post_max_readlength_ag, | |
post_readscore_mean_ag, | |
post_readlength_hist_ag, post_readscore_hist_ag] | |
# create coroutines for each aggregator. Makes it easy to use broadcast | |
pre_agros_c = [functools.partial(apply_aggregator, a)() for a in pre_agros] | |
post_agros_c = [functools.partial(apply_aggregator, a)() for a in post_agros] | |
with open(filter_csv, 'r') as f: | |
_ = f.readline() | |
process_file(f, | |
to_record(row_to_rec_func, | |
broadcast([ | |
filter_record_by(seq_zmw_filter_f, broadcast(pre_agros_c)), | |
filter_record_by(seq_zmw_filter_f, filter_record_by(hq_filter_f, broadcast(post_agros_c)))]))) | |
# Now plot the data, makes tables from the Aggregator | |
log.info("*" * 10) | |
log.info("PreFilter Results") | |
log.info(nbases_ag) | |
log.info(nreads_ag) | |
log.info(max_readlength_ag) | |
log.info(min_readlength_ag) | |
log.info(readlength_ag) | |
log.info(readscore_ag) | |
log.info(readscore_mean_ag) | |
log.info(readlength_hist_ag) | |
log.info(read_score_hist_ag) | |
log.info("*" * 10) | |
log.info("PostFilter Results") | |
log.info(post_nbases_ag) | |
log.info(post_min_readlength_ag) | |
log.info(post_max_readlength_ag) | |
log.info(post_readlength_ag) | |
log.info(post_readscore_mean_ag) | |
log.info(post_readlength_hist_ag) | |
log.info(post_readscore_hist_ag) | |
log.info("completed runner") | |
# this should return a pbreports.model.Report instance | |
return 0 | |
def main(): | |
"""Main point of entry.""" | |
parser = get_parser() | |
args = parser.parse_args() | |
to_debug = args.debug | |
output_dir = args.output_dir | |
filter_csv = args.filter_summary_csv | |
json_report = args.json_report | |
dpi = args.dpi | |
level = logging.DEBUG if to_debug else logging.INFO | |
started_at = time.time() | |
if to_debug: | |
setup_log(log, level=level) | |
else: | |
log.addHandler(logging.NullHandler()) | |
try: | |
report = run_report(filter_csv, output_dir, dpi, json_report) | |
#report.write_json(json_report) | |
rcode = 0 | |
except Exception as e: | |
log.error(e, exc_info=True) | |
sys.stderr.write(str(e) + "\n") | |
rcode = 1 | |
run_time = time.time() - started_at | |
log.info("Completed main with returncode {r} in {s:.2f}".format(r=rcode, s=run_time)) | |
return rcode | |
if __name__ == '__main__': | |
sys.exit(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment