This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import contextlib | |
@contextlib.contextmanager | |
def open_func(file_name): | |
# __enter__方法 | |
print('open file:', file_name, 'in __enter__') | |
file_handler = open(file_name, 'r') | |
yield file_handler | |
# __exit__方法 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def collate_fn(data): | |
""" | |
default_collate比较好地实现了对图片等的操作,但是并不支持对文字等不等长序列的操作 | |
dataloader的一个参数 | |
输入data是list of (x, y), list的长度是batch_size | |
返回xs, ys, lens | |
(batch_size, x或y的size)的tensor | |
""" | |
# Sort a data list |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch.utils.data as data | |
class MyDataset(data.Dataset): | |
''' | |
Dataset must define __getitem__ and __len__ | |
''' | |
def __init__(self, others): | |
pass | |
def __getitem__(self, index): | |
""" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
_registry_path = Path(__file__).parent / 'registry.json' | |
if _registry_path.exists(): | |
with _registry_path.open(encoding='utf-8') as f: | |
_REGISTRY = json.load(f) | |
else: | |
_REGISTRY = {} | |
def short_name(cls: type) -> str: | |
"""Returns just a class name (without package and module specification).""" | |
return cls.__name__.split('.')[-1] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import functools | |
def log1(func): | |
@functools.warp(func) | |
def wrapper(*args, **kw): | |
print("call %s():", func.__name__) | |
return func(*args, **kw) | |
return wrapper | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import torch.utils.data as data | |
from tqdm import tqdm | |
def train(): | |
# config saving | |
model_path = './ckpt' | |
if not os.path.exists(model_path): | |
os.mkdir(model_path) | |
save_step = 2 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import re | |
st_after = re.sub('\W', '', st_before) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class LogMixin(object): | |
@property | |
def logger(self): | |
name = '.'.join([__name__, self.__class__.__name__]) | |
return logging.getLogger(name) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
for j in range(len(s)-1, -1, -1): | |
pass |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def is_number(s): | |
try: | |
float(s) | |
return True | |
except ValueError: | |
pass | |
try: | |
import unicodedata | |
unicodedata.numeric(s) | |
return True |