Skip to content

Instantly share code, notes, and snippets.

@mpkocher
Last active December 31, 2015 20:59
Show Gist options
  • Save mpkocher/8043965 to your computer and use it in GitHub Desktop.
Save mpkocher/8043965 to your computer and use it in GitHub Desktop.
import os
import abc
import logging
import sys
import functools
import operator
import argparse
import time
import math
#from pbreports.util import setup_log
__version__ = '2.5'
log = logging.getLogger(__name__)
_REQUIRED_HEADER_FIELDS = ('Readlength', 'ReadScore',
'PassedFilter', 'SequencingZMW', "ReadId")
# Format is { Name: ( dType, f(str->dType) ) }
VALID_COLUMNS = {"Movie": ("|S64", str),
"ReadId": ("|S64", str),
"#Bases": (int, int),
"Readlength": (int, int),
"ReadScore": (float, float),
"Productivity": (int, int),
"SequencingZMW": (int, int),
"PassedFilter": (int, int),
"Sandwiches": (int, int),
"Whitelisted": (int, int),
"SNR": (float, float),
"ArtifactScore": ("|S64", str)}
COLUMN_TYPES = {k: operator.getitem(v, 1) for k, v in VALID_COLUMNS.iteritems()}
class _BaseFilterException(Exception):
pass
class CsvParserError(_BaseFilterException):
pass
class NoFilteredReadsError(_BaseFilterException):
pass
class NoPassedFilteredReadsError(_BaseFilterException):
pass
def setup_log(alog, file_name=None, level=logging.DEBUG):
"""
Util function for setting up logging.
Due to how smrtpipe logs, the default behavior is that the stdout
is where the logging is redirected. If a file name is given the log
will be written to that file.
:param log: (log instance) Log instance that handlers and filters will
be added.
:param file_name: (str, None), Path to file. If None, stdout will be used.
:param level: (int) logging level
"""
if file_name is None:
handler = logging.StreamHandler(sys.stdout)
else:
handler = logging.FileHandler(file_name)
str_formatter = '[%(levelname)s] %(asctime)-15s [%(name)s %(funcName)s %(lineno)d] %(message)s'
formatter = logging.Formatter(str_formatter)
handler.setFormatter(formatter)
alog.addHandler(handler)
alog.setLevel(level)
def _validate_resource(func, resource):
"""Validate the existence of a file/dir"""
if func(resource):
return os.path.abspath(resource)
else:
raise IOError("Unable to find {f}".format(f=resource))
validate_file = functools.partial(_validate_resource, os.path.isfile)
validate_dir = functools.partial(_validate_resource, os.path.isdir)
validate_output_dir = functools.partial(_validate_resource, os.path.isdir)
class BaseAggregator(object):
__metaclass__ = abc.ABCMeta
@abc.abstractmethod
def apply(self, record):
pass
class CountAggregator(BaseAggregator):
def __init__(self, record_field, total=0):
self.total = total
self.record_field = record_field
def apply(self, record):
self.total += 1
def __repr__(self):
return "<{k} {f} total={t}>".format(k=self.__class__.__name__, t=self.total, f=self.record_field)
class SumAggregator(BaseAggregator):
def __init__(self, record_field, total=0):
self.total = total
self.record_field = record_field
def apply(self, record):
self.total += getattr(record, self.record_field)
def __repr__(self):
_d = dict(k=self.__class__.__name__,
t=self.total,
f=self.record_field)
return "<{k} {f} total={t} >".format(**_d)
class MeanAggregator(BaseAggregator):
def __init__(self, record_field):
self.record_field = record_field
self.total = 0
self.nvalues = 0
def apply(self, record):
self.nvalues += 1
v = getattr(record, self.record_field)
self.total += v
@property
def mean(self):
if self.nvalues == 0:
return 0.0
return self.total / self.nvalues
def __repr__(self):
_d = dict(n=self.nvalues, t=self.total,
k=self.__class__.__name__,
f=self.record_field,
m=self.mean)
return "<{k} {f} mean={m} nvalue={n} total={t}>".format(**_d)
class Histogram(BaseAggregator):
def __init__(self, record_field, min_value, dx, nbins=10):
self.record_field = record_field
self.min_value = min_value
# bin width
self.dx = dx
self.bins = [0 for _ in xrange(nbins)]
@property
def nbins(self):
return len(self.bins)
@property
def max_value(self):
return self.nbins * self.dx
def apply(self, record):
"""Adaptively compute the histogram. If there are not enough bins,
more will be added."""
v = getattr(record, self.record_field)
# If value is larger than the current list of bins
n = int(math.ceil(v / self.dx))
max_v = (self.nbins - 1) * self.dx
if v >= max_v:
delta = v - max_v
n_new_bins = delta / self.dx
i = int(math.ceil(n_new_bins)) + 2
# add more bins
#log.info(("Adding more bins ", delta, n_new_bins, i))
for _ in xrange(i):
self.bins.append(0)
#log.info((v, n, max_v ))
#log.info("{k} {f} Adding value {v} index={n} to nbins {b} dx {x}".format(v=v, b=self.nbins, x=self.dx, n=n, f=self.record_field, k=self.__class__.__name__))
self.bins[n] += 1
def __repr__(self):
_d = dict(k=self.__class__.__name__,
f=self.record_field,
n=self.min_value,
x=self.max_value,
dx=self.dx,
nbins=len(self.bins))
return "<{k} {f} nbins={nbins} dx={dx} min={n} max={x}>".format(**_d)
def _get_header_fields_from_csv(file_name):
"""Peak into the CSV file and extract the column headers.
:raises: CsvParserError if unable to extract column headers.
"""
with open(file_name, 'r') as f:
header = f.readline()
if ',' in header.rstrip():
return header.rstrip().split(',')
else:
msg = "Malformed CSV. Enable to get column headers in {h}.".format(h=header)
raise CsvParserError(msg)
def _validate_header(headers):
return all([field in headers for field in _REQUIRED_HEADER_FIELDS])
class Record(object):
def __init__(self, **kwargs):
for k, v in kwargs.iteritems():
setattr(self, k, v)
def __repr__(self):
o = " ".join([''.join([x, '=', str(getattr(self, x))]) for x in _REQUIRED_HEADER_FIELDS])
_d = dict(k=self.__class__.__name__, o=o)
return "<{k} {o}>".format(**_d)
def _row_to_record(column_names, row):
if ',' not in row:
raise CsvParserError("Malformed row {r}".format(r=row))
d = {}
for name, value in zip(column_names, row.strip().split(',')):
t = COLUMN_TYPES[name]
#print row
#print name, t, value
v = t(value)
d[name] = v
r = Record(**d)
return r
def _filter_record_by_attribute(attr_name, func, value, record):
v = getattr(record, attr_name)
return func(value, v)
def _filter_record(filter_func, record):
"""Returns Bool"""
return filter_func(record)
def null_filter(record):
return True
def _multi_filter_record(filter_funcs, record):
"""Returns Bool"""
for filter_func in filter_funcs:
if not filter_func(record):
# Bail out at the first chance
return False
return True
def _apply(filter_funcs, aggregators, record):
"""Run the filters and call apply method on the aggregator if
the record passes filtering.
"""
if not isinstance(filter_funcs, (list, tuple)):
filter_funcs = [filter_funcs]
if not isinstance(aggregators, (list, tuple)):
aggregators = [aggregators]
if _multi_filter_record(filter_funcs, record):
for aggregator in aggregators:
aggregator.apply(record)
def applyer(row_to_record_func, iterable, funcs):
for it in iterable:
record = row_to_record_func(it)
for func in funcs:
func(record)
del record
def _to_table(nbases, nreads, mean_readlength, mean_read_score):
"""Create a pbreports.model.Table instance"""
return ""
def get_parser():
desc = ""
parser = argparse.ArgumentParser(version=__version__, description=desc)
parser.add_argument('filter_summary_csv', type=validate_file,
help="Filter CSV file.")
parser.add_argument('-o', "--output", dest='output_dir', default=os.getcwd(), type=validate_dir,
help="Output directory for histogram images generated")
parser.add_argument('-r', '--report', dest='json_report',
help='Path of JSON report.')
parser.add_argument("--dpi", default=60, type=int,
help="dots/inch")
parser.add_argument('--debug', action='store_true',
help="Enable debug mode to stdout.")
return parser
def run_report(filter_csv, output_dir, base_report_name, dpi):
"""Main point of entry
The filter stats report has two main modes.
All Reads: (i.e., PreFilter)
SequencingZMW > 0
- total bases
- total number of reads
- mean readlength
- mean readscore
HQ Region: (i.e., PostFilter)
PassedFilter > 0, SequencingZMW > 0
- total bases
- total number of reads
- mean readlength
- mean readscore
Generates:
- Pre and Post filter ReadLength histograms with SDF (with thumbnails)
- Pre and Post filter ReadScore Histogram with SDF (with thumbnails)
- Pre and Post table of total bases, # of reads, mean readlengh, mean readscore
"""
P = functools.partial
row_to_rec_func = P(_row_to_record, _get_header_fields_from_csv(filter_csv))
##### General Filters
# General construct to create a func with signature f(record) -> Bool
seq_zmw_filter_f = lambda record: record.SequencingZMW > 0
hq_filter_f = lambda record: record.PassedFilter > 0
####################### Pre-Filter Aggregator(s)
nbases_ag = SumAggregator('#Bases')
nreads_ag = CountAggregator('Readlength')
readlength_ag = MeanAggregator('Readlength')
# the histogram is adaptively computed. The min value and dx is the
readlength_hist_ag = Histogram('Readlength', 0, dx=10)
read_score_hist_ag = Histogram('ReadScore', 0, dx=0.01)
readscore_ag = SumAggregator('ReadScore', total=0)
readscore_mean_ag = MeanAggregator('ReadScore')
## Create/bind core Functions that can be based to the applyer method
# Calling these 'Models'. A model is list of filters and an aggregator
# Signature to _apply is ([filter1, filter2], aggregator, record)
# calling functools.partial returns a function signature f(record)
pre_filters = [seq_zmw_filter_f]
pre_agros = [nbases_ag, nreads_ag,
readscore_ag, readscore_mean_ag,
readlength_ag,
readlength_hist_ag,
read_score_hist_ag]
pre_models = [P(_apply, pre_filters, pre_agros)]
####################### Post-Filter Aggregator(s)
#
post_nbases_ag = SumAggregator('#Bases')
post_nreads_ag = CountAggregator('Readlength')
post_readlength_ag = MeanAggregator('Readlength')
# the histogram is adaptively computed. The min value and dx is the
post_readlength_hist_ag = Histogram('Readlength', 0, dx=10)
post_readscore_hist_ag = Histogram('ReadScore', 0, dx=0.01)
post_readscore_ag = SumAggregator('ReadScore')
post_readscore_mean_ag = MeanAggregator('ReadScore')
# Post Filter Models
post_filters = [seq_zmw_filter_f, hq_filter_f]
post_agros = [post_nbases_ag, post_nreads_ag,
post_readlength_ag, post_readscore_ag,
post_readscore_mean_ag,
post_readlength_hist_ag, post_readscore_hist_ag]
post_models = [P(_apply, post_filters, post_agros)]
models = pre_models + post_models
with open(filter_csv, 'r') as f:
# read in header
_ = f.readline()
applyer(row_to_rec_func, f, models)
# Now plot the data, makes tables from the Aggregator
log.info("*" * 10)
log.info("PreFilter Results")
log.info(nbases_ag)
log.info(nreads_ag)
log.info(readlength_ag)
log.info(readscore_ag)
log.info(readscore_mean_ag)
log.info(readlength_hist_ag)
log.info(read_score_hist_ag)
log.info("*" * 10)
log.info("PostFilter Results")
log.info(post_nbases_ag)
log.info(post_readlength_ag)
log.info(post_readscore_mean_ag)
log.info(post_readlength_hist_ag)
log.info(post_readscore_hist_ag)
log.info("completed runner")
# this should return a pbreports.model.Report instance
return 0
def main():
"""Main point of entry."""
parser = get_parser()
args = parser.parse_args()
to_debug = args.debug
output_dir = args.output_dir
filter_csv = args.filter_summary_csv
json_report = args.json_report
dpi = args.dpi
level = logging.DEBUG if to_debug else logging.INFO
started_at = time.time()
if to_debug:
setup_log(log, level=level)
else:
log.addHandler(logging.NullHandler())
try:
report = run_report(filter_csv, output_dir, dpi, json_report)
#report.write_json(json_report)
rcode = 0
except Exception as e:
log.error(e, exc_info=True)
sys.stderr.write(str(e) + "\n")
rcode = 1
run_time = time.time() - started_at
log.info("Completed main with returncode {r} in {s:.2f}".format(r=rcode, s=run_time))
return rcode
if __name__ == '__main__':
sys.exit(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment