Created
August 23, 2022 11:51
-
-
Save timvandam/bb6992156686c7eeafeab7a59a1db6de to your computer and use it in GitHub Desktop.
DDP training
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import operator | |
from multiprocessing import Pool | |
from typing import List | |
import torch | |
import random | |
import os | |
import numpy as np | |
from fuzzywuzzy import fuzz | |
from torch.utils.data import DataLoader, Dataset, RandomSampler, DistributedSampler, SequentialSampler | |
from tqdm import tqdm | |
from transformers import RobertaConfig, RobertaTokenizer, RobertaModel, AdamW, get_linear_schedule_with_warmup | |
from Seq2Seq import Seq2Seq | |
import json | |
import sys | |
from functools import lru_cache | |
import argparse | |
import torch.multiprocessing | |
from torch.nn.parallel import DistributedDataParallel as DDP | |
import torch.distributed as dist | |
import datetime | |
from time import time | |
torch.multiprocessing.set_sharing_strategy('file_system') | |
def parse_args(): | |
parser = argparse.ArgumentParser( | |
description="Trains a model, optionally using multiple GPUs using DistributedDataParallel. " | |
"Requires ./datasets/train.txt and ./datasets/validation.jsonl to be present.") | |
parser.add_argument( | |
"--slurm", | |
action='store_true', | |
help='Enables automatic DistributedDataParallel based on Slurm options. ' | |
'Enabling this will automatically fill in all DDP options based on environment variables' | |
) | |
multi_gpu_manual_options = parser.add_argument_group( | |
title="Multi GPU Options", | |
description="Options for manually configuring DistributedDataParallel" | |
) | |
multi_gpu_manual_options.add_argument('--world_size', type=int, default=-1, | |
help='Number of processes participating in the job') | |
multi_gpu_manual_options.add_argument('--gpus_per_node', type=int, default=-1) | |
multi_gpu_manual_options.add_argument('--rank', type=int, default=-1, help='The global rank') | |
multi_gpu_manual_options.add_argument('--local_rank', type=int, default=-1, | |
help='The local rank (determines which GPU to use)') | |
multi_gpu_manual_options.add_argument('--dist_backend', type=str, default=dist.Backend.NCCL) | |
multi_gpu_manual_options.add_argument('--dist_url', default='env://', type=str) | |
parser.add_argument("dataset_folder", type=str, help="The folder containing the dataset", action='store') | |
parser.add_argument("--model_name", default="microsoft/unixcoder-base", type=str, | |
help="The name or the path of the model to be trained") | |
parser.add_argument("--learning_rate", default=2e-4, type=float, help="The learning rate") | |
parser.add_argument("--max_input_length", default=936, type=int, | |
help="The maximum length of the input (validation input is left-truncated if over this length)") | |
parser.add_argument("--max_output_length", default=64, type=int, help="The maximum length of the output") | |
parser.add_argument("--chunk_overlap", default=100, type=int, | |
help="The window offset used for windowing train inputs larger than the max allowed length") | |
parser.add_argument("--seed", default=42, type=int, help="The seed used for randomized things") | |
parser.add_argument("--beam_size", default=3, type=int, help="The beam size for beam search") | |
parser.add_argument("--batch_size", default=8, type=int, help="The batch size") | |
parser.add_argument("--num_epochs", default=10, type=int, help="The number of epochs") | |
parser.add_argument("--gradient_accumulation_steps", default=1, type=int, | |
help="The number of steps to accumulate gradients") | |
parser.add_argument("--weight_decay", default=0.0, type=float, help="The weight decay") | |
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="The epsilon for Adam optimizer") | |
parser.add_argument("--cpu_count", default=os.cpu_count(), type=int, | |
help="The number of CPUs to use for loading data") | |
return parser.parse_args() | |
def set_slurm_args(args): | |
if not args.slurm: | |
raise Exception("Slurm args can only be set if slurm is enabled") | |
args.rank = int(os.environ['SLURM_PROCID']) | |
args.local_rank = int(os.environ['SLURM_LOCALID']) | |
args.cpu_count = int(os.environ['SLURM_CPUS_PER_TASK']) | |
args.world_size = int(os.environ['WORLD_SIZE']) | |
args.gpus_per_node = int(os.environ['SLURM_GPUS_ON_NODE']) | |
if args.gpus_per_node > torch.cuda.device_count(): | |
raise Exception("The number of GPUs per node is greater than the number of GPUs available") | |
def get_current_date_time_string(): | |
return datetime.datetime.now().strftime("%d.%b %Y %H:%M:%S") | |
def log(message): | |
print(f'[{get_current_date_time_string()}] {message}', flush=True) | |
def progress_enumerator(enumerable, create_message, step=100, total=None, report_eta=False): | |
start = time() | |
for i, x in enumerate(enumerable): | |
if step and i % step == 0: | |
suffix = "" | |
if report_eta and i > 0: | |
if total is None: | |
total = len(enumerable) | |
time_elapsed = time() - start | |
avg_time_per_element = time_elapsed / i | |
remaining_elements = total - i | |
remaining_time_estimate = avg_time_per_element * remaining_elements | |
suffix = f" [{datetime.timedelta(seconds=round(remaining_time_estimate))} remaining]" | |
log(create_message(i, x) + suffix) | |
yield x | |
def set_seed(seed: int): | |
random.seed(seed) | |
os.environ['PYHTONHASHSEED'] = str(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
torch.backends.cudnn.deterministic = True | |
def get_train_file_path(dataset_folder: str): | |
return os.path.join(dataset_folder, "datasets", "train.txt") | |
def get_validation_file_path(dataset_folder: str): | |
return os.path.join(dataset_folder, "datasets", "validation.jsonl") | |
def get_model_folder_path(dataset_folder: str): | |
return os.path.join(dataset_folder, "models") | |
def get_model_file_path(dataset_folder: str, model_name: str): | |
return os.path.join(get_model_folder_path(dataset_folder), model_name + ".bin") | |
def get_training_details_file_path(dataset_folder: str): | |
return os.path.join(get_model_folder_path(dataset_folder), "training_details.json") | |
def get_training_details(dataset_folder): | |
training_details_json_path = get_training_details_file_path(dataset_folder) | |
if not os.path.exists(training_details_json_path): | |
return { | |
"models": [], # contains { modelName: string, exactMatch: number, editSim: number } | |
} | |
with open(training_details_json_path, "r", encoding='utf-8') as f: | |
return json.loads(f.read()) | |
def save_training_details(dataset_folder, training_details): | |
training_details_json_path = get_training_details_file_path(dataset_folder) | |
with open(training_details_json_path, "w", encoding='utf-8') as f: | |
f.write(json.dumps(training_details, indent=2)) | |
def verify_files_exist(dataset_folder: str): | |
if not os.path.exists(dataset_folder): | |
raise Exception("Dataset folder does not exist") | |
if not os.path.exists(get_train_file_path(dataset_folder)): | |
raise Exception("Train file does not exist") | |
if not os.path.exists(get_validation_file_path(dataset_folder)): | |
raise Exception("Validation file does not exist") | |
def prepare_model(model_name: str, max_output_length: int, beam_size: int): | |
config = RobertaConfig.from_pretrained(model_name) | |
config.is_decoder = True | |
tokenizer = RobertaTokenizer.from_pretrained(model_name) | |
# confirm assumptions that are made in createTrainSet.ts | |
if tokenizer.bos_token != '<s>': | |
raise Exception("Tokenizer bos_token is not <s>") | |
if tokenizer.cls_token != '<s>': | |
raise Exception("Tokenizer cls_token is not <s>") | |
if tokenizer.sep_token != '</s>': | |
raise Exception("Tokenizer sep_token is not </s>") | |
if tokenizer.eos_token != '</s>': | |
raise Exception("Tokenizer eos_token is not </s>") | |
encoder = RobertaModel.from_pretrained(model_name, config=config) | |
decoder = encoder | |
model = Seq2Seq( | |
encoder=encoder, | |
decoder=decoder, | |
config=config, | |
beam_size=beam_size, | |
max_length=max_output_length, | |
sos_id=tokenizer.cls_token_id, | |
eos_id=[tokenizer.sep_token_id], | |
) | |
return model, tokenizer | |
class IndexedFileDataset(Dataset): | |
""" | |
Dataset that caches an input file but lazily tokenizes them (with a cache!) | |
""" | |
# input_index -> (line_index, input_index) | |
# handy when a line contains multiple inputs/chunks | |
index = [] | |
index_loaded = False | |
# all lines in the input file are cached | |
lines = [] | |
lines_loaded = False | |
def __init__(self, file_path: str, cpu_count: int): | |
self.file_path = file_path | |
self.cpu_count = cpu_count | |
def _read_file(self): | |
if self.lines_loaded: | |
return | |
with open(self.file_path, 'r', encoding='utf-8') as f: | |
self.lines = f.readlines() | |
self.lines_loaded = True | |
def _create_index(self): | |
if not self.lines_loaded: | |
self._read_file() | |
if self.index_loaded: | |
return | |
with Pool(self.cpu_count) as pool: | |
input_counts = progress_enumerator( | |
pool.imap(self._get_line_input_count, range(len(self.lines)), chunksize=1000), | |
lambda i, _: f"Creating index [{i} / {len(self.lines) - 1}]", | |
step=1000, | |
total=len(self.lines), | |
report_eta=True, | |
) | |
self.index = [ | |
(line_index, i) | |
for line_index, input_count in enumerate(input_counts) | |
for i in range(input_count) | |
] | |
self.index_loaded = True | |
def __len__(self): | |
if not self.index_loaded: | |
self._create_index() | |
return len(self.index) | |
# TODO: Force it to get all items to test how much ram it takes | |
# @lru_cache(maxsize=65536) | |
def __getitem__(self, idx): | |
if not self.index_loaded: | |
raise Exception("Index not loaded") | |
if idx < 0 or idx >= len(self.index): | |
raise IndexError("Index out of range") | |
line_index, input_index = self.index[idx] | |
line_input = self._get_line_input(line_index, input_index) | |
return line_input | |
def _get_line(self, line_index: int): | |
if not self.lines_loaded: | |
self._read_file() | |
if line_index < 0 or line_index >= len(self.lines): | |
raise IndexError("Line index out of range") | |
return self.lines[line_index] | |
def _get_line_input_count(self, line_index: int): | |
""" | |
Should return the amount of inputs that are in some line | |
""" | |
raise NotImplementedError() | |
def _get_line_input(self, line_index: int, input_index: int): | |
""" | |
Should return some input from some line | |
""" | |
raise NotImplementedError() | |
def str_to_tokens(string: str, tokenizer: RobertaTokenizer): | |
return [token for token in tokenizer.tokenize(string) if token != '\u0120'] | |
def tokens_to_token_ids(tokenizer: RobertaTokenizer, max_length: int, tokens: List[str]): | |
if len(tokens) > max_length: | |
raise Exception("Input is too long") | |
token_ids = tokenizer.convert_tokens_to_ids(tokens) | |
padding_length = max_length - len(token_ids) | |
token_ids += [tokenizer.pad_token_id] * padding_length | |
return token_ids | |
def rindex(items, item, start=0, end=None): | |
if 0 <= start <= len(items): | |
if end is None: | |
end = len(items) | |
end = min(end, len(items)) | |
for i in range(end - 1, start - 1, -1): | |
if items[i] == item: | |
return i | |
raise ValueError("Item not found") | |
class TrainDataset(IndexedFileDataset): | |
def __init__( | |
self, | |
train_file_path: str, | |
cpu_count: int, | |
max_length: int, | |
tokenizer: RobertaTokenizer, | |
chunk_overlap: int, | |
): | |
super().__init__(train_file_path, cpu_count) | |
if chunk_overlap < 0: | |
raise Exception("Chunk overlap must be >= 0") | |
if chunk_overlap >= max_length - 3: | |
raise Exception("Chunk overlap must be < max_length - 3") | |
self.max_length = max_length | |
self.tokenizer = tokenizer | |
self.chunk_overlap = chunk_overlap | |
@lru_cache(maxsize=10) | |
def _get_line_tokens(self, line_index: int): | |
line = self._get_line(line_index) | |
# remove leading <s> | |
line = " ".join(line.strip().split()[1:]) | |
line_tokens = str_to_tokens(line, self.tokenizer) | |
return line_tokens | |
@lru_cache(maxsize=10) | |
def _get_line_chunk_ranges(self, line_index: int): | |
line_tokens = self._get_line_tokens(line_index) | |
chunk_length = self.max_length - 3 | |
chunk_start = 0 | |
chunk_end = chunk_start + chunk_length | |
chunk_ranges = [] | |
while True: | |
try: | |
# ensure that we always end with an eos | |
last_eos_idx = rindex(line_tokens, "</s>", chunk_start, chunk_end) | |
chunk_end = last_eos_idx + 1 | |
except ValueError: | |
# ValueError means we don't have an eos | |
# this only happens when lines don't fit in the model | |
# we can only really skip those lines | |
# including them partially would lead to partial predictions (which is not the point of line completion) | |
pass | |
chunk_ranges.append((chunk_start, chunk_end)) | |
if chunk_end >= len(line_tokens): | |
break | |
current_chunk_length = chunk_end - chunk_start | |
if current_chunk_length > self.chunk_overlap: | |
# the current chunk is larger than the chunk overlap. good! | |
chunk_start = chunk_end - self.chunk_overlap | |
else: | |
# chunk is smaller than the chunk overlap :o | |
# very small chunk means that the next lines are very long | |
# in order to still make it work we will just not use the overlap | |
# this only happens when the next line is extremely long, so no big issue | |
chunk_start = chunk_end | |
chunk_end = chunk_start + chunk_length | |
if line_index < 10: | |
print(f"Line {line_index} has {len(chunk_ranges)} chunks") | |
for i in range(len(chunk_ranges)): | |
chunk_start, chunk_end = chunk_ranges[i] | |
print(f"*** {i} ***") | |
print(" ".join(line_tokens[chunk_start:chunk_end])) | |
print("\n\n\n") | |
return chunk_ranges | |
@lru_cache(maxsize=10) | |
def _get_line_input_count(self, line_index: int): | |
return len(self._get_line_chunk_ranges(line_index)) | |
@lru_cache(maxsize=10) | |
def _get_line_input(self, line_index: int, input_index: int): | |
# chunked approach with overlapping chunks | |
line_tokens = self._get_line_tokens(line_index) | |
line_chunk_ranges = self._get_line_chunk_ranges(line_index) | |
if input_index < 0 or input_index >= len(line_chunk_ranges): | |
raise IndexError("Input index out of range. " | |
f"Got {input_index}, expected value in range [0, {len(line_chunk_ranges) - 1})") | |
chunk_start, chunk_end = line_chunk_ranges[input_index] | |
chunk = ['<s>', '<decoder-only>', '</s>'] + line_tokens[chunk_start:chunk_end] | |
return tokens_to_token_ids(self.tokenizer, self.max_length, chunk) | |
class ValidationDataset(IndexedFileDataset): | |
def __init__(self, validation_file_path: str, cpu_count: int, max_length: int, tokenizer: RobertaTokenizer): | |
super().__init__(validation_file_path, cpu_count) | |
self.max_length = max_length | |
self.tokenizer = tokenizer | |
def _get_line_input_count(self, line_index: int): | |
# the validation set is just input->output on each line | |
# input is left-truncated if need be | |
return 1 | |
def _get_line_input(self, line_index: int, input_index: int): | |
if input_index != 0: | |
raise IndexError("Index out of range") | |
line = self._get_line(line_index) | |
obj = json.loads(line) | |
# replace \n with </s>, normalize spacing | |
left_context = obj["leftContext"] | |
left_context = left_context.replace("\n", " </s> ") | |
left_context = left_context.split() | |
left_context = " ".join(left_context) | |
tokens = str_to_tokens(left_context, self.tokenizer) | |
# truncate from the left side and add prefix | |
tokens = ["<s>", "<decoder-only>", "</s>"] + tokens[-(self.max_length - 3):] | |
input_tokens = tokens_to_token_ids(self.tokenizer, self.max_length, tokens) | |
return input_tokens, obj["groundTruth"] | |
def main(args): | |
set_seed(args.seed) | |
verify_files_exist(args.dataset_folder) | |
os.makedirs(get_model_folder_path(args.dataset_folder), exist_ok=True) | |
previous_training_details = get_training_details(args.dataset_folder) | |
previous_epochs = len(previous_training_details["models"]) | |
remaining_epochs = args.num_epochs - previous_epochs | |
if remaining_epochs <= 0: | |
log(f"This model has already trained for {len(previous_training_details['models'])} " | |
f"out of {args.num_epochs} epochs, exiting") | |
exit(0) | |
model, tokenizer = prepare_model(args.model_name, args.max_output_length, args.beam_size) | |
if previous_epochs > 0: | |
log(f"Found a model that was already trained for {previous_epochs} epoch(s), loading it") | |
model.load_state_dict(torch.load(get_model_file_path( | |
args.dataset_folder, | |
previous_training_details["models"][-1]["modelName"])) | |
) | |
if torch.cuda.is_available(): | |
log("CUDA available, using GPU") | |
if args.local_rank != -1: | |
log(f"Using distributed GPU [{args.local_rank}]") | |
torch.cuda.set_device(args.local_rank) | |
device = torch.device("cuda") | |
model = model.to(device) | |
else: | |
log("CUDA not available, using CPU") | |
device = torch.device("cpu") | |
model = model.to(device) | |
if args.local_rank != -1: | |
dist.init_process_group( | |
backend=args.dist_backend, | |
init_method=args.dist_url, | |
rank=args.rank, | |
world_size=args.world_size, | |
) | |
model = DDP(model, device_ids=[args.local_rank], find_unused_parameters=True) | |
log("Loading train dataset") | |
train_dataset = TrainDataset( | |
train_file_path=get_train_file_path(args.dataset_folder), | |
cpu_count=args.cpu_count, | |
max_length=args.max_input_length + args.max_output_length, | |
tokenizer=tokenizer, | |
chunk_overlap=args.chunk_overlap | |
) | |
if args.local_rank == -1: | |
train_sampler = RandomSampler(train_dataset) | |
else: | |
train_sampler = DistributedSampler(train_dataset, num_replicas=args.world_size, rank=args.rank) | |
train_dataloader = DataLoader( | |
train_dataset, | |
sampler=train_sampler, | |
batch_size=args.batch_size // args.gradient_accumulation_steps, | |
num_workers=args.cpu_count, | |
pin_memory=True, | |
) | |
log(f"Window size avg: {len(train_dataset.index)/len(train_dataset.lines)}") | |
log(f"Loaded train dataset: {len(train_dataloader)} batches") | |
if args.rank <= 0: | |
log("Loading validation dataset") | |
validation_dataset = ValidationDataset( | |
validation_file_path=get_validation_file_path(args.dataset_folder), | |
cpu_count=args.cpu_count, | |
max_length=args.max_input_length, | |
tokenizer=tokenizer | |
) | |
validation_sampler = SequentialSampler(validation_dataset) | |
validation_dataloader = DataLoader( | |
validation_dataset, | |
sampler=validation_sampler, | |
batch_size=args.batch_size, | |
num_workers=args.cpu_count, | |
pin_memory=True, | |
) | |
log(f"Loaded validation dataset: {len(validation_dataloader)} batches") | |
no_decay = ['bias', 'LayerNorm.weight'] | |
optimizer_grouped_parameters = [ | |
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], | |
'weight_decay': args.weight_decay}, | |
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} | |
] | |
# TODO: Use the non deprecated optimizer | |
# TODO: Research how this works to set good params and to make sure it is implemented correctly | |
optimizer = AdamW( | |
optimizer_grouped_parameters, | |
lr=args.learning_rate, | |
eps=args.adam_epsilon | |
) | |
scheduler = get_linear_schedule_with_warmup( | |
optimizer, | |
num_warmup_steps=int(len(train_dataloader) * args.num_epochs * 0.1), | |
num_training_steps=len(train_dataloader) * args.num_epochs | |
) | |
if args.local_rank != -1: | |
dist.barrier() | |
if args.rank <= 0: | |
log( | |
"***** Running Training *****\n" + | |
"\tNum examples = %d\n" % len(train_dataset) + | |
"\tBatch size = %d\n" % args.batch_size + | |
"\tNum epochs = %d\n" % args.num_epochs + | |
"\tSteps per epoch = %d" % len(train_dataloader) | |
) | |
model.train() | |
nb_tr_examples, nb_tr_steps, tr_loss, global_step, best_accuracy, best_loss = 0, 0, 0, 0, 0, 1e6 | |
losses = [] | |
# TODO: Make tqdm bars more friendly to output file | |
for epoch in range(previous_epochs, args.num_epochs): | |
if args.rank <= 0: | |
log(f"Starting epoch {epoch}") | |
for idx, batch in enumerate(train_dataloader): | |
source_ids = torch.transpose(torch.stack(batch), 0, 1).to(device).contiguous() | |
loss, _, _ = model(source_ids, True) | |
losses.append(loss.item()) | |
if args.gradient_accumulation_steps > 1: | |
loss = loss / args.gradient_accumulation_steps | |
tr_loss += loss.item() | |
if (idx + 1) % 100 == 0: | |
# TODO: Add loss plot | |
log("epoch %d step %d loss %f" % (epoch, idx + 1, round(np.mean(losses[-100:]), 4))) | |
nb_tr_examples += source_ids.size(0) | |
nb_tr_steps += 1 | |
loss.backward() | |
if (nb_tr_steps + 1) % args.gradient_accumulation_steps == 0: | |
# Update parameters | |
optimizer.step() | |
optimizer.zero_grad() | |
scheduler.step() | |
global_step += 1 | |
# Eval model with validation dataset | |
tr_loss = 0 | |
nb_tr_examples, nb_tr_steps = 0, 0 | |
if args.local_rank != -1: | |
dist.barrier() | |
if args.rank <= 0: | |
log("***** Running Validation *****") | |
model.eval() | |
# See if its possible to make this parallel too | |
EM = 0.0 | |
edit_sim = 0.0 | |
for i, (batch, ground_truths) in enumerate(progress_enumerator( | |
validation_dataloader, | |
lambda i, _: f"Validating [{i} / {len(validation_dataloader) - 1}]", | |
total=len(validation_dataloader), | |
report_eta=True, | |
)): | |
source_ids = torch.transpose(torch.stack(batch), 0, 1).to(device).contiguous() | |
with torch.no_grad(): | |
predict = model.module if hasattr(model, 'module') else model | |
preds = predict(source_ids=source_ids) | |
for j, (gt, pred) in enumerate(zip(ground_truths, preds)): | |
t = pred[0].cpu().numpy() | |
t = list(t) | |
if 0 in t: | |
t = t[:t.index(0)] | |
pred = tokenizer.decode(t, clean_up_tokenization_spaces=False) | |
if "</s>" in pred: | |
pred = pred[:pred.index("</s>")] | |
pred = " ".join(pred.strip().split()) | |
gt = " ".join(gt.strip().split()) | |
if i == 0 and j < 5: | |
log(f"Validation example {j}:\n*** Prediction ***\n{pred}\n\n*** Ground Truth ***\n{gt}\n***") | |
if pred == gt: | |
EM += 1 | |
edit_sim += fuzz.ratio(pred, gt) | |
EM /= len(preds) | |
edit_sim /= len(preds) | |
model.train() | |
validation_accuracy = round(EM * 100, 2) | |
log("\t%s = %s " % ("Acc", str(validation_accuracy))) | |
log("\t%s = %s " % ("Edit sim", str(round(edit_sim, 2)))) | |
log(" " + "*" * 20) | |
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self | |
model_name = f"epoch-{epoch}" | |
torch.save(model_to_save.state_dict(), get_model_file_path(args.dataset_folder, model_name)) | |
training_details = get_training_details(args.dataset_folder) | |
training_details["models"].append({ | |
"modelName": model_name, | |
"exactMatch": validation_accuracy, | |
"editSim": edit_sim, | |
}) | |
save_training_details(args.dataset_folder, training_details) | |
# wait for all ranks to reach this barrier | |
# this ensures that all processes wait while rank 0 is validating, saving the model, and saving training details | |
if args.local_rank != -1: | |
dist.barrier() | |
if __name__ == '__main__': | |
args = parse_args() | |
if args.slurm: | |
set_slurm_args(args) | |
config_text = "" | |
config_text += "*** Config ***\n" | |
config_text += "\n".join(map(lambda kv: f"{kv[0]}: {kv[1]}", args.__dict__.items())) + "\n" | |
config_text += "**************" | |
log(config_text) | |
main(args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment