Created
March 21, 2018 14:04
-
-
Save enihsyou/46ed719b9ec9c8d23505d51ab26a44eb to your computer and use it in GitHub Desktop.
给别人写的作业…One-Hot方式的编解码器
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
import json | |
import re | |
from typing import List, TextIO | |
def debug(*args): | |
if DEBUG: | |
print(*args) | |
def _str_to_list(string): | |
regex = re.compile(' ') | |
word_list = regex.split(string) | |
return word_list | |
def _list_to_str(word_list: List) -> str: | |
return " ".join(map(str, word_list)) | |
class OneHotDecoder: | |
def __init__(self, dictionary: dict) -> None: | |
self.dict: dict = {str(v): k for k, v in dictionary.items()} or dict() | |
def _transform(self, cipher: str) -> List[str]: | |
debug(self.dict) | |
result: List[str] = [] | |
key_list = _str_to_list(cipher) | |
for key in key_list: | |
# 根据字典进行转换 | |
patched = self.dict[key] | |
debug(key, ' -> ', patched) | |
result.append(patched) | |
return result | |
def decode(self, cipher: str) -> str: | |
transformed = self._transform(cipher) | |
return _list_to_str(transformed) | |
class OneHotEncoder: | |
def __init__(self, text: str = None, file_path: str = None, file: TextIO = None) -> None: | |
self.text = "" # 原始文字 | |
self.words = [] # 按顺序排的词 | |
self.dict = dict() | |
if text is not None: | |
self.text = text | |
self.words = _str_to_list(text) | |
elif file_path is not None: | |
with open(file_path, "r")as file: | |
lines = file.read() | |
self.text = lines | |
self.words = _str_to_list(lines) | |
elif file is not None: | |
lines = file.read() | |
self.text = lines | |
self.words = _str_to_list(lines) | |
file.close() | |
def _generate_dictionary(self, word_list: List[str]): | |
# 去除重复文字 | |
words = set(word_list) | |
# 为每个文字分配一个独立的编号(递增的) | |
for index, word in enumerate(words): | |
self.dict[word] = index | |
def _transform(self) -> List[int]: | |
debug(self.text) | |
debug(self.dict) | |
result: List[int] = [] | |
for word in self.words: | |
# 根据字典进行转换 | |
patched = self.dict[word] | |
debug(word, ' -> ', patched) | |
result.append(patched) | |
return result | |
def encode(self): | |
# 生成字典 | |
self._generate_dictionary(self.words) | |
# 编码字符串 | |
transformed = self._transform() | |
# 列表转成一整个字符串输出 | |
return _list_to_str(transformed) | |
def encode_function(): | |
def input_filename(): | |
# return "data.txt" | |
return input("Please input data file path: ") or "data.txt" | |
file_path = input_filename() | |
encoder = OneHotEncoder(file_path=file_path) | |
encoded = encoder.encode() | |
debug(encoded) | |
with open(file_path + ".key", "w") as file: | |
file.write(json.dumps(encoder.dict)) | |
print("写入密密到", file_path + ".key") | |
with open(file_path + ".cipher", "w") as file: | |
file.write(encoded) | |
print("写入密文到", file_path + ".cipher") | |
def decode_function(): | |
def input_key_filename(): | |
return input("Please input KEY file path: ") or "data.txt.key" | |
def input_cipher_filename(): | |
return input("Please input CIPHER file path: ") or "data.txt.cipher" | |
key_path = input_key_filename() | |
cipher_path = input_cipher_filename() | |
with open(key_path, "r") as file: | |
key_dict = json.load(file) | |
with open(cipher_path, "r") as file: | |
cipher_text = file.read() | |
decoder = OneHotDecoder(key_dict) | |
decoded = decoder.decode(cipher_text) | |
debug(decoded) | |
with open(cipher_path + ".decoded", "w")as file: | |
file.write(decoded) | |
print("写入解密到", cipher_path + ".decoded") | |
if __name__ == '__main__': | |
DEBUG = False # 控制测试输出 | |
prompt = "1: 加密\n" \ | |
"2:解密\n" | |
functions = { | |
1: encode_function, | |
2: decode_function | |
} | |
choice = input(prompt) | |
if choice in map(str, functions.keys()): | |
functions[int(choice)]() | |
else: | |
exit(1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment