Skip to content

Instantly share code, notes, and snippets.

@jodyphelan
Created March 20, 2025 15:38
Show Gist options
  • Save jodyphelan/ffeb3bfa7e6beaba9442381b953aa2e0 to your computer and use it in GitHub Desktop.
Save jodyphelan/ffeb3bfa7e6beaba9442381b953aa2e0 to your computer and use it in GitHub Desktop.
from collections import defaultdict
import os
from pathogenprofiler import filecheck
import csv
import pathogenprofiler as pp
from tqdm import tqdm
import json
import ntm_profiler as ntmp
from ntm_profiler.models import ProfileResult, SpeciesResult
from typing import List, Tuple, Optional
import argparse
import sys
__softwarename__ = "ntm-profiler"
__default_data_dir__ = f'{sys.base_prefix}/share/{__softwarename__}/'
# create parser
parser = argparse.ArgumentParser(description='Collate results form multiple samples together')
parser.add_argument('--outfile','-o',default="ntmprofiler.collate.txt",help='Sample prefix')
parser.add_argument('--samples',help='File with samples (one per line)')
parser.add_argument('--dir','-d',default=".",help='Storage directory')
parser.add_argument('--suffix',default=".results.json",type=str,help='Input results files suffix')
parser.add_argument('--format',default="txt",choices=["txt","csv"],type=str,help='Output file type')
parser.add_argument('--version', action='version', version="NTM-Profiler version %s" % ntmp.__version__)
parser.add_argument('--db_dir',type=os.path.abspath,default=__default_data_dir__,help='Storage directory')
parser.add_argument('--temp',help="Temp firectory to process all files",type=str,default=".")
parser.add_argument('--logging',type=str.upper,default="INFO",choices=["DEBUG","INFO","WARNING","ERROR","CRITICAL"],help='Logging level')
args = parser.parse_args()
class VariantDB:
def __init__(self, json_db: Optional[dict] = None):
self.samples2variants = defaultdict(set)
self.variant2samples = defaultdict(set)
self.variant_frequencies = {}
self.samples = list()
self.variant_rows = []
if json_db:
for gene in json_db:
for mutation in json_db[gene]:
self.variant2samples[(gene,mutation)] = set()
def add_result(self, result: ProfileResult) -> None:
self.samples.append(result.id)
for var in result.dr_variants + result.other_variants:
key = (result.id,var.gene_name,var.change)
self.variant_frequencies[key] = var.freq
key = (var.gene_name,var.change)
self.variant2samples[key].add(result.id)
self.samples2variants[result.id].add(key)
d = var.model_dump()
d['sample'] = result.id
self.variant_rows.append(d)
def get_frequency(self,key: Tuple[str,str,str]) -> float:
return self.variant_frequencies.get(key,0.0)
def get_variant_list(self) -> List[Tuple[str,str]]:
return list(self.variant2samples.keys())
def write_dump(self,filename: str) -> None:
with open(filename,"w") as O:
fields = ["sample","gene_name","change","freq","type"]
writer = csv.DictWriter(O,fieldnames=fields)
writer.writeheader()
for row in self.variant_rows:
d = {k:row[k] for k in fields}
writer.writerow(d)
# Get a dictionary with the database file: {"ref": "/path/to/fasta" ... etc. }
if args.samples:
samples = [x.rstrip() for x in open(args.samples).readlines()]
else:
samples = [x.replace(args.suffix,"") for x in os.listdir(args.dir) if x[-len(args.suffix):]==args.suffix]
if len(samples)==0:
pp.logging.info(f"\nNo result files found in directory '{args.dir}'. Do you need to specify '--dir'?\n")
quit(0)
# Loop through the sample result files
variant_db = VariantDB()
rows = []
drug_resistance_results = []
resistance_dbs_used = set()
for s in tqdm(samples):
# Data has the same structure as the .result.json files
data = json.load(open(filecheck("%s/%s%s" % (args.dir,s,args.suffix))))
if data['result_type']=='Species':
result = SpeciesResult(**data)
else:
result = ProfileResult(**data)
row = {
'id': s
}
# top_species_hit = result.species.species[0] if len(result.species.species)>0 else None
if len(result.species.species)>0:
row['species'] = ";".join([hit.species for hit in result.species.species])
row['closest-sequence'] = ";".join([hit.prediction_info['accession'] for hit in result.species.species])
row['ANI'] = ";".join([str(hit.prediction_info['ani']) for hit in result.species.species])
else:
row['species'] = None
row['closest-sequence'] = None
row['ANI'] = None
if isinstance(result, ProfileResult):
resistance_dbs_used.add(result.pipeline.resistance_db_version['name'])
variant_db.add_result(result)
row['barcode'] = ";".join([x.id for x in result.barcode])
for var in result.dr_variants + result.dr_genes:
for d in var.drugs:
drug_resistance_results.append({
'id': s,
'drug': d['drug'],
'var': var.get_str(),
})
rows.append(row)
drugs = set()
for res_db in resistance_dbs_used:
res_db_conf = pp.get_db(args.db_dir,res_db,verbose=False)
drugs.update(res_db_conf['drugs'])
drugs = sorted(list(drugs))
for row in rows:
for drug in drugs:
row[drug] = "; ".join([x['var'] for x in drug_resistance_results if x['id']==row['id'] and x['drug']==drug])
if args.format=="txt":
args.sep = "\t"
else:
args.sep = ","
fields = [
'id',
'species',
'closest-sequence',
'ANI',
'barcode'
] + drugs
with open(args.outfile,"w") as O:
writer = csv.DictWriter(O,fieldnames=fields,delimiter=args.sep,extrasaction='ignore')
writer.writeheader()
writer.writerows(rows)
variant_db.write_dump(args.outfile + ".variants.csv")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment