Created
March 20, 2025 15:38
-
-
Save jodyphelan/ffeb3bfa7e6beaba9442381b953aa2e0 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import defaultdict | |
import os | |
from pathogenprofiler import filecheck | |
import csv | |
import pathogenprofiler as pp | |
from tqdm import tqdm | |
import json | |
import ntm_profiler as ntmp | |
from ntm_profiler.models import ProfileResult, SpeciesResult | |
from typing import List, Tuple, Optional | |
import argparse | |
import sys | |
__softwarename__ = "ntm-profiler" | |
__default_data_dir__ = f'{sys.base_prefix}/share/{__softwarename__}/' | |
# create parser | |
parser = argparse.ArgumentParser(description='Collate results form multiple samples together') | |
parser.add_argument('--outfile','-o',default="ntmprofiler.collate.txt",help='Sample prefix') | |
parser.add_argument('--samples',help='File with samples (one per line)') | |
parser.add_argument('--dir','-d',default=".",help='Storage directory') | |
parser.add_argument('--suffix',default=".results.json",type=str,help='Input results files suffix') | |
parser.add_argument('--format',default="txt",choices=["txt","csv"],type=str,help='Output file type') | |
parser.add_argument('--version', action='version', version="NTM-Profiler version %s" % ntmp.__version__) | |
parser.add_argument('--db_dir',type=os.path.abspath,default=__default_data_dir__,help='Storage directory') | |
parser.add_argument('--temp',help="Temp firectory to process all files",type=str,default=".") | |
parser.add_argument('--logging',type=str.upper,default="INFO",choices=["DEBUG","INFO","WARNING","ERROR","CRITICAL"],help='Logging level') | |
args = parser.parse_args() | |
class VariantDB: | |
def __init__(self, json_db: Optional[dict] = None): | |
self.samples2variants = defaultdict(set) | |
self.variant2samples = defaultdict(set) | |
self.variant_frequencies = {} | |
self.samples = list() | |
self.variant_rows = [] | |
if json_db: | |
for gene in json_db: | |
for mutation in json_db[gene]: | |
self.variant2samples[(gene,mutation)] = set() | |
def add_result(self, result: ProfileResult) -> None: | |
self.samples.append(result.id) | |
for var in result.dr_variants + result.other_variants: | |
key = (result.id,var.gene_name,var.change) | |
self.variant_frequencies[key] = var.freq | |
key = (var.gene_name,var.change) | |
self.variant2samples[key].add(result.id) | |
self.samples2variants[result.id].add(key) | |
d = var.model_dump() | |
d['sample'] = result.id | |
self.variant_rows.append(d) | |
def get_frequency(self,key: Tuple[str,str,str]) -> float: | |
return self.variant_frequencies.get(key,0.0) | |
def get_variant_list(self) -> List[Tuple[str,str]]: | |
return list(self.variant2samples.keys()) | |
def write_dump(self,filename: str) -> None: | |
with open(filename,"w") as O: | |
fields = ["sample","gene_name","change","freq","type"] | |
writer = csv.DictWriter(O,fieldnames=fields) | |
writer.writeheader() | |
for row in self.variant_rows: | |
d = {k:row[k] for k in fields} | |
writer.writerow(d) | |
# Get a dictionary with the database file: {"ref": "/path/to/fasta" ... etc. } | |
if args.samples: | |
samples = [x.rstrip() for x in open(args.samples).readlines()] | |
else: | |
samples = [x.replace(args.suffix,"") for x in os.listdir(args.dir) if x[-len(args.suffix):]==args.suffix] | |
if len(samples)==0: | |
pp.logging.info(f"\nNo result files found in directory '{args.dir}'. Do you need to specify '--dir'?\n") | |
quit(0) | |
# Loop through the sample result files | |
variant_db = VariantDB() | |
rows = [] | |
drug_resistance_results = [] | |
resistance_dbs_used = set() | |
for s in tqdm(samples): | |
# Data has the same structure as the .result.json files | |
data = json.load(open(filecheck("%s/%s%s" % (args.dir,s,args.suffix)))) | |
if data['result_type']=='Species': | |
result = SpeciesResult(**data) | |
else: | |
result = ProfileResult(**data) | |
row = { | |
'id': s | |
} | |
# top_species_hit = result.species.species[0] if len(result.species.species)>0 else None | |
if len(result.species.species)>0: | |
row['species'] = ";".join([hit.species for hit in result.species.species]) | |
row['closest-sequence'] = ";".join([hit.prediction_info['accession'] for hit in result.species.species]) | |
row['ANI'] = ";".join([str(hit.prediction_info['ani']) for hit in result.species.species]) | |
else: | |
row['species'] = None | |
row['closest-sequence'] = None | |
row['ANI'] = None | |
if isinstance(result, ProfileResult): | |
resistance_dbs_used.add(result.pipeline.resistance_db_version['name']) | |
variant_db.add_result(result) | |
row['barcode'] = ";".join([x.id for x in result.barcode]) | |
for var in result.dr_variants + result.dr_genes: | |
for d in var.drugs: | |
drug_resistance_results.append({ | |
'id': s, | |
'drug': d['drug'], | |
'var': var.get_str(), | |
}) | |
rows.append(row) | |
drugs = set() | |
for res_db in resistance_dbs_used: | |
res_db_conf = pp.get_db(args.db_dir,res_db,verbose=False) | |
drugs.update(res_db_conf['drugs']) | |
drugs = sorted(list(drugs)) | |
for row in rows: | |
for drug in drugs: | |
row[drug] = "; ".join([x['var'] for x in drug_resistance_results if x['id']==row['id'] and x['drug']==drug]) | |
if args.format=="txt": | |
args.sep = "\t" | |
else: | |
args.sep = "," | |
fields = [ | |
'id', | |
'species', | |
'closest-sequence', | |
'ANI', | |
'barcode' | |
] + drugs | |
with open(args.outfile,"w") as O: | |
writer = csv.DictWriter(O,fieldnames=fields,delimiter=args.sep,extrasaction='ignore') | |
writer.writeheader() | |
writer.writerows(rows) | |
variant_db.write_dump(args.outfile + ".variants.csv") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment