Last active
December 31, 2015 20:59
-
-
Save mpkocher/8043965 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import abc | |
import logging | |
import sys | |
import functools | |
import operator | |
import argparse | |
import time | |
import math | |
#from pbreports.util import setup_log | |
__version__ = '2.5' | |
log = logging.getLogger(__name__) | |
_REQUIRED_HEADER_FIELDS = ('Readlength', 'ReadScore', | |
'PassedFilter', 'SequencingZMW', "ReadId") | |
# Format is { Name: ( dType, f(str->dType) ) } | |
VALID_COLUMNS = {"Movie": ("|S64", str), | |
"ReadId": ("|S64", str), | |
"#Bases": (int, int), | |
"Readlength": (int, int), | |
"ReadScore": (float, float), | |
"Productivity": (int, int), | |
"SequencingZMW": (int, int), | |
"PassedFilter": (int, int), | |
"Sandwiches": (int, int), | |
"Whitelisted": (int, int), | |
"SNR": (float, float), | |
"ArtifactScore": ("|S64", str)} | |
COLUMN_TYPES = {k: operator.getitem(v, 1) for k, v in VALID_COLUMNS.iteritems()} | |
class _BaseFilterException(Exception): | |
pass | |
class CsvParserError(_BaseFilterException): | |
pass | |
class NoFilteredReadsError(_BaseFilterException): | |
pass | |
class NoPassedFilteredReadsError(_BaseFilterException): | |
pass | |
def setup_log(alog, file_name=None, level=logging.DEBUG): | |
""" | |
Util function for setting up logging. | |
Due to how smrtpipe logs, the default behavior is that the stdout | |
is where the logging is redirected. If a file name is given the log | |
will be written to that file. | |
:param log: (log instance) Log instance that handlers and filters will | |
be added. | |
:param file_name: (str, None), Path to file. If None, stdout will be used. | |
:param level: (int) logging level | |
""" | |
if file_name is None: | |
handler = logging.StreamHandler(sys.stdout) | |
else: | |
handler = logging.FileHandler(file_name) | |
str_formatter = '[%(levelname)s] %(asctime)-15s [%(name)s %(funcName)s %(lineno)d] %(message)s' | |
formatter = logging.Formatter(str_formatter) | |
handler.setFormatter(formatter) | |
alog.addHandler(handler) | |
alog.setLevel(level) | |
def _validate_resource(func, resource): | |
"""Validate the existence of a file/dir""" | |
if func(resource): | |
return os.path.abspath(resource) | |
else: | |
raise IOError("Unable to find {f}".format(f=resource)) | |
validate_file = functools.partial(_validate_resource, os.path.isfile) | |
validate_dir = functools.partial(_validate_resource, os.path.isdir) | |
validate_output_dir = functools.partial(_validate_resource, os.path.isdir) | |
class BaseAggregator(object): | |
__metaclass__ = abc.ABCMeta | |
@abc.abstractmethod | |
def apply(self, record): | |
pass | |
class CountAggregator(BaseAggregator): | |
def __init__(self, record_field, total=0): | |
self.total = total | |
self.record_field = record_field | |
def apply(self, record): | |
self.total += 1 | |
def __repr__(self): | |
return "<{k} {f} total={t}>".format(k=self.__class__.__name__, t=self.total, f=self.record_field) | |
class SumAggregator(BaseAggregator): | |
def __init__(self, record_field, total=0): | |
self.total = total | |
self.record_field = record_field | |
def apply(self, record): | |
self.total += getattr(record, self.record_field) | |
def __repr__(self): | |
_d = dict(k=self.__class__.__name__, | |
t=self.total, | |
f=self.record_field) | |
return "<{k} {f} total={t} >".format(**_d) | |
class MeanAggregator(BaseAggregator): | |
def __init__(self, record_field): | |
self.record_field = record_field | |
self.total = 0 | |
self.nvalues = 0 | |
def apply(self, record): | |
self.nvalues += 1 | |
v = getattr(record, self.record_field) | |
self.total += v | |
@property | |
def mean(self): | |
if self.nvalues == 0: | |
return 0.0 | |
return self.total / self.nvalues | |
def __repr__(self): | |
_d = dict(n=self.nvalues, t=self.total, | |
k=self.__class__.__name__, | |
f=self.record_field, | |
m=self.mean) | |
return "<{k} {f} mean={m} nvalue={n} total={t}>".format(**_d) | |
class Histogram(BaseAggregator): | |
def __init__(self, record_field, min_value, dx, nbins=10): | |
self.record_field = record_field | |
self.min_value = min_value | |
# bin width | |
self.dx = dx | |
self.bins = [0 for _ in xrange(nbins)] | |
@property | |
def nbins(self): | |
return len(self.bins) | |
@property | |
def max_value(self): | |
return self.nbins * self.dx | |
def apply(self, record): | |
"""Adaptively compute the histogram. If there are not enough bins, | |
more will be added.""" | |
v = getattr(record, self.record_field) | |
# If value is larger than the current list of bins | |
n = int(math.ceil(v / self.dx)) | |
max_v = (self.nbins - 1) * self.dx | |
if v >= max_v: | |
delta = v - max_v | |
n_new_bins = delta / self.dx | |
i = int(math.ceil(n_new_bins)) + 2 | |
# add more bins | |
#log.info(("Adding more bins ", delta, n_new_bins, i)) | |
for _ in xrange(i): | |
self.bins.append(0) | |
#log.info((v, n, max_v )) | |
#log.info("{k} {f} Adding value {v} index={n} to nbins {b} dx {x}".format(v=v, b=self.nbins, x=self.dx, n=n, f=self.record_field, k=self.__class__.__name__)) | |
self.bins[n] += 1 | |
def __repr__(self): | |
_d = dict(k=self.__class__.__name__, | |
f=self.record_field, | |
n=self.min_value, | |
x=self.max_value, | |
dx=self.dx, | |
nbins=len(self.bins)) | |
return "<{k} {f} nbins={nbins} dx={dx} min={n} max={x}>".format(**_d) | |
def _get_header_fields_from_csv(file_name): | |
"""Peak into the CSV file and extract the column headers. | |
:raises: CsvParserError if unable to extract column headers. | |
""" | |
with open(file_name, 'r') as f: | |
header = f.readline() | |
if ',' in header.rstrip(): | |
return header.rstrip().split(',') | |
else: | |
msg = "Malformed CSV. Enable to get column headers in {h}.".format(h=header) | |
raise CsvParserError(msg) | |
def _validate_header(headers): | |
return all([field in headers for field in _REQUIRED_HEADER_FIELDS]) | |
class Record(object): | |
def __init__(self, **kwargs): | |
for k, v in kwargs.iteritems(): | |
setattr(self, k, v) | |
def __repr__(self): | |
o = " ".join([''.join([x, '=', str(getattr(self, x))]) for x in _REQUIRED_HEADER_FIELDS]) | |
_d = dict(k=self.__class__.__name__, o=o) | |
return "<{k} {o}>".format(**_d) | |
def _row_to_record(column_names, row): | |
if ',' not in row: | |
raise CsvParserError("Malformed row {r}".format(r=row)) | |
d = {} | |
for name, value in zip(column_names, row.strip().split(',')): | |
t = COLUMN_TYPES[name] | |
#print row | |
#print name, t, value | |
v = t(value) | |
d[name] = v | |
r = Record(**d) | |
return r | |
def _filter_record_by_attribute(attr_name, func, value, record): | |
v = getattr(record, attr_name) | |
return func(value, v) | |
def _filter_record(filter_func, record): | |
"""Returns Bool""" | |
return filter_func(record) | |
def null_filter(record): | |
return True | |
def _multi_filter_record(filter_funcs, record): | |
"""Returns Bool""" | |
for filter_func in filter_funcs: | |
if not filter_func(record): | |
# Bail out at the first chance | |
return False | |
return True | |
def _apply(filter_funcs, aggregators, record): | |
"""Run the filters and call apply method on the aggregator if | |
the record passes filtering. | |
""" | |
if not isinstance(filter_funcs, (list, tuple)): | |
filter_funcs = [filter_funcs] | |
if not isinstance(aggregators, (list, tuple)): | |
aggregators = [aggregators] | |
if _multi_filter_record(filter_funcs, record): | |
for aggregator in aggregators: | |
aggregator.apply(record) | |
def applyer(row_to_record_func, iterable, funcs): | |
for it in iterable: | |
record = row_to_record_func(it) | |
for func in funcs: | |
func(record) | |
del record | |
def _to_table(nbases, nreads, mean_readlength, mean_read_score): | |
"""Create a pbreports.model.Table instance""" | |
return "" | |
def get_parser(): | |
desc = "" | |
parser = argparse.ArgumentParser(version=__version__, description=desc) | |
parser.add_argument('filter_summary_csv', type=validate_file, | |
help="Filter CSV file.") | |
parser.add_argument('-o', "--output", dest='output_dir', default=os.getcwd(), type=validate_dir, | |
help="Output directory for histogram images generated") | |
parser.add_argument('-r', '--report', dest='json_report', | |
help='Path of JSON report.') | |
parser.add_argument("--dpi", default=60, type=int, | |
help="dots/inch") | |
parser.add_argument('--debug', action='store_true', | |
help="Enable debug mode to stdout.") | |
return parser | |
def run_report(filter_csv, output_dir, base_report_name, dpi): | |
"""Main point of entry | |
The filter stats report has two main modes. | |
All Reads: (i.e., PreFilter) | |
SequencingZMW > 0 | |
- total bases | |
- total number of reads | |
- mean readlength | |
- mean readscore | |
HQ Region: (i.e., PostFilter) | |
PassedFilter > 0, SequencingZMW > 0 | |
- total bases | |
- total number of reads | |
- mean readlength | |
- mean readscore | |
Generates: | |
- Pre and Post filter ReadLength histograms with SDF (with thumbnails) | |
- Pre and Post filter ReadScore Histogram with SDF (with thumbnails) | |
- Pre and Post table of total bases, # of reads, mean readlengh, mean readscore | |
""" | |
P = functools.partial | |
row_to_rec_func = P(_row_to_record, _get_header_fields_from_csv(filter_csv)) | |
##### General Filters | |
# General construct to create a func with signature f(record) -> Bool | |
seq_zmw_filter_f = lambda record: record.SequencingZMW > 0 | |
hq_filter_f = lambda record: record.PassedFilter > 0 | |
####################### Pre-Filter Aggregator(s) | |
nbases_ag = SumAggregator('#Bases') | |
nreads_ag = CountAggregator('Readlength') | |
readlength_ag = MeanAggregator('Readlength') | |
# the histogram is adaptively computed. The min value and dx is the | |
readlength_hist_ag = Histogram('Readlength', 0, dx=10) | |
read_score_hist_ag = Histogram('ReadScore', 0, dx=0.01) | |
readscore_ag = SumAggregator('ReadScore', total=0) | |
readscore_mean_ag = MeanAggregator('ReadScore') | |
## Create/bind core Functions that can be based to the applyer method | |
# Calling these 'Models'. A model is list of filters and an aggregator | |
# Signature to _apply is ([filter1, filter2], aggregator, record) | |
# calling functools.partial returns a function signature f(record) | |
pre_filters = [seq_zmw_filter_f] | |
pre_agros = [nbases_ag, nreads_ag, | |
readscore_ag, readscore_mean_ag, | |
readlength_ag, | |
readlength_hist_ag, | |
read_score_hist_ag] | |
pre_models = [P(_apply, pre_filters, pre_agros)] | |
####################### Post-Filter Aggregator(s) | |
# | |
post_nbases_ag = SumAggregator('#Bases') | |
post_nreads_ag = CountAggregator('Readlength') | |
post_readlength_ag = MeanAggregator('Readlength') | |
# the histogram is adaptively computed. The min value and dx is the | |
post_readlength_hist_ag = Histogram('Readlength', 0, dx=10) | |
post_readscore_hist_ag = Histogram('ReadScore', 0, dx=0.01) | |
post_readscore_ag = SumAggregator('ReadScore') | |
post_readscore_mean_ag = MeanAggregator('ReadScore') | |
# Post Filter Models | |
post_filters = [seq_zmw_filter_f, hq_filter_f] | |
post_agros = [post_nbases_ag, post_nreads_ag, | |
post_readlength_ag, post_readscore_ag, | |
post_readscore_mean_ag, | |
post_readlength_hist_ag, post_readscore_hist_ag] | |
post_models = [P(_apply, post_filters, post_agros)] | |
models = pre_models + post_models | |
with open(filter_csv, 'r') as f: | |
# read in header | |
_ = f.readline() | |
applyer(row_to_rec_func, f, models) | |
# Now plot the data, makes tables from the Aggregator | |
log.info("*" * 10) | |
log.info("PreFilter Results") | |
log.info(nbases_ag) | |
log.info(nreads_ag) | |
log.info(readlength_ag) | |
log.info(readscore_ag) | |
log.info(readscore_mean_ag) | |
log.info(readlength_hist_ag) | |
log.info(read_score_hist_ag) | |
log.info("*" * 10) | |
log.info("PostFilter Results") | |
log.info(post_nbases_ag) | |
log.info(post_readlength_ag) | |
log.info(post_readscore_mean_ag) | |
log.info(post_readlength_hist_ag) | |
log.info(post_readscore_hist_ag) | |
log.info("completed runner") | |
# this should return a pbreports.model.Report instance | |
return 0 | |
def main(): | |
"""Main point of entry.""" | |
parser = get_parser() | |
args = parser.parse_args() | |
to_debug = args.debug | |
output_dir = args.output_dir | |
filter_csv = args.filter_summary_csv | |
json_report = args.json_report | |
dpi = args.dpi | |
level = logging.DEBUG if to_debug else logging.INFO | |
started_at = time.time() | |
if to_debug: | |
setup_log(log, level=level) | |
else: | |
log.addHandler(logging.NullHandler()) | |
try: | |
report = run_report(filter_csv, output_dir, dpi, json_report) | |
#report.write_json(json_report) | |
rcode = 0 | |
except Exception as e: | |
log.error(e, exc_info=True) | |
sys.stderr.write(str(e) + "\n") | |
rcode = 1 | |
run_time = time.time() - started_at | |
log.info("Completed main with returncode {r} in {s:.2f}".format(r=rcode, s=run_time)) | |
return rcode | |
if __name__ == '__main__': | |
sys.exit(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment