Created
May 7, 2017 07:45
-
-
Save tag1216/2358da45274a6ac781c33b0cc980fa14 to your computer and use it in GitHub Desktop.
巨大ファイルのソート
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import heapq | |
import os | |
import re | |
from argparse import ArgumentParser | |
from contextlib import contextmanager | |
from operator import itemgetter | |
from tempfile import TemporaryDirectory, mktemp | |
import sys | |
from typing import IO, Callable, List | |
def large_sort(input_file: IO, output_file: IO, key: Callable=None, reverse: bool=False, limit_chars: int=1024*1024*64): | |
with TemporaryDirectory() as tmp_dir: | |
for lines in _read_parts(input_file, limit_chars): | |
lines = sorted(lines, key=key, reverse=reverse) | |
_write_part(lines, tmp_dir) | |
with _open_tmp_files(tmp_dir) as tmp_files: | |
for row in heapq.merge(*tmp_files, key=key, reverse=reverse): | |
output_file.write(row) | |
def _read_parts(input_file, limit_chars): | |
lines = input_file.readlines(limit_chars) | |
while lines: | |
yield lines | |
lines = input_file.readlines(limit_chars) | |
def _write_part(lines, tmp_dir): | |
tmp_filename = mktemp(dir=tmp_dir) | |
with open(tmp_filename, "w") as tmp_file: | |
tmp_file.writelines(lines) | |
return tmp_filename | |
@contextmanager | |
def _open_tmp_files(tmp_dir): | |
filenames = os.listdir(tmp_dir) | |
files = [open(os.path.join(tmp_dir, filename), "r") for filename in filenames] | |
try: | |
yield files | |
finally: | |
for file in files: | |
file.close() | |
def key_func(keys: List[str]=None, separator: str=" "): | |
if not keys: | |
return None | |
pattern = re.compile("([0-9]+)(n?)") | |
getters = [] | |
for key in keys: | |
m = pattern.match(key) | |
column = int(m.group(1)) - 1 | |
number = bool(m.group(2)) | |
getter = _itemgetter_int(column) if number else itemgetter(column) | |
getters.append(getter) | |
def func(row): | |
values = row.strip("\n").split(separator) | |
return [f(values) for f in getters] | |
return func | |
def _itemgetter_int(index): | |
def f(x): | |
return int(x[index]) | |
return f | |
def _parse_args(): | |
parser = ArgumentParser() | |
parser.add_argument("-t", "--field-separator", dest="separator", default=" ", help="フィールド区切り文字を指定する") | |
parser.add_argument("-k", "--key", dest="keys", action="append", help="ソート対象フィールド指定する 例) -k 1 -k 2n") | |
parser.add_argument("-r", "--reverse", dest="reverse", action="store_true", default=False, help="降順ソート") | |
parser.add_argument("-l", "--limit", dest="limit", type=int, default=1024*1024*64, help="一度に読み込む文字数の制限") | |
parser.add_argument("file", nargs="?", help="入力ファイル") | |
return parser.parse_args() | |
def main(): | |
args = _parse_args() | |
file = open(args.file, "r") if args.file else sys.stdin | |
try: | |
large_sort(file, sys.stdout, | |
key=key_func(args.keys, args.separator), | |
reverse=args.reverse, | |
limit_chars=args.limit) | |
finally: | |
if args.file: | |
file.close() | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment