Created
April 29, 2026 22:33
-
-
Save apcamargo/e43da080bc6a1818cdab69e24fa3fc45 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python | |
| import argparse | |
| import math | |
| import sys | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| ANY_BACK = 1e-5 | |
| MATRIX_SINGULARITY_THRESHOLD = 1e-15 | |
| CONVERGENCE_TOLERANCE = 1e-10 | |
| NEWTON_CONVERGENCE_TOLERANCE = 1e-10 | |
| LAMBDA_LOWER_BOUND = 1e-6 | |
| BISECTION_CONVERGENCE_TOLERANCE = 1e-12 | |
| POWER_ITERATION_TOLERANCE = 1e-12 | |
| DAMPING_FACTOR = 0.8 | |
| LAMBDA_UPPER_BOUND_SAFETY = 0.95 | |
| PROBABILITY_VALIDATION_TOLERANCE = 1e-10 | |
| RESTRICTION_VALIDATION_TOLERANCE = 1e-6 | |
| ZERO_ROW_COL_THRESHOLD = 1e-10 | |
| MAX_NEWTON_ITERATIONS = 50 | |
| MAX_BISECTION_ITERATIONS = 100 | |
| MAX_POWER_ITERATIONS = 1000 | |
| MAX_BRACKET_ATTEMPTS = 50 | |
| BRACKET_INCREMENT_FACTOR = 1.5 | |
| BRACKET_START_LAMBDA = 0.01 | |
| FloatMatrix = list[list[float]] | |
| IntMatrix = list[list[int]] | |
| ScoreProbSums = list[tuple[float, float]] | |
| ScoreIndexGroups = list[tuple[float, list[tuple[int, int]]]] | |
| @dataclass | |
| class ParsedMatrix: | |
| alphabet_size: int | |
| num2aa: list[str] | |
| aa2num: dict[str, int] | |
| prob_matrix: FloatMatrix | |
| has_terminal_x: bool | |
| @dataclass | |
| class ReducedResult: | |
| aa_groups: list[str] | |
| alphabet: list[str] | |
| score_matrix: IntMatrix | |
| def validate_score_matrix(score_matrix: FloatMatrix) -> None: | |
| size = len(score_matrix) | |
| if size <= 0 or any(len(row) != size for row in score_matrix): | |
| raise ValueError("score matrix must be square and non-empty") | |
| values = [score for row in score_matrix for score in row] | |
| if any(math.isnan(score) or math.isinf(score) for score in values): | |
| raise ValueError("score matrix contains NaN or Inf") | |
| for i in range(size): | |
| for j in range(size): | |
| if abs(score_matrix[i][j] - score_matrix[j][i]) > CONVERGENCE_TOLERANCE: | |
| raise ValueError(f"score matrix is not symmetric at ({i}, {j})") | |
| if max(values) - min(values) <= 0.0: | |
| raise ValueError("score matrix has no score variation") | |
| def find_lambda_bound(score_matrix: FloatMatrix) -> float: | |
| size = len(score_matrix) | |
| row_sums = [sum(row) for row in score_matrix] | |
| col_sums = [sum(score_matrix[i][j] for i in range(size)) for j in range(size)] | |
| if any(abs(total) < ZERO_ROW_COL_THRESHOLD for total in row_sums) or any( | |
| abs(total) < ZERO_ROW_COL_THRESHOLD for total in col_sums | |
| ): | |
| raise ValueError("score matrix has near-zero row or column score sums") | |
| row_maxes = [0.0 for _ in range(size)] | |
| col_maxes = [0.0 for _ in range(size)] | |
| for i in range(size): | |
| row_pos = False | |
| row_neg = False | |
| for j in range(size): | |
| value = score_matrix[i][j] | |
| if value > 0.0: | |
| row_pos = True | |
| row_maxes[i] = max(row_maxes[i], value) | |
| col_maxes[j] = max(col_maxes[j], value) | |
| elif value < 0.0: | |
| row_neg = True | |
| if not (row_pos and row_neg): | |
| raise ValueError(f"score matrix row {i} lacks positive and negative scores") | |
| for j in range(size): | |
| col_pos = False | |
| col_neg = False | |
| for i in range(size): | |
| value = score_matrix[i][j] | |
| if value > 0.0: | |
| col_pos = True | |
| elif value < 0.0: | |
| col_neg = True | |
| if not (col_pos and col_neg): | |
| raise ValueError( | |
| f"score matrix column {j} lacks positive and negative scores" | |
| ) | |
| upper_bound = min(min(row_maxes), min(col_maxes)) | |
| if upper_bound <= 0.0: | |
| raise ValueError("invalid non-positive lambda upper bound") | |
| return upper_bound | |
| def group_score_indices(score_matrix: FloatMatrix) -> ScoreIndexGroups: | |
| groups: ScoreIndexGroups = [] | |
| for i, row in enumerate(score_matrix): | |
| for j, score in enumerate(row): | |
| found_index = None | |
| for index, (group_score, _) in enumerate(groups): | |
| if abs(group_score - score) < CONVERGENCE_TOLERANCE: | |
| found_index = index | |
| break | |
| if found_index is None: | |
| groups.append((score, [(i, j)])) | |
| else: | |
| groups[found_index][1].append((i, j)) | |
| return groups | |
| def weight_score_groups( | |
| score_index_groups: ScoreIndexGroups, p: list[float], q: list[float] | |
| ) -> ScoreProbSums: | |
| groups: ScoreProbSums = [] | |
| for score, indices in score_index_groups: | |
| prob_sum = 0.0 | |
| for i, j in indices: | |
| prob_sum += p[i] * q[j] | |
| groups.append((score, prob_sum)) | |
| return groups | |
| def lambda_constraint(groups: ScoreProbSums, lambda_value: float) -> float: | |
| return ( | |
| sum(prob_sum * math.exp(lambda_value * score) for score, prob_sum in groups) | |
| - 1.0 | |
| ) | |
| def lambda_constraint_deriv(groups: ScoreProbSums, lambda_value: float) -> float: | |
| return sum( | |
| prob_sum * score * math.exp(lambda_value * score) for score, prob_sum in groups | |
| ) | |
| def bracket_lambda( | |
| score_index_groups: ScoreIndexGroups, | |
| p: list[float], | |
| q: list[float], | |
| lambda_upper_bound: float, | |
| ) -> tuple[float, float]: | |
| groups = weight_score_groups(score_index_groups, p, q) | |
| f_zero = lambda_constraint(groups, 0.0) | |
| if abs(f_zero) > PROBABILITY_VALIDATION_TOLERANCE: | |
| raise ValueError("restriction value at lambda=0 is not close to zero") | |
| f_prime_zero = lambda_constraint_deriv(groups, 0.0) | |
| if f_prime_zero >= 0.0: | |
| raise ValueError("restriction derivative at lambda=0 is not negative") | |
| lambda_low = 0.0 | |
| lambda_high = BRACKET_START_LAMBDA | |
| tried_upper_limit = False | |
| upper_limit = lambda_upper_bound * LAMBDA_UPPER_BOUND_SAFETY | |
| for _ in range(MAX_BRACKET_ATTEMPTS): | |
| if lambda_constraint(groups, lambda_high) > 0.0: | |
| return lambda_low, lambda_high | |
| lambda_high *= BRACKET_INCREMENT_FACTOR | |
| if not tried_upper_limit and lambda_high > upper_limit: | |
| lambda_high = upper_limit | |
| tried_upper_limit = True | |
| if lambda_constraint(groups, lambda_high) > 0.0: | |
| return lambda_low, lambda_high | |
| raise ValueError("failed to bracket lambda") | |
| def lu_factorize(matrix: FloatMatrix) -> tuple[list[int], FloatMatrix] | None: | |
| size = len(matrix) | |
| lu_matrix = [row[:] for row in matrix] | |
| pivot_indices = list(range(size)) | |
| for k in range(size - 1): | |
| pivot = k | |
| max_value = abs(lu_matrix[k][k]) | |
| for i in range(k + 1, size): | |
| value = abs(lu_matrix[i][k]) | |
| if value > max_value: | |
| max_value = value | |
| pivot = i | |
| if max_value < MATRIX_SINGULARITY_THRESHOLD: | |
| return None | |
| if pivot != k: | |
| pivot_indices[k], pivot_indices[pivot] = ( | |
| pivot_indices[pivot], | |
| pivot_indices[k], | |
| ) | |
| lu_matrix[k], lu_matrix[pivot] = lu_matrix[pivot], lu_matrix[k] | |
| for i in range(k + 1, size): | |
| factor = lu_matrix[i][k] / lu_matrix[k][k] | |
| lu_matrix[i][k] = factor | |
| for j in range(k + 1, size): | |
| lu_matrix[i][j] -= factor * lu_matrix[k][j] | |
| return pivot_indices, lu_matrix | |
| def solve_lu_system( | |
| pivot_indices: list[int], lu_matrix: FloatMatrix, rhs: list[float] | |
| ) -> list[float] | None: | |
| size = len(lu_matrix) | |
| y = [0.0 for _ in range(size)] | |
| for i in range(size): | |
| y[i] = rhs[pivot_indices[i]] | |
| for i in range(size): | |
| for j in range(i): | |
| y[i] -= lu_matrix[i][j] * y[j] | |
| x = [0.0 for _ in range(size)] | |
| for i in range(size - 1, -1, -1): | |
| value = y[i] | |
| for j in range(i + 1, size): | |
| value -= lu_matrix[i][j] * x[j] | |
| pivot = lu_matrix[i][i] | |
| if abs(pivot) < MATRIX_SINGULARITY_THRESHOLD: | |
| return None | |
| x[i] = value / pivot | |
| return x | |
| def exp_score_matrix(score_matrix: FloatMatrix, lambda_value: float) -> FloatMatrix: | |
| return [[math.exp(lambda_value * score) for score in row] for row in score_matrix] | |
| def solve_lu_eigenvector(exp_matrix: FloatMatrix) -> list[float] | None: | |
| size = len(exp_matrix) | |
| if size <= 0: | |
| return None | |
| factored = lu_factorize(exp_matrix) | |
| if factored is None: | |
| return None | |
| pivot_indices, lu_matrix = factored | |
| eigenvector = [0.0 for _ in range(size)] | |
| for j in range(size): | |
| rhs = [0.0 for _ in range(size)] | |
| rhs[j] = 1.0 | |
| solution = solve_lu_system(pivot_indices, lu_matrix, rhs) | |
| if solution is None: | |
| return None | |
| for i in range(size): | |
| eigenvector[i] += solution[i] | |
| total = sum(eigenvector) | |
| if total <= 0.0: | |
| return None | |
| eigenvector = [value / total for value in eigenvector] | |
| if any(value < 0.0 for value in eigenvector): | |
| return None | |
| return eigenvector | |
| def solve_power_probs( | |
| exp_matrix: FloatMatrix, | |
| ) -> list[float] | None: | |
| size = len(exp_matrix) | |
| if size <= 0: | |
| return None | |
| p = [1.0 / size for _ in range(size)] | |
| for _ in range(MAX_POWER_ITERATIONS): | |
| ap = [0.0 for _ in range(size)] | |
| for i in range(size): | |
| for j in range(size): | |
| ap[i] += exp_matrix[i][j] * p[j] | |
| norm = math.sqrt(sum(value * value for value in ap)) | |
| if norm <= 0.0: | |
| return None | |
| ap = [value / norm for value in ap] | |
| max_change = max(abs(ap[i] - p[i]) for i in range(size)) | |
| p = ap | |
| if max_change < POWER_ITERATION_TOLERANCE: | |
| break | |
| total = sum(p) | |
| if total <= 0.0: | |
| return None | |
| return [value / total for value in p] | |
| def solve_bg_at_lambda( | |
| score_matrix: FloatMatrix, | |
| lambda_value: float, | |
| score_index_groups: ScoreIndexGroups, | |
| ) -> tuple[float, float, list[float]] | None: | |
| if len(score_matrix) <= 0 or lambda_value <= 0.0: | |
| return None | |
| exp_matrix = exp_score_matrix(score_matrix, lambda_value) | |
| solved_probs = solve_lu_eigenvector(exp_matrix) | |
| if solved_probs is None: | |
| solved_probs = solve_power_probs(exp_matrix) | |
| if solved_probs is None: | |
| return None | |
| groups = weight_score_groups(score_index_groups, solved_probs, solved_probs) | |
| f_value = lambda_constraint(groups, lambda_value) | |
| f_prime = lambda_constraint_deriv(groups, lambda_value) | |
| return f_value, f_prime, solved_probs | |
| def bisect_lambda_bg( | |
| score_matrix: FloatMatrix, | |
| lambda_upper_bound: float, | |
| score_index_groups: ScoreIndexGroups, | |
| ) -> tuple[float, list[float]]: | |
| lambda_low = LAMBDA_LOWER_BOUND | |
| low_result = solve_bg_at_lambda(score_matrix, lambda_low, score_index_groups) | |
| if low_result is None: | |
| raise ValueError("could not solve background at low lambda") | |
| f_low, _, _ = low_result | |
| lambda_high = BRACKET_START_LAMBDA | |
| high_result: tuple[float, float, list[float]] | None = None | |
| upper_limit = lambda_upper_bound * LAMBDA_UPPER_BOUND_SAFETY | |
| tried_upper_limit = False | |
| for _ in range(MAX_BRACKET_ATTEMPTS): | |
| high_result = solve_bg_at_lambda(score_matrix, lambda_high, score_index_groups) | |
| if high_result is not None and f_low <= 0.0 < high_result[0]: | |
| break | |
| lambda_high *= BRACKET_INCREMENT_FACTOR | |
| if not tried_upper_limit and lambda_high > upper_limit: | |
| lambda_high = upper_limit | |
| tried_upper_limit = True | |
| high_result = solve_bg_at_lambda( | |
| score_matrix, lambda_high, score_index_groups | |
| ) | |
| if high_result is not None and f_low <= 0.0 < high_result[0]: | |
| break | |
| else: | |
| raise ValueError("failed to bracket solved lambda") | |
| best_lambda = lambda_high | |
| best_probs = high_result[2] | |
| for _ in range(MAX_BISECTION_ITERATIONS): | |
| if lambda_high - lambda_low < BISECTION_CONVERGENCE_TOLERANCE: | |
| break | |
| mid = 0.5 * (lambda_low + lambda_high) | |
| mid_result = solve_bg_at_lambda(score_matrix, mid, score_index_groups) | |
| if mid_result is None: | |
| raise ValueError(f"could not solve background at lambda={mid}") | |
| f_mid, _, probs_mid = mid_result | |
| best_lambda = mid | |
| best_probs = probs_mid | |
| if abs(f_mid) < NEWTON_CONVERGENCE_TOLERANCE: | |
| break | |
| if f_mid > 0.0: | |
| lambda_high = mid | |
| else: | |
| lambda_low = mid | |
| return best_lambda, best_probs | |
| def estimate_lambda_bg( | |
| score_matrix: FloatMatrix, | |
| ) -> tuple[float, list[float]]: | |
| size = len(score_matrix) | |
| validate_score_matrix(score_matrix) | |
| lambda_upper_bound = find_lambda_bound(score_matrix) | |
| score_index_groups = group_score_indices(score_matrix) | |
| p = [1.0 / size for _ in range(size)] | |
| lambda_low, lambda_high = bracket_lambda( | |
| score_index_groups, p, p, lambda_upper_bound | |
| ) | |
| lambda_current = 0.5 * (lambda_low + lambda_high) | |
| lambda_value = lambda_current | |
| epsilon = NEWTON_CONVERGENCE_TOLERANCE | |
| for iteration in range(MAX_NEWTON_ITERATIONS): | |
| solved_result = solve_bg_at_lambda( | |
| score_matrix, lambda_current, score_index_groups | |
| ) | |
| if solved_result is None: | |
| raise ValueError(f"could not solve background at lambda={lambda_current}") | |
| f_value, f_prime, p_current = solved_result | |
| if abs(f_value) < epsilon: | |
| p = p_current | |
| lambda_value = lambda_current | |
| break | |
| if abs(f_prime) < MATRIX_SINGULARITY_THRESHOLD: | |
| break | |
| newton_step = -f_value / f_prime | |
| lambda_new = lambda_current + newton_step | |
| if lambda_new < lambda_low or lambda_new > lambda_high: | |
| if newton_step > 0.0: | |
| damping_factor = min( | |
| 1.0, | |
| (lambda_high - lambda_current) * DAMPING_FACTOR / newton_step, | |
| ) | |
| else: | |
| damping_factor = min( | |
| 1.0, | |
| (lambda_low - lambda_current) * DAMPING_FACTOR / newton_step, | |
| ) | |
| lambda_new = lambda_current + damping_factor * newton_step | |
| if abs(lambda_new - lambda_current) < epsilon: | |
| p = p_current | |
| lambda_value = lambda_new | |
| break | |
| lambda_current = lambda_new | |
| if iteration == MAX_NEWTON_ITERATIONS - 1: | |
| p = p_current | |
| lambda_value = lambda_current | |
| final_groups = weight_score_groups(score_index_groups, p, p) | |
| final_restriction = lambda_constraint(final_groups, lambda_value) | |
| if abs(final_restriction) > RESTRICTION_VALIDATION_TOLERANCE: | |
| lambda_value, p = bisect_lambda_bg( | |
| score_matrix, lambda_upper_bound, score_index_groups | |
| ) | |
| final_groups = weight_score_groups(score_index_groups, p, p) | |
| final_restriction = lambda_constraint(final_groups, lambda_value) | |
| if abs(final_restriction) > RESTRICTION_VALIDATION_TOLERANCE: | |
| raise ValueError( | |
| "lambda/background estimation failed final restriction check " | |
| f"(|f(lambda)|={abs(final_restriction):.6g})" | |
| ) | |
| if lambda_value <= 0.0: | |
| raise ValueError("estimated lambda is non-positive") | |
| return lambda_value, p | |
| def parse_score_matrix( | |
| path: Path, | |
| ) -> tuple[int, list[str], dict[str, int], bool, FloatMatrix]: | |
| aa2num: dict[str, int] = {} | |
| num2aa: list[str] = [] | |
| score_matrix: FloatMatrix | None = None | |
| for line in path.read_text().splitlines(): | |
| words = line.split() | |
| if not words or line.startswith("#"): | |
| continue | |
| if score_matrix is None: | |
| if len(words) <= 1: | |
| continue | |
| for index, word in enumerate(words): | |
| aa = word[0].upper() | |
| if not aa.isalpha(): | |
| raise ValueError("Scoring matrix must start with alphabet header") | |
| aa2num[aa] = index | |
| num2aa.append(aa) | |
| alphabet_size = len(num2aa) | |
| score_matrix = [ | |
| [0.0 for _ in range(alphabet_size)] for _ in range(alphabet_size) | |
| ] | |
| continue | |
| if len(words) <= 1: | |
| continue | |
| aa_char = words[0][0].upper() | |
| if not aa_char.isalpha(): | |
| raise ValueError("First element in matrix row must be an alphabet letter") | |
| row = aa2num.get(aa_char) | |
| if row is None: | |
| raise ValueError(f"{path}: unknown row residue {aa_char!r}") | |
| alphabet_size = len(num2aa) | |
| if len(words) < alphabet_size + 1: | |
| raise ValueError(f"{path}: incomplete matrix row for {aa_char}") | |
| for col in range(alphabet_size): | |
| score_matrix[row][col] = float(words[col + 1]) | |
| if score_matrix is None: | |
| raise ValueError(f"{path}: could not find alphabet header") | |
| return len(num2aa), num2aa, aa2num, num2aa[-1] == "X", score_matrix | |
| def x_has_real_positive(score_matrix: FloatMatrix, x_index: int) -> bool: | |
| alphabet_size = len(score_matrix) | |
| return any( | |
| score_matrix[x_index][j] > 0.0 or score_matrix[j][x_index] > 0.0 | |
| for j in range(alphabet_size - 1) | |
| ) | |
| def is_estimable_index( | |
| score_matrix: FloatMatrix, index: int, candidates: list[int] | |
| ) -> bool: | |
| row_has_positive = False | |
| row_has_negative = False | |
| col_has_positive = False | |
| col_has_negative = False | |
| for other in candidates: | |
| row_value = score_matrix[index][other] | |
| col_value = score_matrix[other][index] | |
| row_has_positive = row_has_positive or row_value > 0.0 | |
| row_has_negative = row_has_negative or row_value < 0.0 | |
| col_has_positive = col_has_positive or col_value > 0.0 | |
| col_has_negative = col_has_negative or col_value < 0.0 | |
| return ( | |
| row_has_positive and row_has_negative and col_has_positive and col_has_negative | |
| ) | |
| def select_estimable_indices( | |
| score_matrix: FloatMatrix, has_terminal_x: bool | |
| ) -> list[int]: | |
| candidates = list(range(len(score_matrix))) | |
| if has_terminal_x: | |
| x_index = len(score_matrix) - 1 | |
| if not x_has_real_positive(score_matrix, x_index): | |
| candidates.remove(x_index) | |
| while True: | |
| filtered = [ | |
| index | |
| for index in candidates | |
| if is_estimable_index(score_matrix, index, candidates) | |
| ] | |
| if filtered == candidates: | |
| break | |
| candidates = filtered | |
| if len(candidates) < 2: | |
| raise ValueError("score matrix has fewer than two estimable states") | |
| return candidates | |
| def submatrix_for_indices(score_matrix: FloatMatrix, indices: list[int]) -> FloatMatrix: | |
| return [[score_matrix[i][j] for j in indices] for i in indices] | |
| def estimate_bg( | |
| score_matrix: FloatMatrix, has_terminal_x: bool | |
| ) -> tuple[float, list[float]]: | |
| alphabet_size = len(score_matrix) | |
| fit_indices = select_estimable_indices(score_matrix, has_terminal_x) | |
| fit_matrix = submatrix_for_indices(score_matrix, fit_indices) | |
| lambda_value, fit_bg = estimate_lambda_bg(fit_matrix) | |
| excluded_indices = set(range(alphabet_size)) - set(fit_indices) | |
| scale = 1.0 - (ANY_BACK * len(excluded_indices)) | |
| if scale <= 0.0: | |
| raise ValueError("too many non-estimable states for fixed background") | |
| background = [0.0 for _ in range(alphabet_size)] | |
| for local_index, original_index in enumerate(fit_indices): | |
| background[original_index] = fit_bg[local_index] * scale | |
| for index in excluded_indices: | |
| background[index] = ANY_BACK | |
| if has_terminal_x: | |
| background[alphabet_size - 1] = ANY_BACK | |
| return lambda_value, background | |
| def pair_probs_from_scores( | |
| score_matrix: FloatMatrix, background: list[float], lambda_value: float | |
| ) -> FloatMatrix: | |
| alphabet_size = len(score_matrix) | |
| return [ | |
| [ | |
| math.exp(lambda_value * score_matrix[i][j]) * background[i] * background[j] | |
| for j in range(alphabet_size) | |
| ] | |
| for i in range(alphabet_size) | |
| ] | |
| def parse_matrix_file(path: Path) -> ParsedMatrix: | |
| alphabet_size, num2aa, aa2num, has_terminal_x, score_matrix = parse_score_matrix( | |
| path | |
| ) | |
| lambda_value, background = estimate_bg(score_matrix, has_terminal_x) | |
| prob_matrix = pair_probs_from_scores(score_matrix, background, lambda_value) | |
| return ParsedMatrix( | |
| alphabet_size=alphabet_size, | |
| num2aa=num2aa, | |
| aa2num=aa2num, | |
| prob_matrix=prob_matrix, | |
| has_terminal_x=has_terminal_x, | |
| ) | |
| def zeros(size: int) -> FloatMatrix: | |
| return [[0.0 for _ in range(size)] for _ in range(size)] | |
| def compute_background( | |
| prob_matrix: FloatMatrix, alphabet_size: int, contains_x: bool | |
| ) -> list[float]: | |
| background = [0.0 for _ in range(alphabet_size)] | |
| for i in range(alphabet_size): | |
| background[i] = sum(prob_matrix[i][j] for j in range(alphabet_size)) | |
| if contains_x: | |
| background[alphabet_size - 1] = ANY_BACK | |
| return background | |
| def generate_log_odds( | |
| prob_matrix: FloatMatrix, alphabet_size: int, contains_x: bool | |
| ) -> FloatMatrix: | |
| background = compute_background(prob_matrix, alphabet_size, contains_x) | |
| scores = zeros(alphabet_size) | |
| for i in range(alphabet_size): | |
| for j in range(alphabet_size): | |
| scores[i][j] = math.log2( | |
| prob_matrix[i][j] / (background[i] * background[j]) | |
| ) | |
| return scores | |
| def round_to_short(value: float) -> int: | |
| return int(value - 0.5) if value < 0.0 else int(value + 0.5) | |
| def generate_integer_scores( | |
| prob_matrix: FloatMatrix, | |
| alphabet_size: int, | |
| bit_factor: float, | |
| has_terminal_x: bool, | |
| ) -> IntMatrix: | |
| log_odds = generate_log_odds(prob_matrix, alphabet_size, contains_x=has_terminal_x) | |
| scores = [[0 for _ in range(alphabet_size)] for _ in range(alphabet_size)] | |
| for i in range(alphabet_size): | |
| for j in range(alphabet_size): | |
| scores[i][j] = round_to_short(bit_factor * log_odds[i][j]) | |
| return scores | |
| def log_odds_mi(pair_probs: FloatMatrix, size: int) -> float: | |
| background = compute_background(pair_probs, size, contains_x=False) | |
| mi = 0.0 | |
| for i in range(size): | |
| for j in range(size): | |
| mi += pair_probs[i][j] * math.log2( | |
| pair_probs[i][j] / (background[i] * background[j]) | |
| ) | |
| return mi | |
| def merge_pair_probs( | |
| pair_probs: FloatMatrix, size: int, kept_idx: int, removed_idx: int | |
| ) -> FloatMatrix: | |
| """Merge removed_idx into kept_idx, preserving row/column order.""" | |
| merged_pair_probs = zeros(size - 1) | |
| kept = [idx for idx in range(size) if idx != removed_idx] | |
| for new_i, old_i in enumerate(kept): | |
| for new_j, old_j in enumerate(kept): | |
| value = pair_probs[old_i][old_j] | |
| if old_j == kept_idx: | |
| value += pair_probs[old_i][removed_idx] | |
| if old_i == kept_idx: | |
| value += pair_probs[removed_idx][old_j] | |
| if old_j == kept_idx: | |
| value += pair_probs[removed_idx][removed_idx] | |
| merged_pair_probs[new_i][new_j] = value | |
| return merged_pair_probs | |
| def find_best_mi_merge( | |
| pair_probs: FloatMatrix, size: int | |
| ) -> tuple[int, int, FloatMatrix]: | |
| best_mi = -math.inf | |
| best_i = 0 | |
| best_j = 1 | |
| best_pair_probs: FloatMatrix | None = None | |
| for i in range(size): | |
| for j in range(i + 1, size): | |
| trial_pair_probs = merge_pair_probs(pair_probs, size, i, j) | |
| mi = log_odds_mi(trial_pair_probs, size - 1) | |
| if mi > best_mi: | |
| best_mi = mi | |
| best_i = i | |
| best_j = j | |
| best_pair_probs = trial_pair_probs | |
| if best_pair_probs is None: | |
| raise ValueError("no merge candidates available") | |
| return best_i, best_j, best_pair_probs | |
| def build_group_strings( | |
| reduced_alphabet: list[str], | |
| aa2num: dict[str, int], | |
| source_alphabet: list[str], | |
| ) -> list[str]: | |
| aa_groups: list[str] = [] | |
| for representative in reduced_alphabet: | |
| representative_index = aa2num[representative] | |
| members = [ | |
| aa for aa in source_alphabet if aa2num[aa] == representative_index | |
| ] | |
| aa_groups.append("(" + " ".join(members) + ")") | |
| return aa_groups | |
| def reduce_matrix( | |
| parsed: ParsedMatrix, target_size: int, bit_factor: float | |
| ) -> ReducedResult: | |
| original_size = parsed.alphabet_size | |
| if target_size >= original_size: | |
| raise ValueError("Reduced alphabet has to be smaller than the original one") | |
| if target_size < 2: | |
| raise ValueError("Reduced alphabet size must be at least 2") | |
| has_terminal_x = parsed.has_terminal_x | |
| start_merge_size = original_size - 1 if has_terminal_x else original_size | |
| target_merge_size = target_size - 1 if has_terminal_x else target_size | |
| work_pair_probs = [ | |
| row[:start_merge_size] for row in parsed.prob_matrix[:start_merge_size] | |
| ] | |
| work_aa2num = parsed.aa2num.copy() | |
| reduced_alphabet = parsed.num2aa[:] | |
| reduce_steps = start_merge_size - target_merge_size | |
| merge_size = start_merge_size | |
| for _ in range(reduce_steps): | |
| kept_idx, removed_idx, merged_pair_probs = find_best_mi_merge( | |
| work_pair_probs, merge_size | |
| ) | |
| kept_aa = reduced_alphabet[kept_idx] | |
| removed_aa = reduced_alphabet[removed_idx] | |
| del reduced_alphabet[removed_idx] | |
| kept_int = work_aa2num[kept_aa] | |
| removed_int = work_aa2num[removed_aa] | |
| for aa in parsed.num2aa: | |
| if work_aa2num[aa] == removed_int: | |
| work_aa2num[aa] = kept_int | |
| merge_size -= 1 | |
| work_pair_probs = merged_pair_probs | |
| aa_groups = build_group_strings( | |
| reduced_alphabet, | |
| work_aa2num, | |
| parsed.num2aa, | |
| ) | |
| reduced_num2aa = reduced_alphabet[:] | |
| final_pair_probs = zeros(target_size) | |
| for i in range(merge_size): | |
| for j in range(merge_size): | |
| final_pair_probs[i][j] = work_pair_probs[i][j] | |
| if has_terminal_x: | |
| reduced_bg = [0.0 for _ in range(original_size)] | |
| for i in range(target_size): | |
| reduced_bg[i] = sum(final_pair_probs[i][j] for j in range(target_size)) | |
| reduced_bg[target_size - 1] = ANY_BACK | |
| # This uses the pre-compaction X mapping to match ReducedMatrix.cpp. | |
| precompact_x_idx = work_aa2num["X"] | |
| for i in range(target_size - 1): | |
| reduced_bg[i] *= 1.0 - reduced_bg[precompact_x_idx] | |
| orig_prob = parsed.prob_matrix | |
| orig_bg = compute_background(orig_prob, original_size, contains_x=True) | |
| original_x_index = original_size - 1 | |
| for i, aa in enumerate(reduced_num2aa): | |
| old_index = work_aa2num[aa] | |
| x_pair_odds = orig_prob[old_index][original_x_index] / ( | |
| orig_bg[old_index] * orig_bg[original_x_index] | |
| ) | |
| final_pair_probs[target_size - 1][i] = ( | |
| x_pair_odds * reduced_bg[i] * reduced_bg[target_size - 1] | |
| ) | |
| final_pair_probs[i][target_size - 1] = ( | |
| x_pair_odds * reduced_bg[target_size - 1] * reduced_bg[i] | |
| ) | |
| score_matrix = generate_integer_scores( | |
| final_pair_probs, target_size, bit_factor, has_terminal_x | |
| ) | |
| return ReducedResult( | |
| aa_groups=aa_groups, | |
| alphabet=reduced_num2aa, | |
| score_matrix=score_matrix, | |
| ) | |
| def format_score_matrix(alphabet: list[str], score_matrix: IntMatrix) -> str: | |
| lines = [" " + " ".join(f"{aa:>4}" for aa in alphabet)] | |
| for aa, row in zip(alphabet, score_matrix): | |
| lines.append(f"{aa} " + " ".join(f"{score:>4d}" for score in row)) | |
| return "\n".join(lines) | |
| def format_result(result: ReducedResult) -> str: | |
| groups = " ".join(result.aa_groups) | |
| return ( | |
| "# Reduced amino acid alphabet:\n" | |
| f"# {groups}\n" | |
| f"{format_score_matrix(result.alphabet, result.score_matrix)}\n" | |
| ) | |
| def parse_args(argv: list[str]) -> argparse.Namespace: | |
| parser = argparse.ArgumentParser( | |
| description="Compute and print a reduced amino-acid substitution matrix." | |
| ) | |
| parser.add_argument("matrix", type=Path, help="Scoring matrix file") | |
| parser.add_argument( | |
| "alphabet_size", | |
| type=int, | |
| help="Reduced alphabet size (range 2 to original_size - 1)", | |
| ) | |
| parser.add_argument( | |
| "--bit-factor", | |
| type=float, | |
| default=2.0, | |
| help="Score scaling factor for the printed integer matrix (default: 2.0)", | |
| ) | |
| return parser.parse_args(argv) | |
| def main(argv: list[str]) -> int: | |
| args = parse_args(argv) | |
| try: | |
| parsed = parse_matrix_file(args.matrix) | |
| reduced = reduce_matrix(parsed, args.alphabet_size, args.bit_factor) | |
| except (OSError, ValueError, ZeroDivisionError) as exc: | |
| print(f"error: {exc}", file=sys.stderr) | |
| return 1 | |
| print(format_result(reduced), end="") | |
| return 0 | |
| if __name__ == "__main__": | |
| raise SystemExit(main(sys.argv[1:])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment