Last active
November 6, 2020 07:25
-
-
Save infinex/789adaaa5ec7f06bb06ed803cc5b255d to your computer and use it in GitHub Desktop.
utils
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import argparse | |
| import datetime | |
| from dateutil.relativedelta import relativedelta | |
| from pyparsing import ( | |
| ParseException, | |
| pyparsing_common as ppc, | |
| CaselessKeyword, | |
| And, | |
| Or, | |
| StringEnd | |
| ) | |
| class Interval(object): | |
| def __init__(self, start, end): | |
| self.start = start | |
| self.end = end | |
| def __str__(self): | |
| return 'from {start} to {end}'.format(start=self.start, end=self.end) | |
| def handle_last(tokens): | |
| end = datetime.date.today() | |
| start = end - relativedelta(days=tokens.n) | |
| return Interval(start, end) | |
| def handle_previous(tokens): | |
| if tokens.day: | |
| end = datetime.date.today() | |
| start = end - relativedelta(days=1) | |
| return Interval(start, end) | |
| elif tokens.week: | |
| end = datetime.date.today() | |
| start = end - relativedelta(days=7) | |
| return Interval(start, end) | |
| elif tokens.month: | |
| end = datetime.date.today() | |
| start = end - relativedelta(months=1) | |
| return Interval(start, end) | |
| def handle_fromto(tokens): | |
| return Interval(tokens.start, tokens.end) | |
| def make_date_parser(): | |
| date_expr = ppc.iso8601_date.copy() | |
| date_expr.setParseAction(ppc.convertToDate()) | |
| expr_last = And( | |
| CaselessKeyword('LAST') & ppc.integer.setResultsName('n') & StringEnd() | |
| ).setResultsName('interval').setParseAction(handle_last) | |
| expr_prev = And( | |
| CaselessKeyword('PREVIOUS') & Or( | |
| CaselessKeyword('DAY').setResultsName('day') | CaselessKeyword('WEEK').setResultsName('week') | CaselessKeyword('MONTH').setResultsName('month') | |
| ) + StringEnd() | |
| ).setResultsName('interval').setParseAction(handle_previous) | |
| expr_fromto_date = And( | |
| CaselessKeyword('FROM') + date_expr.setResultsName('start') + CaselessKeyword('TO') + date_expr.setResultsName('end') + StringEnd() | |
| ).setResultsName('interval').setParseAction(handle_fromto) | |
| parser = expr_fromto_date | expr_last | expr_prev | |
| return parser | |
| class IntervalAction(argparse.Action): | |
| def __init__(self, option_strings, dest, **kwargs): | |
| self._parser = make_date_parser() | |
| super(IntervalAction, self).__init__(option_strings, dest, **kwargs) | |
| def __call__(self, parser, namespace, values, option_string=None): | |
| try: | |
| rv = self._parser.parseString(values) | |
| except ParseException: | |
| parser.error('argument %s is not valid' % '/'.join(self.option_strings)) | |
| else: | |
| setattr(namespace, self.dest, rv.interval) | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('-i', '--interval', action=IntervalAction, required=True) | |
| args = parser.parse_args() | |
| print('interval: %s' % args.interval) | |
| print('interval start: %s' % args.interval.start) | |
| print('interval end: %s' % args.interval.end) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class Documents: | |
| def __init__(self, sentence, rule_score, text_index, field, file_name, | |
| **kwargs): | |
| self.sentence = sentence | |
| self.rule_score = rule_score | |
| self.text_index = text_index | |
| self.field = field | |
| self.file_name = file_name | |
| self.kwargs = kwargs | |
| def __str__(self): | |
| return f'sentence: {self.sentence},' \ | |
| f'rule_score: {self.rule_score},' \ | |
| f'text_index: {self.text_index},' \ | |
| f'field: {self.field},' \ | |
| f'file_name: {self.file_name},' | |
| def to_dict(self): | |
| x = copy.deepcopy(self.__dict__) | |
| del x['kwargs'] | |
| x.update(self.kwargs) | |
| return x |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def hashing(data): | |
| if isinstance(data, dict): | |
| return hashlib.md5((json.dumps(d, sort_keys=True)).encode()).hexdigest() | |
| elif isinstance(data, str): | |
| return hashlib.md5(data.encode()).hexdigest() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| export LC_ALL=en_US.UTF-8 | |
| export LANG=en_US.UTF-8 | |
| export LANGUAGE=en_US.UTF-8 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| with pd.ExcelWriter(write_path) as writer: # doctest: +SKIP | |
| metrics.to_frame('acc').to_excel(writer,sheet_name=f'overall summary model') | |
| agg_df.groupby(['field_id','field_name']).apply(lambda x: | |
| generate_metrics(df=x, | |
| value_col='prediction_with_threshold', | |
| sim_col='match_prediction', | |
| gold_col='gold_value', | |
| threshold=SIM_THRESHOLD)) \ | |
| .sort_values('field_id'). \ | |
| to_excel(writer, sheet_name=f'code level model') | |
| agg_df.groupby(['file_name']).apply(lambda x: | |
| generate_metrics(df=x, | |
| value_col='prediction_with_threshold', | |
| sim_col='match_prediction', | |
| gold_col='gold_value', | |
| threshold=SIM_THRESHOLD)) \ | |
| .sort_values('file_name'). \ | |
| to_excel(writer, sheet_name=f'file level model') | |
| agg_df.to_excel(writer, sheet_name='annotated') | |
| merge_df.to_excel(writer, sheet_name='all') | |
| write_text_to_excel(lines=output.split('\n'), sheet_name='prediction_details', | |
| writer=writer) | |
| def write_text_to_excel(lines, sheet_name, writer): | |
| """ | |
| write text in excel | |
| :param lines: list | |
| :param sheet_name: sheetname | |
| :param writer: xlsx_writer | |
| :return: | |
| """ | |
| ws = writer.book.add_worksheet(sheet_name) | |
| ws.hide_gridlines(2) | |
| cell_format = writer.book.add_format({'font_name': 'Courier'}) | |
| for line_id, line in enumerate(lines): | |
| ws.write(line_id, 0, line, cell_format) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| with pd.ExcelWriter(write_path) as writer: # doctest: +SKIP | |
| metrics.to_frame('acc').to_excel(writer,sheet_name=f'overall summary model') | |
| def excel_text_writer(lines,sheet, writer): | |
| ws = writer.book.add_worksheet(sheet) | |
| ws.hide_gridlines() | |
| cell_format = writer.book.add_format({'font_name': 'Courier New'}) | |
| for line_id, line in enumerate(lines): | |
| ws.write(line_id, 0, line, cell_format) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import logging | |
| from abc import ABCMeta, abstractmethod | |
| from typing import Callable | |
| import pydash | |
| logger = logging.getLogger(__file__) | |
| class YamlProcessFactory: | |
| """ The factory class for creating yaml process strategy""" | |
| registry = {} | |
| """ Internal registry for available yaml process strategy""" | |
| @classmethod | |
| def register(cls, name: str) -> Callable: | |
| def inner_wrapper(wrapped_class: BaseYamlProcess) -> Callable: | |
| if name in cls.registry: | |
| logger.warning('Yaml %s already exists. Will replace it', | |
| name) | |
| cls.registry[name] = wrapped_class | |
| return wrapped_class | |
| return inner_wrapper | |
| @classmethod | |
| def create_processor(cls, name: str = 'default', | |
| **kwargs) -> 'BaseYamlProcess': | |
| """ Factory command to create the loader. | |
| """ | |
| if name not in cls.registry: | |
| raise ValueError('Executor %s does not exist in the registry') | |
| exec_class = cls.registry[name] | |
| loader = exec_class(**kwargs) | |
| return loader | |
| class BaseYamlProcess(metaclass=ABCMeta): | |
| @abstractmethod | |
| def process(self, f): | |
| pass | |
| @YamlProcessFactory.register('default') | |
| class DefaultYamlProcess(BaseYamlProcess): | |
| def __init__(self, ): | |
| """ | |
| loaded for template | |
| """ | |
| pass | |
| def process(self, f): | |
| return f | |
| @YamlProcessFactory.register('model') | |
| class ModelYamlProcess(BaseYamlProcess): | |
| def __init__(self, ): | |
| """ | |
| loaded for template | |
| """ | |
| pass | |
| @staticmethod | |
| def to_flatten_deep(root): | |
| flatten_dict = {} | |
| def dfs(o, prefix=''): | |
| prefix_dot = f'{prefix}.' if prefix != '' else prefix | |
| if isinstance(o, dict): | |
| return { | |
| key: dfs(value, f'{prefix_dot}{key}') for | |
| key, value in o.items() | |
| } | |
| elif isinstance(o, list): | |
| return [dfs(it, f'{prefix_dot}[{i}]') for i, it in enumerate(o)] | |
| else: | |
| flatten_dict[prefix] = o | |
| return o | |
| dfs(root) | |
| return flatten_dict | |
| def process(self, f): | |
| """ | |
| duplicates field codes in yaml file | |
| Before | |
| template: model | |
| template_params: | |
| FIELD_CODES: [ 'F_MT700_1','F_MT700_2'] | |
| fields: | |
| - field_code: +FIELD_CODES | |
| strategy: | |
| chain: | |
| - action: fe_model.convert_and_tokenize | |
| inputs: | |
| context: document | |
| After | |
| template: model | |
| template_params: | |
| FIELD_CODES: [ 'F_MT700_1','F_MT700_2'] | |
| fields: | |
| - field_code: F_MT700_1 | |
| strategy: | |
| chain: | |
| - action: fe_model.convert_and_tokenize | |
| inputs: | |
| context: document | |
| - field_code: F_MT700_2 | |
| strategy: | |
| chain: | |
| - action: fe_model.convert_and_tokenize | |
| inputs: | |
| context: document | |
| :param f: | |
| :return: | |
| """ | |
| assert 'template_params' in f | |
| duplicate_keys = f['template_params'] | |
| all_fields = f['fields'] | |
| # get flatten path mapping | |
| flatten_map = self.to_flatten_deep(all_fields) | |
| # find path with + | |
| filter_paths = {k: v for k, v in flatten_map.items() if | |
| str(v).startswith('+')} | |
| generated_strategy_templates = [] | |
| marked_as_drop = [] | |
| # loop path | |
| for path, code in filter_paths.items(): | |
| replace_code_key = code[1:] | |
| assert replace_code_key in duplicate_keys | |
| to_replace_lst = duplicate_keys[replace_code_key] | |
| p1 = path.split('.')[0] | |
| p2 = '.'.join(path.split('.')[1:]) | |
| one_field = pydash.get(all_fields,p1) | |
| for replace_str in to_replace_lst: | |
| clone_fields = pydash.clone_deep(one_field) | |
| pydash.set_(clone_fields,p2,replace_str) | |
| generated_strategy_templates.append(clone_fields) | |
| marked_as_drop.append(one_field) | |
| # cleaning | |
| all_fields = [field for field in all_fields if | |
| field not in marked_as_drop] | |
| all_fields.extend(generated_strategy_templates) | |
| f['fields'] = all_fields | |
| return f | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import numpy as np | |
| import sys | |
| import os | |
| import concurrent | |
| from tqdm import tqdm | |
| MAX_LEN =192 | |
| pad_to_max_length = True | |
| CPU_COUNT = os.cpu_count() | |
| dir = '/kaggle/input/jigsaw-multilingual-toxic-comment-classification' | |
| # edit this to change to other tokenizer | |
| tokenizer = transformer_tokenizer.from_pretrained('xlm-roberta-base') | |
| def regular_encode(texts, tokenizer, seq_len, pad_to_max_length=False): | |
| enc_di = tokenizer.batch_encode_plus( | |
| texts, | |
| return_attention_masks=False, | |
| return_token_type_ids=False, | |
| pad_to_max_length=pad_to_max_length, | |
| max_length=seq_len, | |
| ) | |
| return np.array(enc_di['input_ids']) | |
| def process_pool_tokenizer(input_text, tokenizer, max_len, pad_to_max_length): | |
| context = [] | |
| with concurrent.futures.ProcessPoolExecutor(max_workers=CPU_COUNT) as executor: | |
| with tqdm(total=len(input_text)) as progress: | |
| futures = [] | |
| for x in input_text: | |
| future = executor.submit(regular_encode,x, tokenizer, max_len, pad_to_max_length) | |
| future.add_done_callback(lambda p: progress.update()) | |
| futures.append(future) | |
| for future in futures: | |
| result = future.result() | |
| context.extend(result) | |
| return context | |
| def chunks(lst, n=5000): | |
| """Yield successive n-sized chunks from lst.""" | |
| context = [] | |
| for i in range(0, len(lst), n): | |
| context.append(lst[i:i + n]) | |
| return context | |
| # %%time | |
| # process_pool_tokenizer(chunks(sample.comment_text.values), tokenizer, MAX_LEN, pad_to_max_length) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class LabelField(object): | |
| def __init__(self, | |
| uuid, | |
| query, | |
| context, | |
| file_name, | |
| case_id, | |
| file_id, | |
| page_id, | |
| paragraph_id, | |
| sentence_id, | |
| target=None): | |
| self.uuid = uuid | |
| self.query = query | |
| self.context = context | |
| self.file_name = file_name | |
| self.case_id = case_id | |
| self.file_id = file_id | |
| self.page_id = page_id | |
| self.paragraph_id = paragraph_id | |
| self.sentence_id = sentence_id | |
| self.target = target | |
| def __str__(self): | |
| return f'uuid:{self.uuid}, ' \ | |
| f'query:{self.query}, ' \ | |
| f'context:{self.context}, ' \ | |
| f'file_name:{self.file_name}, ' \ | |
| f'case_id:{self.case_id}, ' \ | |
| f'file_id:{self.file_id}, ' \ | |
| f'page_id:{self.page_id}, ' \ | |
| f'paragraph_id:{self.paragraph_id}, ' \ | |
| f'sentence_id:{self.sentence_id}, ' | |
| def to_dict(self): | |
| x = copy.deepcopy(self.__dict__) | |
| del x['kwargs'] | |
| x.update(self.kwargs) | |
| return x |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| logger = logging.getLogger(__name__) | |
| # create console handler and set level to debug | |
| ch = logging.StreamHandler() | |
| ch.setLevel(logging.DEBUG) | |
| # create formatter | |
| formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(name)s - PID: %(process)d - %(message)s') | |
| # add formatter to ch | |
| ch.setFormatter(formatter) | |
| # add ch to logger | |
| logger.addHandler(ch) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| DATA_DIR=./data | |
| field_keywords=$(wildcard $(DATA_DIR)/field_keywords_*) | |
| out_keywords=$(patsubst $(DATA_DIR)/field_keywords_%,out%.txt,$(field_keywords)) | |
| REAL_PATH=$(abspath $(DATA_DIR)) | |
| NN_FRAMEWORK_DATA=~/.nn_framework_Data | |
| field_keyword=$(patsubst $(DATA_DIR)/field_keywords_%,$(DATA_DIR)/field_keywords_\%,$(word 1,$(field_keywords))) | |
| ## all: run command | |
| all: $(out_keywords) | |
| ## variables: Print Variables | |
| variables: | |
| @echo $(out_keywords) | |
| @echo $(field_keyword) | |
| ## start_%: mkdir %, and output out.txt and out2.txt | |
| star_%: $(field_keyword) | |
| echo star_$* | |
| mkdir -p $@ | |
| cat $^ | paste -sd ',' - > $@/out.txt | |
| echo 'asd' > $@/out2.txt | |
| out%.txt: star_% | |
| echo 'run' > out$*.txt | |
| .Phony: clean help | |
| clean: | |
| rm -rf star* | |
| rm -rf out*.txt | |
| help : Makefile | |
| @sed -n 's/^##/make /p' $< |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| tmp_write_path = Path(f'/tmp/{data_annotation_path.name}.pkl') | |
| if not tmp_write_path.exists(): | |
| extracted_fields = field_extraction(fas_files, field_keyword, | |
| output_path) | |
| pickle.dump(extracted_fields, tmp_write_path.open('wb')) | |
| else: | |
| extracted_fields = pickle.load(tmp_write_path.open('rb')) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def get_args_parser(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--input", | |
| type=str, | |
| required=True, | |
| help='input filename') | |
| parser.add_argument("--extracted_fpath", | |
| type=str, | |
| required=True, | |
| help='input extracted field folder') | |
| return parser | |
| def get_print_args(): | |
| args, unknowns = get_args_parser().parse_known_args() | |
| param_str = '\n'.join( | |
| ['%20s = %s' % (k, v) for k, v in sorted(vars(args).items())]) | |
| print('usage: %s\n%20s %s\n%s\n%s\n' % | |
| (' '.join(sys.argv), 'ARG', 'VALUE', '_' * 50, param_str)) | |
| return args | |
| def get_args_parser(): | |
| from . import __version__ | |
| from .graph import PoolingStrategy | |
| parser = argparse.ArgumentParser(description='Start a BertServer for serving') | |
| group1 = parser.add_argument_group('File Paths', | |
| 'config the path, checkpoint and filename of a pretrained/fine-tuned BERT model') | |
| group1.add_argument('-model_dir', type=str, required=True, | |
| help='directory of a pretrained BERT model') | |
| group1.add_argument('-tuned_model_dir', type=str, | |
| help='directory of a fine-tuned BERT model') | |
| group1.add_argument('-ckpt_name', type=str, default='bert_model.ckpt', | |
| help='filename of the checkpoint file. By default it is "bert_model.ckpt", but \ | |
| for a fine-tuned model the name could be different.') | |
| group1.add_argument('-config_name', type=str, default='bert_config.json', | |
| help='filename of the JSON config file for BERT model.') | |
| group1.add_argument('-graph_tmp_dir', type=str, default=None, | |
| help='path to graph temp file') | |
| group2 = parser.add_argument_group('BERT Parameters', | |
| 'config how BERT model and pooling works') | |
| group2.add_argument('-max_seq_len', type=check_max_seq_len, default=25, | |
| help='maximum length of a sequence, longer sequence will be trimmed on the right side. ' | |
| 'set it to NONE for dynamically using the longest sequence in a (mini)batch.') | |
| group2.add_argument('-cased_tokenization', dest='do_lower_case', action='store_false', default=True, | |
| help='Whether tokenizer should skip the default lowercasing and accent removal.' | |
| 'Should be used for e.g. the multilingual cased pretrained BERT model.') | |
| group2.add_argument('-pooling_layer', type=int, nargs='+', default=[-2], | |
| help='the encoder layer(s) that receives pooling. \ | |
| Give a list in order to concatenate several layers into one') | |
| group2.add_argument('-pooling_strategy', type=PoolingStrategy.from_string, | |
| default=PoolingStrategy.REDUCE_MEAN, choices=list(PoolingStrategy), | |
| help='the pooling strategy for generating encoding vectors') | |
| group2.add_argument('-mask_cls_sep', action='store_true', default=False, | |
| help='masking the embedding on [CLS] and [SEP] with zero. \ | |
| When pooling_strategy is in {CLS_TOKEN, FIRST_TOKEN, SEP_TOKEN, LAST_TOKEN} \ | |
| then the embedding is preserved, otherwise the embedding is masked to zero before pooling') | |
| group2.add_argument('-no_special_token', action='store_true', default=False, | |
| help='add [CLS] and [SEP] in every sequence, \ | |
| put sequence to the model without [CLS] and [SEP] when True and \ | |
| is_tokenized=True in Client') | |
| group2.add_argument('-show_tokens_to_client', action='store_true', default=False, | |
| help='sending tokenization results to client') | |
| group2.add_argument('-no_position_embeddings', action='store_true', default=False, | |
| help='Whether to add position embeddings for the position of each token in the sequence.') | |
| group2.add_argument('-num_labels', type=int, default=2, | |
| help='Numbers of Label') | |
| group3 = parser.add_argument_group('Serving Configs', | |
| 'config how server utilizes GPU/CPU resources') | |
| group3.add_argument('-port', '-port_in', '-port_data', type=int, default=5555, | |
| help='server port for receiving data from client') | |
| group3.add_argument('-port_out', '-port_result', type=int, default=5556, | |
| help='server port for sending result to client') | |
| group3.add_argument('-http_port', type=int, default=None, | |
| help='server port for receiving HTTP requests') | |
| group3.add_argument('-http_max_connect', type=int, default=10, | |
| help='maximum number of concurrent HTTP connections') | |
| group3.add_argument('-cors', type=str, default='*', | |
| help='setting "Access-Control-Allow-Origin" for HTTP requests') | |
| group3.add_argument('-num_worker', type=int, default=1, | |
| help='number of server instances') | |
| group3.add_argument('-max_batch_size', type=int, default=256, | |
| help='maximum number of sequences handled by each worker') | |
| group3.add_argument('-priority_batch_size', type=int, default=16, | |
| help='batch smaller than this size will be labeled as high priority,' | |
| 'and jumps forward in the job queue') | |
| group3.add_argument('-cpu', action='store_true', default=False, | |
| help='running on CPU (default on GPU)') | |
| group3.add_argument('-xla', action='store_true', default=False, | |
| help='enable XLA compiler (experimental)') | |
| group3.add_argument('-fp16', action='store_true', default=False, | |
| help='use float16 precision (experimental)') | |
| group3.add_argument('-gpu_memory_fraction', type=float, default=0.5, | |
| help='determine the fraction of the overall amount of memory \ | |
| that each visible GPU should be allocated per worker. \ | |
| Should be in range [0.0, 1.0]') | |
| group3.add_argument('-device_map', type=int, nargs='+', default=[], | |
| help='specify the list of GPU device ids that will be used (id starts from 0). \ | |
| If num_worker > len(device_map), then device will be reused; \ | |
| if num_worker < len(device_map), then device_map[:num_worker] will be used') | |
| group3.add_argument('-prefetch_size', type=int, default=10, | |
| help='the number of batches to prefetch on each worker. When running on a CPU-only machine, \ | |
| this is set to 0 for comparability') | |
| group3.add_argument('-fixed_embed_length', action='store_true', default=False, | |
| help='when "max_seq_len" is set to None, the server determines the "max_seq_len" according to ' | |
| 'the actual sequence lengths within each batch. When "pooling_strategy=NONE", ' | |
| 'this may cause two ".encode()" from the same client results in different sizes [B, T, D].' | |
| 'Turn this on to fix the "T" in [B, T, D] to "max_position_embeddings" in bert json config.') | |
| parser.add_argument('-verbose', action='store_true', default=False, | |
| help='turn on tensorflow logging for debug') | |
| parser.add_argument('-version', action='version', version='%(prog)s ' + __version__) | |
| return parser |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class Singleton(type): | |
| _instances = {} | |
| def __call__(cls, *args, **kwargs): | |
| if cls not in cls._instances: | |
| cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) | |
| return cls._instances[cls] | |
| class myclass(metadata=Singleton)): | |
| def __init__(): | |
| pass | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import tarfile | |
| with tarfile.open(f"{save_path}.npy.tar.gz", "w:gz") as tar: | |
| tar.add(f"{save_path}.npy") | |
| with tarfile.open((process_dir.parent / (process_dir.name + '.tar.gz')).as_posix(), mode='w:gz') as archive: | |
| archive.add(process_dir.as_posix(), recursive=True) | |
| fs = Path(f"{save_path}.npy.tar.gz") | |
| tar_size = fs.stat().st_size / 1024 /1024 | |
| fs = Path(f"{save_path}.npy") | |
| org_size = fs.stat().st_size / 1024 /1024 | |
| print(f'orginal {org_size} after {tar_size}') | |
| fs.unlink() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment