Skip to content

Instantly share code, notes, and snippets.

@apcamargo
Created April 29, 2026 22:33
Show Gist options
  • Select an option

  • Save apcamargo/e43da080bc6a1818cdab69e24fa3fc45 to your computer and use it in GitHub Desktop.

Select an option

Save apcamargo/e43da080bc6a1818cdab69e24fa3fc45 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
import argparse
import math
import sys
from dataclasses import dataclass
from pathlib import Path
ANY_BACK = 1e-5
MATRIX_SINGULARITY_THRESHOLD = 1e-15
CONVERGENCE_TOLERANCE = 1e-10
NEWTON_CONVERGENCE_TOLERANCE = 1e-10
LAMBDA_LOWER_BOUND = 1e-6
BISECTION_CONVERGENCE_TOLERANCE = 1e-12
POWER_ITERATION_TOLERANCE = 1e-12
DAMPING_FACTOR = 0.8
LAMBDA_UPPER_BOUND_SAFETY = 0.95
PROBABILITY_VALIDATION_TOLERANCE = 1e-10
RESTRICTION_VALIDATION_TOLERANCE = 1e-6
ZERO_ROW_COL_THRESHOLD = 1e-10
MAX_NEWTON_ITERATIONS = 50
MAX_BISECTION_ITERATIONS = 100
MAX_POWER_ITERATIONS = 1000
MAX_BRACKET_ATTEMPTS = 50
BRACKET_INCREMENT_FACTOR = 1.5
BRACKET_START_LAMBDA = 0.01
FloatMatrix = list[list[float]]
IntMatrix = list[list[int]]
ScoreProbSums = list[tuple[float, float]]
ScoreIndexGroups = list[tuple[float, list[tuple[int, int]]]]
@dataclass
class ParsedMatrix:
alphabet_size: int
num2aa: list[str]
aa2num: dict[str, int]
prob_matrix: FloatMatrix
has_terminal_x: bool
@dataclass
class ReducedResult:
aa_groups: list[str]
alphabet: list[str]
score_matrix: IntMatrix
def validate_score_matrix(score_matrix: FloatMatrix) -> None:
size = len(score_matrix)
if size <= 0 or any(len(row) != size for row in score_matrix):
raise ValueError("score matrix must be square and non-empty")
values = [score for row in score_matrix for score in row]
if any(math.isnan(score) or math.isinf(score) for score in values):
raise ValueError("score matrix contains NaN or Inf")
for i in range(size):
for j in range(size):
if abs(score_matrix[i][j] - score_matrix[j][i]) > CONVERGENCE_TOLERANCE:
raise ValueError(f"score matrix is not symmetric at ({i}, {j})")
if max(values) - min(values) <= 0.0:
raise ValueError("score matrix has no score variation")
def find_lambda_bound(score_matrix: FloatMatrix) -> float:
size = len(score_matrix)
row_sums = [sum(row) for row in score_matrix]
col_sums = [sum(score_matrix[i][j] for i in range(size)) for j in range(size)]
if any(abs(total) < ZERO_ROW_COL_THRESHOLD for total in row_sums) or any(
abs(total) < ZERO_ROW_COL_THRESHOLD for total in col_sums
):
raise ValueError("score matrix has near-zero row or column score sums")
row_maxes = [0.0 for _ in range(size)]
col_maxes = [0.0 for _ in range(size)]
for i in range(size):
row_pos = False
row_neg = False
for j in range(size):
value = score_matrix[i][j]
if value > 0.0:
row_pos = True
row_maxes[i] = max(row_maxes[i], value)
col_maxes[j] = max(col_maxes[j], value)
elif value < 0.0:
row_neg = True
if not (row_pos and row_neg):
raise ValueError(f"score matrix row {i} lacks positive and negative scores")
for j in range(size):
col_pos = False
col_neg = False
for i in range(size):
value = score_matrix[i][j]
if value > 0.0:
col_pos = True
elif value < 0.0:
col_neg = True
if not (col_pos and col_neg):
raise ValueError(
f"score matrix column {j} lacks positive and negative scores"
)
upper_bound = min(min(row_maxes), min(col_maxes))
if upper_bound <= 0.0:
raise ValueError("invalid non-positive lambda upper bound")
return upper_bound
def group_score_indices(score_matrix: FloatMatrix) -> ScoreIndexGroups:
groups: ScoreIndexGroups = []
for i, row in enumerate(score_matrix):
for j, score in enumerate(row):
found_index = None
for index, (group_score, _) in enumerate(groups):
if abs(group_score - score) < CONVERGENCE_TOLERANCE:
found_index = index
break
if found_index is None:
groups.append((score, [(i, j)]))
else:
groups[found_index][1].append((i, j))
return groups
def weight_score_groups(
score_index_groups: ScoreIndexGroups, p: list[float], q: list[float]
) -> ScoreProbSums:
groups: ScoreProbSums = []
for score, indices in score_index_groups:
prob_sum = 0.0
for i, j in indices:
prob_sum += p[i] * q[j]
groups.append((score, prob_sum))
return groups
def lambda_constraint(groups: ScoreProbSums, lambda_value: float) -> float:
return (
sum(prob_sum * math.exp(lambda_value * score) for score, prob_sum in groups)
- 1.0
)
def lambda_constraint_deriv(groups: ScoreProbSums, lambda_value: float) -> float:
return sum(
prob_sum * score * math.exp(lambda_value * score) for score, prob_sum in groups
)
def bracket_lambda(
score_index_groups: ScoreIndexGroups,
p: list[float],
q: list[float],
lambda_upper_bound: float,
) -> tuple[float, float]:
groups = weight_score_groups(score_index_groups, p, q)
f_zero = lambda_constraint(groups, 0.0)
if abs(f_zero) > PROBABILITY_VALIDATION_TOLERANCE:
raise ValueError("restriction value at lambda=0 is not close to zero")
f_prime_zero = lambda_constraint_deriv(groups, 0.0)
if f_prime_zero >= 0.0:
raise ValueError("restriction derivative at lambda=0 is not negative")
lambda_low = 0.0
lambda_high = BRACKET_START_LAMBDA
tried_upper_limit = False
upper_limit = lambda_upper_bound * LAMBDA_UPPER_BOUND_SAFETY
for _ in range(MAX_BRACKET_ATTEMPTS):
if lambda_constraint(groups, lambda_high) > 0.0:
return lambda_low, lambda_high
lambda_high *= BRACKET_INCREMENT_FACTOR
if not tried_upper_limit and lambda_high > upper_limit:
lambda_high = upper_limit
tried_upper_limit = True
if lambda_constraint(groups, lambda_high) > 0.0:
return lambda_low, lambda_high
raise ValueError("failed to bracket lambda")
def lu_factorize(matrix: FloatMatrix) -> tuple[list[int], FloatMatrix] | None:
size = len(matrix)
lu_matrix = [row[:] for row in matrix]
pivot_indices = list(range(size))
for k in range(size - 1):
pivot = k
max_value = abs(lu_matrix[k][k])
for i in range(k + 1, size):
value = abs(lu_matrix[i][k])
if value > max_value:
max_value = value
pivot = i
if max_value < MATRIX_SINGULARITY_THRESHOLD:
return None
if pivot != k:
pivot_indices[k], pivot_indices[pivot] = (
pivot_indices[pivot],
pivot_indices[k],
)
lu_matrix[k], lu_matrix[pivot] = lu_matrix[pivot], lu_matrix[k]
for i in range(k + 1, size):
factor = lu_matrix[i][k] / lu_matrix[k][k]
lu_matrix[i][k] = factor
for j in range(k + 1, size):
lu_matrix[i][j] -= factor * lu_matrix[k][j]
return pivot_indices, lu_matrix
def solve_lu_system(
pivot_indices: list[int], lu_matrix: FloatMatrix, rhs: list[float]
) -> list[float] | None:
size = len(lu_matrix)
y = [0.0 for _ in range(size)]
for i in range(size):
y[i] = rhs[pivot_indices[i]]
for i in range(size):
for j in range(i):
y[i] -= lu_matrix[i][j] * y[j]
x = [0.0 for _ in range(size)]
for i in range(size - 1, -1, -1):
value = y[i]
for j in range(i + 1, size):
value -= lu_matrix[i][j] * x[j]
pivot = lu_matrix[i][i]
if abs(pivot) < MATRIX_SINGULARITY_THRESHOLD:
return None
x[i] = value / pivot
return x
def exp_score_matrix(score_matrix: FloatMatrix, lambda_value: float) -> FloatMatrix:
return [[math.exp(lambda_value * score) for score in row] for row in score_matrix]
def solve_lu_eigenvector(exp_matrix: FloatMatrix) -> list[float] | None:
size = len(exp_matrix)
if size <= 0:
return None
factored = lu_factorize(exp_matrix)
if factored is None:
return None
pivot_indices, lu_matrix = factored
eigenvector = [0.0 for _ in range(size)]
for j in range(size):
rhs = [0.0 for _ in range(size)]
rhs[j] = 1.0
solution = solve_lu_system(pivot_indices, lu_matrix, rhs)
if solution is None:
return None
for i in range(size):
eigenvector[i] += solution[i]
total = sum(eigenvector)
if total <= 0.0:
return None
eigenvector = [value / total for value in eigenvector]
if any(value < 0.0 for value in eigenvector):
return None
return eigenvector
def solve_power_probs(
exp_matrix: FloatMatrix,
) -> list[float] | None:
size = len(exp_matrix)
if size <= 0:
return None
p = [1.0 / size for _ in range(size)]
for _ in range(MAX_POWER_ITERATIONS):
ap = [0.0 for _ in range(size)]
for i in range(size):
for j in range(size):
ap[i] += exp_matrix[i][j] * p[j]
norm = math.sqrt(sum(value * value for value in ap))
if norm <= 0.0:
return None
ap = [value / norm for value in ap]
max_change = max(abs(ap[i] - p[i]) for i in range(size))
p = ap
if max_change < POWER_ITERATION_TOLERANCE:
break
total = sum(p)
if total <= 0.0:
return None
return [value / total for value in p]
def solve_bg_at_lambda(
score_matrix: FloatMatrix,
lambda_value: float,
score_index_groups: ScoreIndexGroups,
) -> tuple[float, float, list[float]] | None:
if len(score_matrix) <= 0 or lambda_value <= 0.0:
return None
exp_matrix = exp_score_matrix(score_matrix, lambda_value)
solved_probs = solve_lu_eigenvector(exp_matrix)
if solved_probs is None:
solved_probs = solve_power_probs(exp_matrix)
if solved_probs is None:
return None
groups = weight_score_groups(score_index_groups, solved_probs, solved_probs)
f_value = lambda_constraint(groups, lambda_value)
f_prime = lambda_constraint_deriv(groups, lambda_value)
return f_value, f_prime, solved_probs
def bisect_lambda_bg(
score_matrix: FloatMatrix,
lambda_upper_bound: float,
score_index_groups: ScoreIndexGroups,
) -> tuple[float, list[float]]:
lambda_low = LAMBDA_LOWER_BOUND
low_result = solve_bg_at_lambda(score_matrix, lambda_low, score_index_groups)
if low_result is None:
raise ValueError("could not solve background at low lambda")
f_low, _, _ = low_result
lambda_high = BRACKET_START_LAMBDA
high_result: tuple[float, float, list[float]] | None = None
upper_limit = lambda_upper_bound * LAMBDA_UPPER_BOUND_SAFETY
tried_upper_limit = False
for _ in range(MAX_BRACKET_ATTEMPTS):
high_result = solve_bg_at_lambda(score_matrix, lambda_high, score_index_groups)
if high_result is not None and f_low <= 0.0 < high_result[0]:
break
lambda_high *= BRACKET_INCREMENT_FACTOR
if not tried_upper_limit and lambda_high > upper_limit:
lambda_high = upper_limit
tried_upper_limit = True
high_result = solve_bg_at_lambda(
score_matrix, lambda_high, score_index_groups
)
if high_result is not None and f_low <= 0.0 < high_result[0]:
break
else:
raise ValueError("failed to bracket solved lambda")
best_lambda = lambda_high
best_probs = high_result[2]
for _ in range(MAX_BISECTION_ITERATIONS):
if lambda_high - lambda_low < BISECTION_CONVERGENCE_TOLERANCE:
break
mid = 0.5 * (lambda_low + lambda_high)
mid_result = solve_bg_at_lambda(score_matrix, mid, score_index_groups)
if mid_result is None:
raise ValueError(f"could not solve background at lambda={mid}")
f_mid, _, probs_mid = mid_result
best_lambda = mid
best_probs = probs_mid
if abs(f_mid) < NEWTON_CONVERGENCE_TOLERANCE:
break
if f_mid > 0.0:
lambda_high = mid
else:
lambda_low = mid
return best_lambda, best_probs
def estimate_lambda_bg(
score_matrix: FloatMatrix,
) -> tuple[float, list[float]]:
size = len(score_matrix)
validate_score_matrix(score_matrix)
lambda_upper_bound = find_lambda_bound(score_matrix)
score_index_groups = group_score_indices(score_matrix)
p = [1.0 / size for _ in range(size)]
lambda_low, lambda_high = bracket_lambda(
score_index_groups, p, p, lambda_upper_bound
)
lambda_current = 0.5 * (lambda_low + lambda_high)
lambda_value = lambda_current
epsilon = NEWTON_CONVERGENCE_TOLERANCE
for iteration in range(MAX_NEWTON_ITERATIONS):
solved_result = solve_bg_at_lambda(
score_matrix, lambda_current, score_index_groups
)
if solved_result is None:
raise ValueError(f"could not solve background at lambda={lambda_current}")
f_value, f_prime, p_current = solved_result
if abs(f_value) < epsilon:
p = p_current
lambda_value = lambda_current
break
if abs(f_prime) < MATRIX_SINGULARITY_THRESHOLD:
break
newton_step = -f_value / f_prime
lambda_new = lambda_current + newton_step
if lambda_new < lambda_low or lambda_new > lambda_high:
if newton_step > 0.0:
damping_factor = min(
1.0,
(lambda_high - lambda_current) * DAMPING_FACTOR / newton_step,
)
else:
damping_factor = min(
1.0,
(lambda_low - lambda_current) * DAMPING_FACTOR / newton_step,
)
lambda_new = lambda_current + damping_factor * newton_step
if abs(lambda_new - lambda_current) < epsilon:
p = p_current
lambda_value = lambda_new
break
lambda_current = lambda_new
if iteration == MAX_NEWTON_ITERATIONS - 1:
p = p_current
lambda_value = lambda_current
final_groups = weight_score_groups(score_index_groups, p, p)
final_restriction = lambda_constraint(final_groups, lambda_value)
if abs(final_restriction) > RESTRICTION_VALIDATION_TOLERANCE:
lambda_value, p = bisect_lambda_bg(
score_matrix, lambda_upper_bound, score_index_groups
)
final_groups = weight_score_groups(score_index_groups, p, p)
final_restriction = lambda_constraint(final_groups, lambda_value)
if abs(final_restriction) > RESTRICTION_VALIDATION_TOLERANCE:
raise ValueError(
"lambda/background estimation failed final restriction check "
f"(|f(lambda)|={abs(final_restriction):.6g})"
)
if lambda_value <= 0.0:
raise ValueError("estimated lambda is non-positive")
return lambda_value, p
def parse_score_matrix(
path: Path,
) -> tuple[int, list[str], dict[str, int], bool, FloatMatrix]:
aa2num: dict[str, int] = {}
num2aa: list[str] = []
score_matrix: FloatMatrix | None = None
for line in path.read_text().splitlines():
words = line.split()
if not words or line.startswith("#"):
continue
if score_matrix is None:
if len(words) <= 1:
continue
for index, word in enumerate(words):
aa = word[0].upper()
if not aa.isalpha():
raise ValueError("Scoring matrix must start with alphabet header")
aa2num[aa] = index
num2aa.append(aa)
alphabet_size = len(num2aa)
score_matrix = [
[0.0 for _ in range(alphabet_size)] for _ in range(alphabet_size)
]
continue
if len(words) <= 1:
continue
aa_char = words[0][0].upper()
if not aa_char.isalpha():
raise ValueError("First element in matrix row must be an alphabet letter")
row = aa2num.get(aa_char)
if row is None:
raise ValueError(f"{path}: unknown row residue {aa_char!r}")
alphabet_size = len(num2aa)
if len(words) < alphabet_size + 1:
raise ValueError(f"{path}: incomplete matrix row for {aa_char}")
for col in range(alphabet_size):
score_matrix[row][col] = float(words[col + 1])
if score_matrix is None:
raise ValueError(f"{path}: could not find alphabet header")
return len(num2aa), num2aa, aa2num, num2aa[-1] == "X", score_matrix
def x_has_real_positive(score_matrix: FloatMatrix, x_index: int) -> bool:
alphabet_size = len(score_matrix)
return any(
score_matrix[x_index][j] > 0.0 or score_matrix[j][x_index] > 0.0
for j in range(alphabet_size - 1)
)
def is_estimable_index(
score_matrix: FloatMatrix, index: int, candidates: list[int]
) -> bool:
row_has_positive = False
row_has_negative = False
col_has_positive = False
col_has_negative = False
for other in candidates:
row_value = score_matrix[index][other]
col_value = score_matrix[other][index]
row_has_positive = row_has_positive or row_value > 0.0
row_has_negative = row_has_negative or row_value < 0.0
col_has_positive = col_has_positive or col_value > 0.0
col_has_negative = col_has_negative or col_value < 0.0
return (
row_has_positive and row_has_negative and col_has_positive and col_has_negative
)
def select_estimable_indices(
score_matrix: FloatMatrix, has_terminal_x: bool
) -> list[int]:
candidates = list(range(len(score_matrix)))
if has_terminal_x:
x_index = len(score_matrix) - 1
if not x_has_real_positive(score_matrix, x_index):
candidates.remove(x_index)
while True:
filtered = [
index
for index in candidates
if is_estimable_index(score_matrix, index, candidates)
]
if filtered == candidates:
break
candidates = filtered
if len(candidates) < 2:
raise ValueError("score matrix has fewer than two estimable states")
return candidates
def submatrix_for_indices(score_matrix: FloatMatrix, indices: list[int]) -> FloatMatrix:
return [[score_matrix[i][j] for j in indices] for i in indices]
def estimate_bg(
score_matrix: FloatMatrix, has_terminal_x: bool
) -> tuple[float, list[float]]:
alphabet_size = len(score_matrix)
fit_indices = select_estimable_indices(score_matrix, has_terminal_x)
fit_matrix = submatrix_for_indices(score_matrix, fit_indices)
lambda_value, fit_bg = estimate_lambda_bg(fit_matrix)
excluded_indices = set(range(alphabet_size)) - set(fit_indices)
scale = 1.0 - (ANY_BACK * len(excluded_indices))
if scale <= 0.0:
raise ValueError("too many non-estimable states for fixed background")
background = [0.0 for _ in range(alphabet_size)]
for local_index, original_index in enumerate(fit_indices):
background[original_index] = fit_bg[local_index] * scale
for index in excluded_indices:
background[index] = ANY_BACK
if has_terminal_x:
background[alphabet_size - 1] = ANY_BACK
return lambda_value, background
def pair_probs_from_scores(
score_matrix: FloatMatrix, background: list[float], lambda_value: float
) -> FloatMatrix:
alphabet_size = len(score_matrix)
return [
[
math.exp(lambda_value * score_matrix[i][j]) * background[i] * background[j]
for j in range(alphabet_size)
]
for i in range(alphabet_size)
]
def parse_matrix_file(path: Path) -> ParsedMatrix:
alphabet_size, num2aa, aa2num, has_terminal_x, score_matrix = parse_score_matrix(
path
)
lambda_value, background = estimate_bg(score_matrix, has_terminal_x)
prob_matrix = pair_probs_from_scores(score_matrix, background, lambda_value)
return ParsedMatrix(
alphabet_size=alphabet_size,
num2aa=num2aa,
aa2num=aa2num,
prob_matrix=prob_matrix,
has_terminal_x=has_terminal_x,
)
def zeros(size: int) -> FloatMatrix:
return [[0.0 for _ in range(size)] for _ in range(size)]
def compute_background(
prob_matrix: FloatMatrix, alphabet_size: int, contains_x: bool
) -> list[float]:
background = [0.0 for _ in range(alphabet_size)]
for i in range(alphabet_size):
background[i] = sum(prob_matrix[i][j] for j in range(alphabet_size))
if contains_x:
background[alphabet_size - 1] = ANY_BACK
return background
def generate_log_odds(
prob_matrix: FloatMatrix, alphabet_size: int, contains_x: bool
) -> FloatMatrix:
background = compute_background(prob_matrix, alphabet_size, contains_x)
scores = zeros(alphabet_size)
for i in range(alphabet_size):
for j in range(alphabet_size):
scores[i][j] = math.log2(
prob_matrix[i][j] / (background[i] * background[j])
)
return scores
def round_to_short(value: float) -> int:
return int(value - 0.5) if value < 0.0 else int(value + 0.5)
def generate_integer_scores(
prob_matrix: FloatMatrix,
alphabet_size: int,
bit_factor: float,
has_terminal_x: bool,
) -> IntMatrix:
log_odds = generate_log_odds(prob_matrix, alphabet_size, contains_x=has_terminal_x)
scores = [[0 for _ in range(alphabet_size)] for _ in range(alphabet_size)]
for i in range(alphabet_size):
for j in range(alphabet_size):
scores[i][j] = round_to_short(bit_factor * log_odds[i][j])
return scores
def log_odds_mi(pair_probs: FloatMatrix, size: int) -> float:
background = compute_background(pair_probs, size, contains_x=False)
mi = 0.0
for i in range(size):
for j in range(size):
mi += pair_probs[i][j] * math.log2(
pair_probs[i][j] / (background[i] * background[j])
)
return mi
def merge_pair_probs(
pair_probs: FloatMatrix, size: int, kept_idx: int, removed_idx: int
) -> FloatMatrix:
"""Merge removed_idx into kept_idx, preserving row/column order."""
merged_pair_probs = zeros(size - 1)
kept = [idx for idx in range(size) if idx != removed_idx]
for new_i, old_i in enumerate(kept):
for new_j, old_j in enumerate(kept):
value = pair_probs[old_i][old_j]
if old_j == kept_idx:
value += pair_probs[old_i][removed_idx]
if old_i == kept_idx:
value += pair_probs[removed_idx][old_j]
if old_j == kept_idx:
value += pair_probs[removed_idx][removed_idx]
merged_pair_probs[new_i][new_j] = value
return merged_pair_probs
def find_best_mi_merge(
pair_probs: FloatMatrix, size: int
) -> tuple[int, int, FloatMatrix]:
best_mi = -math.inf
best_i = 0
best_j = 1
best_pair_probs: FloatMatrix | None = None
for i in range(size):
for j in range(i + 1, size):
trial_pair_probs = merge_pair_probs(pair_probs, size, i, j)
mi = log_odds_mi(trial_pair_probs, size - 1)
if mi > best_mi:
best_mi = mi
best_i = i
best_j = j
best_pair_probs = trial_pair_probs
if best_pair_probs is None:
raise ValueError("no merge candidates available")
return best_i, best_j, best_pair_probs
def build_group_strings(
reduced_alphabet: list[str],
aa2num: dict[str, int],
source_alphabet: list[str],
) -> list[str]:
aa_groups: list[str] = []
for representative in reduced_alphabet:
representative_index = aa2num[representative]
members = [
aa for aa in source_alphabet if aa2num[aa] == representative_index
]
aa_groups.append("(" + " ".join(members) + ")")
return aa_groups
def reduce_matrix(
parsed: ParsedMatrix, target_size: int, bit_factor: float
) -> ReducedResult:
original_size = parsed.alphabet_size
if target_size >= original_size:
raise ValueError("Reduced alphabet has to be smaller than the original one")
if target_size < 2:
raise ValueError("Reduced alphabet size must be at least 2")
has_terminal_x = parsed.has_terminal_x
start_merge_size = original_size - 1 if has_terminal_x else original_size
target_merge_size = target_size - 1 if has_terminal_x else target_size
work_pair_probs = [
row[:start_merge_size] for row in parsed.prob_matrix[:start_merge_size]
]
work_aa2num = parsed.aa2num.copy()
reduced_alphabet = parsed.num2aa[:]
reduce_steps = start_merge_size - target_merge_size
merge_size = start_merge_size
for _ in range(reduce_steps):
kept_idx, removed_idx, merged_pair_probs = find_best_mi_merge(
work_pair_probs, merge_size
)
kept_aa = reduced_alphabet[kept_idx]
removed_aa = reduced_alphabet[removed_idx]
del reduced_alphabet[removed_idx]
kept_int = work_aa2num[kept_aa]
removed_int = work_aa2num[removed_aa]
for aa in parsed.num2aa:
if work_aa2num[aa] == removed_int:
work_aa2num[aa] = kept_int
merge_size -= 1
work_pair_probs = merged_pair_probs
aa_groups = build_group_strings(
reduced_alphabet,
work_aa2num,
parsed.num2aa,
)
reduced_num2aa = reduced_alphabet[:]
final_pair_probs = zeros(target_size)
for i in range(merge_size):
for j in range(merge_size):
final_pair_probs[i][j] = work_pair_probs[i][j]
if has_terminal_x:
reduced_bg = [0.0 for _ in range(original_size)]
for i in range(target_size):
reduced_bg[i] = sum(final_pair_probs[i][j] for j in range(target_size))
reduced_bg[target_size - 1] = ANY_BACK
# This uses the pre-compaction X mapping to match ReducedMatrix.cpp.
precompact_x_idx = work_aa2num["X"]
for i in range(target_size - 1):
reduced_bg[i] *= 1.0 - reduced_bg[precompact_x_idx]
orig_prob = parsed.prob_matrix
orig_bg = compute_background(orig_prob, original_size, contains_x=True)
original_x_index = original_size - 1
for i, aa in enumerate(reduced_num2aa):
old_index = work_aa2num[aa]
x_pair_odds = orig_prob[old_index][original_x_index] / (
orig_bg[old_index] * orig_bg[original_x_index]
)
final_pair_probs[target_size - 1][i] = (
x_pair_odds * reduced_bg[i] * reduced_bg[target_size - 1]
)
final_pair_probs[i][target_size - 1] = (
x_pair_odds * reduced_bg[target_size - 1] * reduced_bg[i]
)
score_matrix = generate_integer_scores(
final_pair_probs, target_size, bit_factor, has_terminal_x
)
return ReducedResult(
aa_groups=aa_groups,
alphabet=reduced_num2aa,
score_matrix=score_matrix,
)
def format_score_matrix(alphabet: list[str], score_matrix: IntMatrix) -> str:
lines = [" " + " ".join(f"{aa:>4}" for aa in alphabet)]
for aa, row in zip(alphabet, score_matrix):
lines.append(f"{aa} " + " ".join(f"{score:>4d}" for score in row))
return "\n".join(lines)
def format_result(result: ReducedResult) -> str:
groups = " ".join(result.aa_groups)
return (
"# Reduced amino acid alphabet:\n"
f"# {groups}\n"
f"{format_score_matrix(result.alphabet, result.score_matrix)}\n"
)
def parse_args(argv: list[str]) -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Compute and print a reduced amino-acid substitution matrix."
)
parser.add_argument("matrix", type=Path, help="Scoring matrix file")
parser.add_argument(
"alphabet_size",
type=int,
help="Reduced alphabet size (range 2 to original_size - 1)",
)
parser.add_argument(
"--bit-factor",
type=float,
default=2.0,
help="Score scaling factor for the printed integer matrix (default: 2.0)",
)
return parser.parse_args(argv)
def main(argv: list[str]) -> int:
args = parse_args(argv)
try:
parsed = parse_matrix_file(args.matrix)
reduced = reduce_matrix(parsed, args.alphabet_size, args.bit_factor)
except (OSError, ValueError, ZeroDivisionError) as exc:
print(f"error: {exc}", file=sys.stderr)
return 1
print(format_result(reduced), end="")
return 0
if __name__ == "__main__":
raise SystemExit(main(sys.argv[1:]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment