Last active
October 14, 2022 16:30
-
-
Save jeromekelleher/33927b941f2d63317049aacce16ff63a to your computer and use it in GitHub Desktop.
Compute branch GRM using numba
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import tskit | |
import numpy as np | |
import numba | |
@numba.njit | |
def sv_tables_init(parent_array): | |
# This is an implementation of Schieber and Vishkin's nearest common ancestor | |
# algorithm from TAOCP volume 4A, pg.164-167 [K11]_. Preprocesses the | |
# input tree into a sideways heap in O(n) time and processes queries for the | |
# nearest common ancestor between an arbitary pair of nodes in O(1) time. | |
# | |
# NB internally this assumes that tree uses 1-based addressing and 0 is a | |
# special value. We would like to update this to use the 0-based indexing | |
# natively and also use the built-in triply linked tree to save some time | |
# and memory. | |
n = 1 + parent_array.shape[0] | |
oriented_forest = np.zeros(n, dtype=np.int32) | |
# Convert to 1-based representation assumed here. | |
oriented_forest[1:] = 1 + parent_array | |
LAMBDA = 0 | |
# Triply-linked tree. FIXME we shouldn't need to build this as it's | |
# available already in tskit | |
child = np.zeros(n, dtype=np.int32) | |
parent = np.zeros(n, dtype=np.int32) | |
sib = np.zeros(n, dtype=np.int32) | |
for u in range(n): | |
v = oriented_forest[u] | |
sib[u] = child[v] | |
child[v] = u | |
parent[u] = v | |
lambd = np.zeros(n, dtype=np.int32) | |
pi = np.zeros(n, dtype=np.int32) | |
tau = np.zeros(n, dtype=np.int32) | |
beta = np.zeros(n, dtype=np.int32) | |
alpha = np.zeros(n, dtype=np.int32) | |
p = child[LAMBDA] | |
n = 0 | |
lambd[0] = -1 | |
while p != LAMBDA: | |
while True: | |
n += 1 | |
pi[p] = n | |
tau[n] = LAMBDA | |
lambd[n] = 1 + lambd[n >> 1] | |
if child[p] != LAMBDA: | |
p = child[p] | |
else: | |
break | |
beta[p] = n | |
while True: | |
tau[beta[p]] = parent[p] | |
if sib[p] != LAMBDA: | |
p = sib[p] | |
break | |
else: | |
p = parent[p] | |
if p != LAMBDA: | |
h = lambd[n & -pi[p]] | |
beta[p] = ((n >> h) | 1) << h | |
else: | |
break | |
# Begin the second traversal | |
lambd[0] = lambd[n] | |
pi[LAMBDA] = 0 | |
beta[LAMBDA] = 0 | |
alpha[LAMBDA] = 0 | |
p = child[LAMBDA] | |
while p != LAMBDA: | |
while True: | |
a = alpha[parent[p]] | (beta[p] & -beta[p]) | |
alpha[p] = a | |
if child[p] != LAMBDA: | |
p = child[p] | |
else: | |
break | |
while True: | |
if sib[p] != LAMBDA: | |
p = sib[p] | |
break | |
else: | |
p = parent[p] | |
if p == LAMBDA: | |
break | |
return lambd, pi, tau, beta, alpha | |
@numba.njit | |
def _sv_mrca(x, y, lambd, pi, tau, beta, alpha): | |
if beta[x] <= beta[y]: | |
h = lambd[beta[y] & -beta[x]] | |
else: | |
h = lambd[beta[x] & -beta[y]] | |
k = alpha[x] & alpha[y] & -(1 << h) | |
h = lambd[k & -k] | |
j = ((beta[x] >> h) | 1) << h | |
if j == beta[x]: | |
xhat = x | |
else: | |
ell = lambd[alpha[x] & ((1 << h) - 1)] | |
xhat = tau[((beta[x] >> ell) | 1) << ell] | |
if j == beta[y]: | |
yhat = y | |
else: | |
ell = lambd[alpha[y] & ((1 << h) - 1)] | |
yhat = tau[((beta[y] >> ell) | 1) << ell] | |
if pi[xhat] <= pi[yhat]: | |
z = xhat | |
else: | |
z = yhat | |
return z | |
@numba.njit | |
def sv_mrca(x, y, lambd, pi, tau, beta, alpha): | |
# Convert to 1-based indexes and back. See note above. | |
return _sv_mrca(x + 1, y + 1, lambd, pi, tau, beta, alpha) - 1 | |
@numba.njit | |
def _B_matrix_sv(I, parent, time, root_time): | |
# Preprocess so that we can answer MRCA queries in constant time. | |
lambd, pi, tau, beta, alpha = sv_tables_init(parent) | |
N = I.shape[0] | |
B = np.zeros((N, N)) | |
for j in range(N): | |
for k in range(j, N): | |
s = 0 | |
for u in I[j]: | |
for v in I[k]: | |
mrca = sv_mrca(u, v, lambd, pi, tau, beta, alpha) | |
s += root_time - time[mrca] | |
B[j, k] = s | |
B[k, j] = s | |
return B | |
@numba.njit | |
def _normalise(B): | |
K = np.zeros_like(B) | |
N = K.shape[0] | |
B_mean = np.mean(B) | |
# Numba doesn't support np.mean(a, axis=0) | |
Bi_mean = np.zeros(N) | |
for i in range(N): | |
for j in range(N): | |
Bi_mean[i] += B[i, j] | |
Bi_mean /= N | |
for i in range(N): | |
for j in range(N): | |
K[i, j] = B[i, j] - Bi_mean[i] - Bi_mean[j] + B_mean | |
return K | |
def branch_genetic_relatedness_matrix(ts): | |
N = ts.num_individuals | |
I = np.zeros((N, 2), dtype=np.int32) | |
for ind in ts.individuals(): | |
I[ind.id] = ind.nodes | |
K = np.zeros((ts.num_individuals, ts.num_individuals)) | |
for tree in ts.trees(): | |
if tree.num_roots == ts.num_samples: | |
continue | |
root_time = ts.nodes_time[tree.root] | |
B = _B_matrix_sv(I, tree.parent_array, ts.nodes_time, root_time) | |
K += _normalise(B) * tree.span | |
return K | |
if __name__ == "__main__": | |
if len(sys.argv) != 3: | |
print(f"usage: {sys.argv[0]} file.trees relatedness.txt") | |
sys.exit(1) | |
ts = tskit.load(sys.argv[1]) | |
K = branch_genetic_relatedness_matrix(ts) | |
# K2 = genetic_relatedness_matrix(ts) | |
# assert np.allclose(K, K2) | |
np.savetxt(sys.argv[2], K) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment