Skip to content

Instantly share code, notes, and snippets.

@endrebak
Created June 20, 2023 05:10
Show Gist options
  • Save endrebak/027c8b51da52e12f4424ef176d1951bf to your computer and use it in GitHub Desktop.
Save endrebak/027c8b51da52e12f4424ef176d1951bf to your computer and use it in GitHub Desktop.
from typing import List
import polars as pl
import numpy as np
import bioframe.core.arrops as arrops
import pyoframe as pf
import polars as pl
import pytest
CHROMOSOME_PROPERTY = "chromosome"
CHROMOSOME2_PROPERTY = "chromosome_2"
STARTS_PROPERTY = "starts"
ENDS_PROPERTY = "ends"
STARTS2_PROPERTY = "starts_2"
ENDS2_PROPERTY = "ends_2"
STARTS_2IN1_PROPERTY = "starts_2in1"
ENDS_2IN1_PROPERTY = "ends_2in1"
STARTS_1IN2_PROPERTY = "starts_1in2"
ENDS_1IN2_PROPERTY = "ends_1in2"
MASK_1IN2_PROPERTY = "mask_1in2"
MASK_2IN1_PROPERTY = "mask_2in1"
LENGTHS_2IN1_PROPERTY = "lengths_2in1"
LENGTHS_1IN2_PROPERTY = "lengths_1in2"
DF_COLUMNS_PROPERTY = [STARTS_PROPERTY, ENDS_PROPERTY]
TEMP_CHROMOSOME_PROPERTY = f"__{CHROMOSOME_PROPERTY}__"
DF2_COLUMNS_PROPERTY = [STARTS2_PROPERTY, ENDS2_PROPERTY]
def search(
col1: str,
col2: str,
side: str = "left"
) -> pl.Expr:
return pl.col(col1).explode().search_sorted(pl.col(col2).explode(), side=side)
def lengths(
starts: str,
ends: str,
outname: str = ""
) -> pl.Expr:
return pl.col(ends).explode().sub(pl.col(starts).explode()).explode().alias(outname)
def find_starts_in_ends(starts, ends, starts_2, ends_2, closed: bool = False) -> List[pl.Expr]:
return [
search(starts_2, starts).alias(STARTS_2IN1_PROPERTY).implode(),
search(starts_2, ends).alias(ENDS_2IN1_PROPERTY).implode(),
search(starts, starts_2, side="right").alias(STARTS_1IN2_PROPERTY).implode(),
search(starts, ends_2).alias(ENDS_1IN2_PROPERTY).implode(),
]
def compute_masks() -> List[pl.Expr]:
return [
pl.col(ENDS_1IN2_PROPERTY).explode().gt(pl.col(STARTS_1IN2_PROPERTY).explode()).implode().alias(MASK_1IN2_PROPERTY),
pl.col(ENDS_2IN1_PROPERTY).explode().gt(pl.col(STARTS_2IN1_PROPERTY).explode()).implode().alias(MASK_2IN1_PROPERTY)
]
def apply_masks() -> List[pl.Expr]:
return [
pl.col([STARTS_2IN1_PROPERTY, ENDS_2IN1_PROPERTY]).explode().filter(pl.col(MASK_2IN1_PROPERTY).explode()).implode(),
pl.col([STARTS_1IN2_PROPERTY, ENDS_1IN2_PROPERTY]).explode().filter(pl.col(MASK_1IN2_PROPERTY).explode()).implode()
]
def add_lengths() -> pl.Expr:
return [
pl.col(ENDS_2IN1_PROPERTY).explode().sub(pl.col(STARTS_2IN1_PROPERTY).explode()).alias(LENGTHS_2IN1_PROPERTY).implode(),
pl.col(ENDS_1IN2_PROPERTY).explode().sub(pl.col(STARTS_1IN2_PROPERTY).explode()).alias(LENGTHS_1IN2_PROPERTY).implode()
]
def repeat_frame(columns, mask, startsin, endsin) -> pl.Expr:
return pl.col(columns).explode().filter(
pl.col(mask).explode()).repeat_by(
pl.col(endsin).explode() - pl.col(startsin).explode()
).explode()
def repeat_other(columns, starts, diffs):
return pl.col(columns).explode().take(
pl.col(starts).explode().repeat_by(pl.col(diffs).explode()).alias("cat_starts").explode().add(
pl.arange(0, pl.col(diffs).explode().sum()).explode().alias("length_sum_arange").sub(
pl.col(diffs).explode().cumsum().sub(pl.col(diffs).explode()).repeat_by(pl.col(diffs).explode()).explode()
)
)
)
def join(
df: pl.LazyFrame,
df2: pl.LazyFrame,
suffix: str,
starts: str,
ends: str,
starts_2: str,
ends_2: str,
):
i1 = df.sort(starts, ends).select([pl.exclude("chromosome").implode()])
i2 = df2.sort(starts, ends).select([pl.exclude("chromosome").implode()])
j = i1.join(i2, how="cross", suffix=suffix)
res = j.with_columns(find_starts_in_ends(STARTS_PROPERTY, ENDS_PROPERTY, STARTS2_PROPERTY, ENDS2_PROPERTY))
res2 = res.with_columns(
compute_masks()
).with_columns(
apply_masks()
).with_columns(
add_lengths()
).select(
pl.concat(
[
repeat_frame(DF_COLUMNS_PROPERTY, MASK_2IN1_PROPERTY, STARTS_2IN1_PROPERTY, ENDS_2IN1_PROPERTY),
repeat_other(DF_COLUMNS_PROPERTY, STARTS_1IN2_PROPERTY, LENGTHS_1IN2_PROPERTY)
]
),
pl.concat(
[
repeat_other(DF2_COLUMNS_PROPERTY, STARTS_2IN1_PROPERTY, LENGTHS_2IN1_PROPERTY),
repeat_frame(DF2_COLUMNS_PROPERTY, MASK_1IN2_PROPERTY, STARTS_1IN2_PROPERTY, ENDS_1IN2_PROPERTY),
]
)
)
print(res2.collect())
def genomics_join():
res = g.groupby(CHROMOSOME_PROPERTY).agg(
find_starts_in_ends(STARTS_PROPERTY, ENDS_PROPERTY, STARTS2_PROPERTY, ENDS2_PROPERTY)
).groupby(CHROMOSOME_PROPERTY).agg(
compute_masks()
).groupby(CHROMOSOME_PROPERTY).agg(
apply_masks()
).groupby(CHROMOSOME_PROPERTY).agg(
add_lengths() + [pl.all().explode()] # can this step be added to compute_masks and apply_masks to lengths too?
).groupby(CHROMOSOME_PROPERTY).agg(
[
pl.col(CHROMOSOME_PROPERTY).repeat_by(pl.col(LENGTHS_2IN1_PROPERTY).list.sum()).alias(CHROMOSOME_PROPERTY + "_top"),
repeat_frame(DF_COLUMNS_PROPERTY, MASK_2IN1_PROPERTY, STARTS_2IN1_PROPERTY, ENDS_2IN1_PROPERTY).suffix("_top"),
repeat_other(DF2_COLUMNS_PROPERTY, STARTS_2IN1_PROPERTY, LENGTHS_2IN1_PROPERTY).suffix("_top"),
pl.col(CHROMOSOME_PROPERTY).repeat_by(pl.col(LENGTHS_1IN2_PROPERTY).list.lengths()).alias(CHROMOSOME_PROPERTY + "_bottom"),
repeat_other(DF_COLUMNS_PROPERTY, STARTS_1IN2_PROPERTY, LENGTHS_1IN2_PROPERTY).suffix("_bottom"),
repeat_frame(DF2_COLUMNS_PROPERTY, MASK_1IN2_PROPERTY, STARTS_1IN2_PROPERTY, ENDS_1IN2_PROPERTY).suffix("_bottom"),
]
).groupby(CHROMOSOME_PROPERTY).agg(
[pl.col(CHROMOSOME_PROPERTY).repeat_by(
pl.col(f"{STARTS_PROPERTY}_top").list.lengths() + pl.col(f"{STARTS2_PROPERTY}_bottom").list.lengths()
).alias(TEMP_CHROMOSOME_PROPERTY)] + [
pl.col(f"{col}_top").list.concat(pl.col(f"{col}_bottom")).alias(col).explode()
for col in DF_COLUMNS_PROPERTY + DF2_COLUMNS_PROPERTY
],
).select(
[pl.col(TEMP_CHROMOSOME_PROPERTY).explode().alias(CHROMOSOME_PROPERTY), pl.exclude(CHROMOSOME_PROPERTY, TEMP_CHROMOSOME_PROPERTY).explode()]
)
print(res.collect())
# shape: (6, 5)
# ┌────────────┬────────┬──────┬──────────┬────────┐
# │ chromosome ┆ starts ┆ ends ┆ starts_2 ┆ ends_2 │
# │ --- ┆ --- ┆ --- ┆ --- ┆ --- │
# │ str ┆ i64 ┆ i64 ┆ i64 ┆ i64 │
# ╞════════════╪════════╪══════╪══════════╪════════╡
# │ chr1 ┆ 0 ┆ 6 ┆ 1 ┆ 2 │
# │ chr1 ┆ 0 ┆ 6 ┆ 3 ┆ 8 │
# │ chr1 ┆ 5 ┆ 7 ┆ 6 ┆ 7 │
# │ chr1 ┆ 6 ┆ 10 ┆ 6 ┆ 7 │
# │ chr1 ┆ 5 ┆ 7 ┆ 3 ┆ 8 │
# │ chr1 ┆ 6 ┆ 10 ┆ 3 ┆ 8 │
# └────────────┴────────┴──────┴──────────┴────────┘
assert 0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment