Created
June 20, 2023 05:10
-
-
Save endrebak/027c8b51da52e12f4424ef176d1951bf to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import List | |
import polars as pl | |
import numpy as np | |
import bioframe.core.arrops as arrops | |
import pyoframe as pf | |
import polars as pl | |
import pytest | |
CHROMOSOME_PROPERTY = "chromosome" | |
CHROMOSOME2_PROPERTY = "chromosome_2" | |
STARTS_PROPERTY = "starts" | |
ENDS_PROPERTY = "ends" | |
STARTS2_PROPERTY = "starts_2" | |
ENDS2_PROPERTY = "ends_2" | |
STARTS_2IN1_PROPERTY = "starts_2in1" | |
ENDS_2IN1_PROPERTY = "ends_2in1" | |
STARTS_1IN2_PROPERTY = "starts_1in2" | |
ENDS_1IN2_PROPERTY = "ends_1in2" | |
MASK_1IN2_PROPERTY = "mask_1in2" | |
MASK_2IN1_PROPERTY = "mask_2in1" | |
LENGTHS_2IN1_PROPERTY = "lengths_2in1" | |
LENGTHS_1IN2_PROPERTY = "lengths_1in2" | |
DF_COLUMNS_PROPERTY = [STARTS_PROPERTY, ENDS_PROPERTY] | |
TEMP_CHROMOSOME_PROPERTY = f"__{CHROMOSOME_PROPERTY}__" | |
DF2_COLUMNS_PROPERTY = [STARTS2_PROPERTY, ENDS2_PROPERTY] | |
def search( | |
col1: str, | |
col2: str, | |
side: str = "left" | |
) -> pl.Expr: | |
return pl.col(col1).explode().search_sorted(pl.col(col2).explode(), side=side) | |
def lengths( | |
starts: str, | |
ends: str, | |
outname: str = "" | |
) -> pl.Expr: | |
return pl.col(ends).explode().sub(pl.col(starts).explode()).explode().alias(outname) | |
def find_starts_in_ends(starts, ends, starts_2, ends_2, closed: bool = False) -> List[pl.Expr]: | |
return [ | |
search(starts_2, starts).alias(STARTS_2IN1_PROPERTY).implode(), | |
search(starts_2, ends).alias(ENDS_2IN1_PROPERTY).implode(), | |
search(starts, starts_2, side="right").alias(STARTS_1IN2_PROPERTY).implode(), | |
search(starts, ends_2).alias(ENDS_1IN2_PROPERTY).implode(), | |
] | |
def compute_masks() -> List[pl.Expr]: | |
return [ | |
pl.col(ENDS_1IN2_PROPERTY).explode().gt(pl.col(STARTS_1IN2_PROPERTY).explode()).implode().alias(MASK_1IN2_PROPERTY), | |
pl.col(ENDS_2IN1_PROPERTY).explode().gt(pl.col(STARTS_2IN1_PROPERTY).explode()).implode().alias(MASK_2IN1_PROPERTY) | |
] | |
def apply_masks() -> List[pl.Expr]: | |
return [ | |
pl.col([STARTS_2IN1_PROPERTY, ENDS_2IN1_PROPERTY]).explode().filter(pl.col(MASK_2IN1_PROPERTY).explode()).implode(), | |
pl.col([STARTS_1IN2_PROPERTY, ENDS_1IN2_PROPERTY]).explode().filter(pl.col(MASK_1IN2_PROPERTY).explode()).implode() | |
] | |
def add_lengths() -> pl.Expr: | |
return [ | |
pl.col(ENDS_2IN1_PROPERTY).explode().sub(pl.col(STARTS_2IN1_PROPERTY).explode()).alias(LENGTHS_2IN1_PROPERTY).implode(), | |
pl.col(ENDS_1IN2_PROPERTY).explode().sub(pl.col(STARTS_1IN2_PROPERTY).explode()).alias(LENGTHS_1IN2_PROPERTY).implode() | |
] | |
def repeat_frame(columns, mask, startsin, endsin) -> pl.Expr: | |
return pl.col(columns).explode().filter( | |
pl.col(mask).explode()).repeat_by( | |
pl.col(endsin).explode() - pl.col(startsin).explode() | |
).explode() | |
def repeat_other(columns, starts, diffs): | |
return pl.col(columns).explode().take( | |
pl.col(starts).explode().repeat_by(pl.col(diffs).explode()).alias("cat_starts").explode().add( | |
pl.arange(0, pl.col(diffs).explode().sum()).explode().alias("length_sum_arange").sub( | |
pl.col(diffs).explode().cumsum().sub(pl.col(diffs).explode()).repeat_by(pl.col(diffs).explode()).explode() | |
) | |
) | |
) | |
def join( | |
df: pl.LazyFrame, | |
df2: pl.LazyFrame, | |
suffix: str, | |
starts: str, | |
ends: str, | |
starts_2: str, | |
ends_2: str, | |
): | |
i1 = df.sort(starts, ends).select([pl.exclude("chromosome").implode()]) | |
i2 = df2.sort(starts, ends).select([pl.exclude("chromosome").implode()]) | |
j = i1.join(i2, how="cross", suffix=suffix) | |
res = j.with_columns(find_starts_in_ends(STARTS_PROPERTY, ENDS_PROPERTY, STARTS2_PROPERTY, ENDS2_PROPERTY)) | |
res2 = res.with_columns( | |
compute_masks() | |
).with_columns( | |
apply_masks() | |
).with_columns( | |
add_lengths() | |
).select( | |
pl.concat( | |
[ | |
repeat_frame(DF_COLUMNS_PROPERTY, MASK_2IN1_PROPERTY, STARTS_2IN1_PROPERTY, ENDS_2IN1_PROPERTY), | |
repeat_other(DF_COLUMNS_PROPERTY, STARTS_1IN2_PROPERTY, LENGTHS_1IN2_PROPERTY) | |
] | |
), | |
pl.concat( | |
[ | |
repeat_other(DF2_COLUMNS_PROPERTY, STARTS_2IN1_PROPERTY, LENGTHS_2IN1_PROPERTY), | |
repeat_frame(DF2_COLUMNS_PROPERTY, MASK_1IN2_PROPERTY, STARTS_1IN2_PROPERTY, ENDS_1IN2_PROPERTY), | |
] | |
) | |
) | |
print(res2.collect()) | |
def genomics_join(): | |
res = g.groupby(CHROMOSOME_PROPERTY).agg( | |
find_starts_in_ends(STARTS_PROPERTY, ENDS_PROPERTY, STARTS2_PROPERTY, ENDS2_PROPERTY) | |
).groupby(CHROMOSOME_PROPERTY).agg( | |
compute_masks() | |
).groupby(CHROMOSOME_PROPERTY).agg( | |
apply_masks() | |
).groupby(CHROMOSOME_PROPERTY).agg( | |
add_lengths() + [pl.all().explode()] # can this step be added to compute_masks and apply_masks to lengths too? | |
).groupby(CHROMOSOME_PROPERTY).agg( | |
[ | |
pl.col(CHROMOSOME_PROPERTY).repeat_by(pl.col(LENGTHS_2IN1_PROPERTY).list.sum()).alias(CHROMOSOME_PROPERTY + "_top"), | |
repeat_frame(DF_COLUMNS_PROPERTY, MASK_2IN1_PROPERTY, STARTS_2IN1_PROPERTY, ENDS_2IN1_PROPERTY).suffix("_top"), | |
repeat_other(DF2_COLUMNS_PROPERTY, STARTS_2IN1_PROPERTY, LENGTHS_2IN1_PROPERTY).suffix("_top"), | |
pl.col(CHROMOSOME_PROPERTY).repeat_by(pl.col(LENGTHS_1IN2_PROPERTY).list.lengths()).alias(CHROMOSOME_PROPERTY + "_bottom"), | |
repeat_other(DF_COLUMNS_PROPERTY, STARTS_1IN2_PROPERTY, LENGTHS_1IN2_PROPERTY).suffix("_bottom"), | |
repeat_frame(DF2_COLUMNS_PROPERTY, MASK_1IN2_PROPERTY, STARTS_1IN2_PROPERTY, ENDS_1IN2_PROPERTY).suffix("_bottom"), | |
] | |
).groupby(CHROMOSOME_PROPERTY).agg( | |
[pl.col(CHROMOSOME_PROPERTY).repeat_by( | |
pl.col(f"{STARTS_PROPERTY}_top").list.lengths() + pl.col(f"{STARTS2_PROPERTY}_bottom").list.lengths() | |
).alias(TEMP_CHROMOSOME_PROPERTY)] + [ | |
pl.col(f"{col}_top").list.concat(pl.col(f"{col}_bottom")).alias(col).explode() | |
for col in DF_COLUMNS_PROPERTY + DF2_COLUMNS_PROPERTY | |
], | |
).select( | |
[pl.col(TEMP_CHROMOSOME_PROPERTY).explode().alias(CHROMOSOME_PROPERTY), pl.exclude(CHROMOSOME_PROPERTY, TEMP_CHROMOSOME_PROPERTY).explode()] | |
) | |
print(res.collect()) | |
# shape: (6, 5) | |
# ┌────────────┬────────┬──────┬──────────┬────────┐ | |
# │ chromosome ┆ starts ┆ ends ┆ starts_2 ┆ ends_2 │ | |
# │ --- ┆ --- ┆ --- ┆ --- ┆ --- │ | |
# │ str ┆ i64 ┆ i64 ┆ i64 ┆ i64 │ | |
# ╞════════════╪════════╪══════╪══════════╪════════╡ | |
# │ chr1 ┆ 0 ┆ 6 ┆ 1 ┆ 2 │ | |
# │ chr1 ┆ 0 ┆ 6 ┆ 3 ┆ 8 │ | |
# │ chr1 ┆ 5 ┆ 7 ┆ 6 ┆ 7 │ | |
# │ chr1 ┆ 6 ┆ 10 ┆ 6 ┆ 7 │ | |
# │ chr1 ┆ 5 ┆ 7 ┆ 3 ┆ 8 │ | |
# │ chr1 ┆ 6 ┆ 10 ┆ 3 ┆ 8 │ | |
# └────────────┴────────┴──────┴──────────┴────────┘ | |
assert 0 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment