Skip to content

Instantly share code, notes, and snippets.

@nh13
Created January 10, 2025 20:31
Show Gist options
  • Save nh13/1f3bd11cd40970f35d63535b766e8f3a to your computer and use it in GitHub Desktop.
Save nh13/1f3bd11cd40970f35d63535b766e8f3a to your computer and use it in GitHub Desktop.

Tests for pybwa

mamba create -y -n pybwa setuptools cython python=3.11 poetry=1.8 snakemake-minimal=7 defopt samtools bwa pysam fgpyo sra-tools
conda activate pybwa
snakemake --cores $(nproc) --snakefile Snakefile --directory output
import defopt
from pathlib import Path
import pysam
import itertools
from fgpyo.sam import Template
from typing import Iterator
from pysam import AlignedSegment
from pybwa import BwaAln
from pybwa import BwaAlnOptions
from pybwa import BwaIndex
from pybwa.libbwamem import BwaMem
from pybwa.libbwamem import BwaMemOptions
from pybwa.libbwamem import BwaMemOptionsBuilder
def _align(bwa, opt, records, writer):
for rec in bwa.align(opt=opt, queries=records):
if isinstance(rec, list):
for r in rec:
writer.write(r)
else:
writer.write(rec)
def _run(fasta: Path, fastq: Path, bam: Path, chunk_size: int, bwa, opt) -> None:
with pysam.AlignmentFile(str(fasta.with_suffix(".dict"))) as reader:
header = reader.header
with pysam.FastxFile(str(fastq)) as reader, pysam.AlignmentFile(str(bam), mode="wb", header=header) as writer:
records = []
for rec in reader:
records.append(rec)
if len(records) == chunk_size:
_align(bwa, opt, records, writer)
records = []
if len(records) > 0:
_align(bwa, opt, records, writer)
records = []
def aln(*, fasta: Path, fastq: Path, bam: Path, chunk_size: int = 10000) -> None:
"""Run bwa aln with pybwa"""
opt = BwaAlnOptions()
bwa = BwaAln(prefix=fasta)
_run(fasta=fasta, fastq=fastq, bam=bam, chunk_size=chunk_size, bwa=bwa, opt=opt)
def mem(*, fasta: Path, fastq: Path, bam: Path, chunk_size: int = 10000) -> None:
"""Run bwa mem with pybwa"""
builder = BwaMemOptionsBuilder()
builder.query_coord_as_primary = True
opt = builder.build()
bwa = BwaMem(prefix=fasta)
_run(fasta=fasta, fastq=fastq, bam=bam, chunk_size=chunk_size, bwa=bwa, opt=opt)
def all_r1s(template: Template) -> Iterator[AlignedSegment]:
"""Yields all R1 alignments of this template including secondary and supplementary."""
r1_primary = [] if template.r1 is None else [template.r1]
return itertools.chain(r1_primary, template.r1_secondaries, template.r1_supplementals)
def all_r2s(template: Template) -> Iterator[AlignedSegment]:
"""Yields all R2 alignments of this template including secondary and supplementary."""
r2_primary = [] if template.r2 is None else [template.r2]
return itertools.chain(r2_primary, template.r2_secondaries, template.r2_supplementals)
def compare(*, left_bam: Path, right_bam: Path) -> None:
with pysam.AlignmentFile(str(left_bam)) as lh, pysam.AlignmentFile(str(right_bam)) as rh:
for lefts, rights in itertools.zip_longest(Template.iterator(lh), Template.iterator(rh)):
assert lefts is not None
assert rights is not None
# They should _always_ have the same name
if lefts.name != rights.name:
print(f"qname\tnull\t{lefts.name}\t{rights.name}")
raise Exception("Unrecoverable")
# Check the # of records for this template. We continue regardless since
# we want to check the primay alignments
num_lefts = len(list(lefts.all_recs()))
num_rights = len(list(rights.all_recs()))
if num_lefts != num_rights:
print(f"num_alignments\t{lefts.name}\t{num_lefts}\t{num_rights}")
# if no primary alignments, not much to do
if num_lefts == 0 or num_rights == 0:
continue
# Ditto for # of R1s
num_lefts = len(list(all_r1s(lefts)))
num_rights = len(list(all_r1s(rights)))
if num_lefts != num_rights:
print(f"num_r1\t{lefts.name}\t{num_lefts}\t{num_rights}")
# Ditto for # of R2s
num_lefts = len(list(all_r2s(lefts)))
num_rights = len(list(all_r2s(rights)))
if num_lefts != num_rights:
print(f"num_r2\t{lefts.name}\t{num_lefts}\t{num_rights}")
# We must have primary alignments for R1 (and not R2) for single-end
assert lefts.r2 is None
assert rights.r2 is None
left_r1 = lefts.r1
right_r1 = rights.r1
# Check mapping quality
assert left_r1 is not None and right_r1 is not None
assert left_r1.query_name == right_r1.query_name
if left_r1.mapping_quality != right_r1.mapping_quality:
print(f"mapq\t{lefts.name}\t{left_r1.mapping_quality}\t{right_r1.mapping_quality}")
continue
assert left_r1.is_unmapped == right_r1.is_unmapped
# Do not continue if the reads are not mapped
if left_r1.is_unmapped:
continue
# Alignment score (only if if AS is present)
assert left_r1.has_tag("AS") == right_r1.has_tag("AS")
if left_r1.has_tag("AS"):
if left_r1.get_tag("AS") != right_r1.get_tag("AS"):
print(f"AS\t{lefts.name}\t{left_r1.get_tag('AS')}\t{right_r1.get_tag('AS')}")
continue
# XS only if the mapping quality is non-zero and XS is present
if left_r1.mapping_quality == 0:
assert left_r1.has_tag("XS") == right_r1.has_tag("XS")
if left_r1.has_tag("XS"):
assert left_r1.get_tag("AS") >= left_r1.get_tag("XS")
assert right_r1.get_tag("AS") >= right_r1.get_tag("XS")
continue
# Contig and start
if left_r1.reference_name != right_r1.reference_name:
print(f"contig\t{lefts.name}\t{left_r1.reference_name}\t{right_r1.reference_name}")
continue
if left_r1.reference_start != right_r1.reference_start:
print(f"start\t{lefts.name}\t{left_r1.reference_start}\t{right_r1.reference_start}")
continue
# SAM flags
if left_r1.flag != right_r1.flag:
print(f"flag\t{lefts.name}\t{left_r1.flag}\t{right_r1.flag}")
continue
# Cigar
if left_r1.cigarstring != right_r1.cigarstring:
print(f"cigar\t{lefts.name}\t{left_r1.flag}\t{right_r1.flag}")
continue
# TODO: others? SA? The whole string (be careful about tag ordering)?
if __name__ == '__main__':
defopt.run([aln, mem, compare])
from pathlib import Path
samples = [
'SRR7733443',
'SRR9932168',
'SRR10286935' # from SRX6999918
]
# the number of reads to examine from each sample
num_reads = 1000000
fasta = Path("/path/to/human_g1k_v37/human_g1k_v37.fasta")
root_dir = Path("/path/to/this/dir")
methods = ['aln', 'mem']
rule all:
input:
[
f"diff/{method}/{sample}.summary.txt"
for sample in samples
for method in methods
]
rule download:
output:
fastq="fastqs/{sample}.fastq.gz"
log: "fastqs/{sample}.log"
params:
num_lines = num_reads * 4
shell:
"""
(
set +o pipefail
fastq-dump --stdout {wildcards.sample} | head -n {params.num_lines} | gzip -c > {output.fastq}
) &> {log}
"""
rule bwa_aln:
input:
fastq="fastqs/{sample}.fastq.gz",
fasta=fasta
output:
sai="bwa/aln/{sample}.sai",
bam="bwa/aln/{sample}.bam"
log: "bwa/aln/{sample}.log"
shell:
"""
(
bwa aln {input.fasta} {input.fastq} > {output.sai};
bwa samse {input.fasta} {output.sai} {input.fastq} | samtools view -Sb - > {output.bam};
) &> {log}
"""
rule bwa_mem:
input:
fastq="fastqs/{sample}.fastq.gz",
fasta=fasta
output:
bam="bwa/mem/{sample}.bam"
log: "bwa/mem/{sample}.log"
shell:
"""
(
bwa mem -5 {input.fasta} {input.fastq} | samtools view -Sb - > {output.bam};
) &> {log}
"""
rule pybwa_aln:
input:
fastq="fastqs/{sample}.fastq.gz",
fasta=fasta
output:
bam="pybwa/aln/{sample}.bam"
params:
pybwa=root_dir / "bwa.py"
log: "pybwa/aln/{sample}.log"
shell:
"""
(
python {params.pybwa} aln --fasta {input.fasta} --fastq {input.fastq} --bam {output.bam}
) &> {log}
"""
rule pybwa_mem:
input:
fastq="fastqs/{sample}.fastq.gz",
fasta=fasta
output:
bam="pybwa/mem/{sample}.bam"
params:
pybwa=root_dir / "bwa.py"
log: "pybwa/mem/{sample}.log"
shell:
"""
(
python {params.pybwa} mem --fasta {input.fasta} --fastq {input.fastq} --bam {output.bam}
) &> {log}
"""
rule compare:
input:
left="pybwa/{method}/{sample}.bam",
right="bwa/{method}/{sample}.bam"
output:
d_txt="diff/{method}/{sample}.detailed.txt",
s_txt="diff/{method}/{sample}.summary.txt"
params:
pybwa=root_dir / "bwa.py"
log:
"diff/{method}/{sample}.log"
shell:
"""
(
python {params.pybwa} compare --left {input.left} --right {input.right} > {output.d_txt};
num_errors=$(wc -l {output.d_txt} | awk \'{{print $1}}\');
echo -e "${{num_errors}}\ttotal" > {output.s_txt};
sort {output.d_txt} | cut -f 1 | uniq -c | awk \'{{print $1, $2}}\' | sort -n -k 1 -r >> {output.s_txt};
) &> {log}
"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment