Created
November 11, 2019 15:05
-
-
Save jw910731/47d3eb279173d3310ea7cb3142982c40 to your computer and use it in GitHub Desktop.
NPSC judge script
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import sys | |
| import subprocess as subproc | |
| import difflib | |
| import decimal | |
| from subprocess import PIPE | |
| from typing import * | |
| from argparse import ArgumentParser | |
| def get_dirlist(parent_dir: str)-> List[str]: | |
| files = os.listdir(parent_dir) | |
| ret = [] | |
| for val in files: | |
| if os.path.isdir(os.path.join(parent_dir, val)): | |
| ret.append(val) | |
| return ret | |
| def get_filelist(parent_dir: str)-> List[str]: | |
| files = os.listdir(parent_dir) | |
| ret = [] | |
| for val in files: | |
| if os.path.isfile(os.path.join(parent_dir, val)): | |
| ret.append(val) | |
| return ret | |
| def list_testdata(list: List[str])-> List[str]: | |
| ret = [] | |
| for v in list: | |
| name, ext = os.path.splitext(v) | |
| if name not in ret: | |
| ret.append(name) | |
| return ret | |
| def float_diff(a_list: List[str], b_list: List[str], precision: int)-> Iterator[Tuple[int, bool]]: | |
| mx_len = max(len(a_list), len(b_list)) | |
| def processor(): | |
| decimal.getcontext().prec = precision+5 | |
| for i in range(mx_len): | |
| try: | |
| if i < len(a_list): | |
| line_a = decimal.Decimal(a_list[i]).normalize() | |
| if i < len(b_list): | |
| line_b = decimal.Decimal(b_list[i]).normalize() | |
| except decimal.InvalidOperation: | |
| yield (i+1)*-1, False | |
| if line_a is not None and line_b is not None: | |
| if abs(line_a - line_b) < decimal.getcontext().power(10, decimal.getcontext().minus(precision)): | |
| yield i, True | |
| else: | |
| yield i, False | |
| else: | |
| yield i, False | |
| return processor() | |
| def list_cleanup(l: List): | |
| if l: | |
| if not l[-1]: | |
| del l[-1] | |
| # Main Function | |
| if __name__ == "__main__": | |
| # Argument process | |
| parser = ArgumentParser() | |
| parser.add_argument("testcase", help="Where NPSC testcase is located.") | |
| parser.add_argument("exec", help="The executable to be tested.") | |
| parser.add_argument("--float-presicion", "-f", dest="float_precise", help="To enable tolerance of round-off error with in given precision.") | |
| args = parser.parse_args() | |
| execute = os.path.abspath(args.exec) | |
| test_root = os.path.abspath(args.testcase) | |
| try: | |
| float_precision = int(args.float_precise) | |
| except ValueError: | |
| print('Invalid argument type', file=sys.stderr) | |
| exit(1) | |
| except TypeError: | |
| float_precision = None | |
| test_list = get_dirlist(test_root) | |
| # User input to choose which to test | |
| for k in range(len(test_list)): | |
| v = test_list[k] | |
| print('{}) {}'.format(k, v)) | |
| while True: | |
| raw = input('Input test number: ') | |
| try: | |
| choice = int(raw) | |
| except ValueError: | |
| print('Input Invalid') | |
| continue | |
| if choice > len(test_list)-1: | |
| print('Input Invalid') | |
| continue | |
| break | |
| work_root = os.path.join(test_root, test_list[choice], 'data', 'secret') | |
| work_case = list_testdata(get_filelist(work_root)) | |
| for case in work_case: | |
| print("========Start Case {}========".format(case)) | |
| in_f = open(os.path.join(work_root, case+'.in'), 'r') | |
| ans_f = open(os.path.join(work_root, case+'.ans'), 'r') | |
| ans = ans_f.read() | |
| proc = subproc.Popen([execute], stdin=PIPE, stdout=PIPE) | |
| out = proc.communicate(in_f.read().encode())[0].decode() | |
| if float_precision is not None: | |
| ans_list = ans.split('\n') | |
| out_list = out.split('\n') | |
| list_cleanup(ans_list) | |
| list_cleanup(out_list) | |
| ans_it = float_diff(ans_list, out_list, float_precision) | |
| prev_stat = True | |
| ans_diff = [] | |
| out_diff = [] | |
| for v in ans_it: | |
| (line, stat) = v | |
| if line < 0: | |
| print("Input on line {} is unable to convert to number.".format((line*-1)-1), file=sys.stderr) | |
| exit(1) | |
| if not stat: # if falsy | |
| ans_diff.append(ans_list[line] if line < len(ans_list) else "") | |
| out_diff.append(out_list[line] if line < len(out_list) else "") | |
| if stat != prev_stat and not prev_stat: # false -> true | |
| for v in ans_diff: | |
| print(v) | |
| print("--------") | |
| for v in out_diff: | |
| print(v) | |
| # clear output buffer | |
| ans_diff = [] | |
| out_diff = [] | |
| prev_stat = stat | |
| if not prev_stat: | |
| for v in ans_diff: | |
| print(v) | |
| print("--------") | |
| for v in out_diff: | |
| print(v) | |
| else: | |
| ans_list = ans.splitlines(True) | |
| out_list = out.splitlines(True) | |
| sys.stdout.writelines(difflib.context_diff(ans_list, out_list, fromfile="Correct Answer", tofile="Your Answer")) | |
| print("============================") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment