Last active
November 27, 2022 15:46
-
-
Save qguv/ca8809401efea5f682d61471d2c7cf50 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import collections | |
import functools | |
import itertools | |
import math | |
import primefac | |
import sys | |
tf_try_calls = [] | |
class ThreshholdError(Exception): | |
pass | |
def dict_combinations(d): | |
return [dict(zip(d.keys(), xs)) for xs in itertools.product(*d.values())] | |
def safe(fn): | |
def _safe_fn(*args, **kwargs): | |
try: | |
return fn(*args, **kwargs) | |
except Exception as e: | |
return str(e) | |
_safe_fn.__name__ = fn.__name__ | |
return _safe_fn | |
def tf(**kwarg_possibilities): | |
def _wrapper(fn): | |
tf_try_calls.append((safe(fn), kwarg_possibilities)) | |
return fn | |
return _wrapper | |
def get_digits(x, base=10): | |
if type(x) != int: | |
raise ValueError(f"can't convert {x} to base-{base} int") | |
if base == 2: | |
s = bin(x)[2:] | |
elif base == 8: | |
s = oct(x)[2:] | |
elif base == 10: | |
s = str(x) | |
elif base == 16: | |
s = hex(x)[2:] | |
else: | |
raise ValueError(f"can't convert {x} to base-{base} int") | |
return [int(c, base=base) for c in s] | |
@tf(base=[2, 8, 10, 16], holy_four=[True, False], hex_upper=[True, False]) | |
@functools.cache | |
def count_holes(n, holy_four=True, hex_upper=False, base=10): | |
count = collections.Counter(get_digits(n, base=base)) | |
hole_digits = [0, 6, 8, 8, 9, 10, 13] | |
if holy_four: | |
hole_digits.append(4) | |
if hex_upper: | |
# two-hole B, no-hole E | |
hole_digits.extend([11, 11]) | |
else: | |
# one-hole b, one-hole e | |
hole_digits.extend([11, 14]) | |
return sum(count[x] for x in hole_digits) | |
@tf(power_base=list(range(2, 17))) | |
@functools.cache | |
def is_power_of(n, power_base=2, precision=15): | |
if n == 0: | |
return False | |
return round(math.log(n, power_base), precision).is_integer() | |
@tf() | |
@functools.cache | |
def count_syllables(x, log=lambda x: None): | |
if not x: | |
log('zero') | |
return 2 | |
return _count_syllables(x, log=log) | |
def _count_syllables(x, suffix=None, log=lambda x: None): | |
syllables = 0 | |
if x < 0: | |
log('negative') | |
x *= -1 | |
if x > 999: | |
thousands = x // 1000 | |
syllables += _count_syllables(thousands) | |
log('thousand') | |
syllables += 2 | |
x %= 1000 | |
if x > 99: | |
hundreds = x // 100 | |
syllables += _count_syllables(hundreds) | |
log('hundred') | |
syllables += 2 | |
x %= 100 | |
if x > 19: | |
tens = x // 10 | |
syllables += _count_syllables(tens, suffix='ty') | |
x %= 10 | |
if x > 12: | |
syllables += _count_syllables(x % 10, suffix='teen') | |
return syllables | |
if x > 9: | |
log(['ten', 'eleven', 'twelve'][x % 10]) | |
syllables += 3 if x == 11 else 1 | |
return syllables | |
if x > 0: | |
if suffix: | |
log(['', '', 'twen', 'thir', 'four', 'fif', 'six', 'seven', 'eigh', 'nine'][x] + suffix) | |
else: | |
log(['', 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'ten', 'eleven', 'twelve'][x]) | |
syllables += 2 if x == 7 else 1 | |
return syllables | |
return 0 | |
@tf(base=[2, 8, 10, 16]) | |
@functools.cache | |
def digits_are_sorted(x, base=10): | |
digits = get_digits(x, base=base) | |
sorted_digits = sorted(digits) | |
return digits == sorted_digits or digits == sorted_digits[::-1] | |
def concat(digits, base=10): | |
x = 0 | |
for digit in digits: | |
x *= base | |
x += digit | |
return x | |
# @tf(base=[2, 8, 10, 16], reverse=[False, True]) | |
def sort_digits(x, base=10, reverse=False): | |
digits = get_digits(x, base=base) | |
return concat(sorted(digits, reverse=reverse), base=base) | |
@tf(base=[2], remove_digit=list(range(2))) | |
@tf(base=[8], remove_digit=list(range(8))) | |
@tf(base=[10], remove_digit=list(range(10))) | |
@tf(base=[16], remove_digit=list(range(16))) | |
def remove_digit(x, base=10, remove_digit=0): | |
digits = get_digits(x, base=base) | |
return concat((x for x in digits if x != remove_digit), base=base) | |
@tf(base=[2], digit_from=list(range(2)), digit_to=list(range(2))) | |
@tf(base=[8], digit_from=list(range(8)), digit_to=list(range(8))) | |
@tf(base=[10], digit_from=list(range(10)), digit_to=list(range(10))) | |
@tf(base=[16], digit_from=list(range(16)), digit_to=list(range(16))) | |
def replace_digit(x, base=10, digit_from=0, digit_to=0): | |
if digit_from == digit_to: | |
return x | |
digits = get_digits(x, base=base) | |
return concat((digit_to if x == digit_from else x for x in digits), base=base) | |
@tf(base=[2, 8, 10, 16], reverse=[False, True]) | |
def sort_unique_digits(x, base=10, reverse=False): | |
digits = set(get_digits(x, base=base)) | |
return concat(sorted(digits, reverse=reverse)) | |
@tf(base=[2, 8, 10, 16]) | |
@functools.cache | |
def digit_sum(x, base=10): | |
return sum(get_digits(x, base=base)) | |
@tf(base=[2, 8, 10, 16]) | |
@functools.cache | |
def unique_digit_sum(x, base=10): | |
return sum(set(get_digits(x, base=base))) | |
def product(xs): | |
return functools.reduce(lambda x, y: x * y, xs, 1) | |
@tf(base=[2, 8, 10, 16]) | |
@functools.cache | |
def digit_product(x, base=10): | |
return product(get_digits(x, base=base)) | |
@tf(base=[2, 8, 10, 16]) | |
@functools.cache | |
def unique_digit_product(x, base=10): | |
return product(set(get_digits(x, base=base))) | |
@functools.cache | |
def get_prime_factors(x): | |
return tuple(primefac.primefac(x)) | |
def get_unique_prime_factors(x): | |
return tuple(sorted(set(get_prime_factors(x)))) | |
@tf() | |
def count_prime_factors(x): | |
return len(get_prime_factors(x)) | |
@tf() | |
def count_unique_prime_factors(x): | |
return len(set(get_prime_factors(x))) | |
@tf() | |
@functools.cache | |
def is_prime(x): | |
return primefac.isprime(x) | |
@tf() | |
@functools.cache | |
def prime_factors_sum(x): | |
return sum(get_prime_factors(x)) | |
@tf() | |
@functools.cache | |
def unique_prime_factors_sum(x): | |
return sum(set(get_prime_factors(x))) | |
def has_duplicates(xs): | |
seen = set() | |
for x in xs: | |
if x in seen: | |
return True | |
seen.add(x) | |
return False | |
@tf(base=[2, 8, 10, 16]) | |
@functools.cache | |
def digits_have_duplicates(x, base=10): | |
return has_duplicates(get_digits(x, base=base)) | |
@tf() | |
@functools.cache | |
def prime_factors_have_duplicates(x): | |
return has_duplicates(get_prime_factors(x)) | |
def reduce_digital(fn, x, base=10): | |
while True: | |
xs = get_digits(x, base=base) | |
if len(xs) == 1: | |
return x | |
x = fn(xs) | |
@tf(base=[2, 8, 10, 16]) | |
@functools.cache | |
def digital_sum(x, base=10): | |
return reduce_digital(sum, x, base=base) | |
@tf(base=[2, 8, 10, 16]) | |
@functools.cache | |
def digital_product(x, base=10): | |
''' | |
>>> digital_product(0, base=10) | |
0 | |
>>> digital_product(1, base=10) | |
1 | |
>>> digital_product(24, base=10) | |
8 | |
>>> digital_product(25, base=10) | |
0 | |
>>> digital_product(26, base=10) | |
2 | |
''' | |
return reduce_digital(product, x, base=base) | |
@tf(target_digit=list(range(10)), base=[2, 8, 10, 16]) | |
def count_consecutive_digits(n, target_digit=0, base=10): | |
''' | |
>>> count_consecutive_digits(1020030002001, target_digit=0, base=10) | |
3 | |
>>> count_consecutive_digits(1, target_digit=0, base=10) | |
0 | |
''' | |
current = 0 | |
longest = 0 | |
for digit in get_digits(n, base=base): | |
if digit == target_digit: | |
current += 1 | |
longest = max(current, longest) | |
if digit != target_digit: | |
current = 0 | |
return longest | |
@tf(divisor=list(range(2, 16))) | |
def mod(x, divisor=2): | |
return x % divisor | |
# @tf(base=[16], digit=list(range(16))) | |
# @tf(base=[10], digit=list(range(10))) | |
# @tf(base=[8], digit=list(range(8))) | |
# @tf(base=[2], digit=list(range(2))) | |
def count_digit(x, base=10, digit=0): | |
digits = get_digits(x, base=base) | |
return sum(1 for d in digits if d == digit) | |
def get_call_name(call): | |
fn, kwargs = call | |
if not kwargs: | |
return fn.__name__ | |
return f'{fn.__name__}({str(kwargs)[1:-1]})' if kwargs else fn.__name__ | |
def get_overlap(intersection, htbn_mapped, not_htbn_mapped): | |
mapped = list(htbn_mapped.values()) + list(not_htbn_mapped.values()) | |
amount = sum(1 for x in mapped if x in intersection) | |
percent = 100 * amount / len(mapped) | |
return percent | |
def fmt_set(xs): | |
items = ', '.join(str(x) for x in sorted(xs)) | |
if len(xs) == 1: | |
return items | |
return f'in {{{items}}}' | |
def fmt_rule(name, calls, htbn_mapped, not_htbn_mapped, intersection): | |
positive_rule, positive_penalty = _fmt_rule(name, calls, htbn_mapped, not_htbn_mapped, intersection) | |
print(f"positive rule ({positive_penalty}): {positive_rule}") | |
negative_rule, negative_penalty = _fmt_rule(name, calls, not_htbn_mapped, htbn_mapped, intersection) | |
print(f"negative rule ({negative_penalty}): {negative_rule}") | |
if negative_penalty < positive_penalty: | |
return (f"AKHTBN unless {negative_rule}", negative_penalty) | |
return (f"AKHTBN if {positive_rule}", positive_penalty) | |
def _fmt_rule(name, calls, htbn_mapped, not_htbn_mapped, intersection, negative=False): | |
in_set = set() | |
positive_exceptions = set() | |
for k, v in htbn_mapped.items(): | |
if v in intersection: | |
positive_exceptions.add(k) | |
else: | |
in_set.add(v) | |
negative_exceptions = set() | |
for k, v in not_htbn_mapped.items(): | |
if v in intersection: | |
negative_exceptions.add(k) | |
if len(positive_exceptions) > len(negative_exceptions): | |
positive_exceptions = set() | |
in_set.update(intersection) | |
rule = f"{name} is {fmt_set(in_set)} and is not {fmt_set(negative_exceptions)}" | |
elif positive_exceptions: | |
negative_exceptions = set() | |
rule = f"{name} is {fmt_set(in_set)} or is {fmt_set(positive_exceptions)}" | |
else: | |
rule = f"{name} is {fmt_set(in_set)}" | |
penalty = 0 | |
penalty += 3 * len(calls) | |
penalty += len(in_set) | |
penalty += 4 * (len(positive_exceptions) + len(negative_exceptions)) | |
# exotic bases penalty | |
if any(call[1].get('base', 10) != 10 for call in calls): | |
penalty += 5 | |
return (rule, penalty) | |
def print_compose(htbn, not_htbn, calls, retention_percent=50, overlap_percent=20): | |
htbn_mapped = {k: k for k in htbn} | |
not_htbn_mapped = {k: k for k in not_htbn} | |
for fn, kwargs in calls: | |
last_htbn_mapped = htbn_mapped.copy() | |
last_not_htbn_mapped = not_htbn_mapped.copy() | |
for orig, x in htbn_mapped.items(): | |
htbn_mapped[orig] = fn(x, **kwargs) | |
for orig, x in not_htbn_mapped.items(): | |
not_htbn_mapped[orig] = fn(x, **kwargs) | |
# abort if any function leaves koans unchanged | |
if last_htbn_mapped == htbn_mapped and last_not_htbn_mapped == not_htbn_mapped: | |
raise ThreshholdError("function leaves koans unchanged") | |
htbn_unique = set(htbn_mapped.values()) | |
not_htbn_unique = set(not_htbn_mapped.values()) | |
intersection = htbn_unique.intersection(not_htbn_unique) | |
overlap = get_overlap(intersection, htbn_mapped, not_htbn_mapped) | |
if overlap > overlap_percent: | |
raise ThreshholdError(f"overlap {overlap}% exceeds threshhold {overlap_percent}%") | |
htbn_retention = 100 * len(htbn_unique) / len(htbn) | |
not_htbn_retention = 100 * len(not_htbn_unique) / len(not_htbn) | |
if htbn_retention > retention_percent: | |
raise ThreshholdError(f"retention {htbn_retention} in koans with BN doesn't meet threshhold {retention_percent}") | |
if not_htbn_retention > retention_percent: | |
raise ThreshholdError(f"retention {not_htbn_retention} in koans without BN doesn't meet threshhold {retention_percent}") | |
name = ', '.join(map(get_call_name, calls)) | |
rule, penalty = fmt_rule(name, calls, htbn_mapped, not_htbn_mapped, intersection) | |
details = f"\n=== {name} ===\n" | |
if not intersection and 1 in [len(htbn_unique), len(not_htbn_unique)]: | |
details += "#############################\n" | |
details += "### LIKELY FOUND THE RULE ###\n" | |
details += "#############################\n" | |
details += f'penalty: {penalty}\n' | |
details += f"htbn retention: {htbn_retention}%\n" | |
details += f"not_htbn retention: {not_htbn_retention}%\n" | |
details += f"overlap: {overlap}%\n" | |
details += print_table(htbn_mapped, not_htbn_mapped, intersection) | |
print(f'\n{penalty} {name}\n') | |
return (rule, penalty, details) | |
def print_table(htbn_mapped, not_htbn_mapped, intersection): | |
fmt = "%10s %10s%1s | %10s %10s%s\n" | |
msg = fmt % ('has ', 'fn(has)', '', 'not has ', 'fn(not has)', '') | |
lines = itertools.zip_longest(htbn_mapped.items(), not_htbn_mapped.items(), fillvalue=(None, '')) | |
for (h, hm), (nh, nhm) in lines: | |
h = '' if h is None else f'{h}:' | |
nh = '' if nh is None else f'{nh}:' | |
hi = '!' if hm in intersection else '' | |
nhi = '!' if nhm in intersection else '' | |
hm = 'Err' if type(hm) is ValueError else str(hm) | |
nhm = 'Err' if type(nhm) is ValueError else str(nhm) | |
msg += fmt % (h, hm, hi, nh, nhm, nhi) | |
return msg | |
def analyze_until_break(htbn, not_htbn, calls_to_try): | |
rules = [] | |
i = 0 | |
try: | |
while True: | |
i += 1 | |
print(f"\nround {i}:") | |
# choose a sequence of functions | |
for calls in itertools.product(calls_to_try, repeat=i): | |
# narrow down possible kwarg values to those suitable for all functions in this sequence | |
common_kwargs = dict() | |
for _, kwargs in calls: | |
for name, values in kwargs.items(): | |
try: | |
common_values = common_kwargs[name] | |
common_values.intersection_update(values) | |
except KeyError: | |
common_values = set(values) | |
common_kwargs[name] = common_values | |
# if any kwarg has no possibilities left, then we can't run this sequence of functions | |
if not all(common_kwargs.values()): | |
print('x', end='') | |
continue | |
print('>', end='') | |
# choose a particular value for those kwargs... | |
for pinned_kwargs in dict_combinations(common_kwargs): | |
# ...and use those values to make each call | |
pinned_calls = [(fn, {k: pinned_kwargs[k] for k in kwargs.keys()}) for fn, kwargs in calls] | |
try: | |
rules.append(print_compose(htbn, not_htbn, pinned_calls)) | |
print('o', end='') | |
except ThreshholdError as e: | |
print(f'\n{e}') # DEBUG | |
pass | |
input(f"\non to round {i+1}?") | |
except KeyboardInterrupt: | |
return rules | |
def analyze(htbn, not_htbn, calls_to_try): | |
rules = analyze_until_break(htbn, not_htbn, calls_to_try=calls_to_try) | |
rules.sort(key=lambda x: x[1]) | |
for rule, penalty, details in rules: | |
try: | |
input("\nPress enter to continue...") | |
except KeyboardInterrupt: | |
return | |
print(penalty, rule, f'\n{details}\n') | |
@safe | |
def lemon4(n, base=10): | |
errors = [] | |
if count_consecutive_digits(n, 0) > 1: | |
errors.append('consecutive zero digits') | |
digits = get_digits(n, base=base) | |
if digits[-1] == 4: | |
errors.append('mod 10 == 4') | |
p = product(digits) | |
if p % 2 == 1: | |
errors.append('digit product odd') | |
if errors: | |
raise ValueError(', '.join(errors)) | |
return 'ok' | |
@safe | |
def lemon5(n, base=10): | |
errors = [] | |
if count_consecutive_digits(n, 0) > 1: | |
errors.append('consecutive zero digits') | |
digits = get_digits(n, base=base) | |
if product(digits) % 2 == 1: | |
errors.append('digit product odd') | |
if errors: | |
raise ValueError(', '.join(errors)) | |
# return ','.join(str(x) for x in primefac.primefac(n)) | |
return f"{count_prime_factors(n)} prime factors" | |
def read_export(f): | |
htbn = [] | |
not_htbn = [] | |
for line in f: | |
line = line.strip() | |
if not line: | |
continue | |
line = line[1:-1].split('","') | |
try: | |
n = int(line[1]) | |
except ValueError: | |
continue | |
has = line[2] == '1' | |
group = htbn if has else not_htbn | |
group.append(n) | |
return (htbn, not_htbn) | |
if __name__ == "__main__": | |
with open(sys.argv[1], 'r') as f: | |
htbn, not_htbn = read_export(f) | |
htbn.sort() | |
not_htbn.sort() | |
# analyze(htbn, not_htbn, [(lemon5, dict())]) | |
analyze(htbn, not_htbn, tf_try_calls) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment