Created
June 12, 2023 09:10
-
-
Save dgomes/3879a46538adb06af92b498e9b53ad31 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding: utf-8 | |
__author__ = 'Mário Antunes' | |
__version__ = '1.0' | |
__email__ = '[email protected]' | |
__status__ = 'Production' | |
__license__ = 'MIT' | |
import sys | |
import json | |
import time | |
import logging | |
import argparse | |
import requests | |
from multiprocessing.pool import ThreadPool | |
class CustomFormatter(logging.Formatter): | |
grey = '\x1b[38;20m' | |
yellow = '\x1b[33;20m' | |
red = '\x1b[31;20m' | |
bold_red = '\x1b[31;1m' | |
reset = '\x1b[0m' | |
format = '%(message)s' | |
FORMATS = { | |
logging.DEBUG: grey + format + reset, | |
logging.INFO: grey + format + reset, | |
logging.WARNING: yellow + format + reset, | |
logging.ERROR: red + format + reset, | |
logging.CRITICAL: bold_red + format + reset | |
} | |
def format(self, record): | |
log_fmt = self.FORMATS.get(record.levelno) | |
formatter = logging.Formatter(log_fmt) | |
return formatter.format(record) | |
def progressbar(i, prefix="", n=100, size=100, out=sys.stdout): # Python3.6+ | |
count = n | |
def show(j): | |
x = int(size*j/count) | |
print(f"{prefix}[{u'█'*x}{('.'*(size-x))}] {j}/{count}", end='\r', file=out, flush=True) | |
if j >= n: | |
print("\n", flush=True, file=out) | |
show(i) | |
# create logger | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.DEBUG) | |
# create console handler with a higher log level | |
ch = logging.StreamHandler() | |
ch.setLevel(logging.DEBUG) | |
ch.setFormatter(CustomFormatter()) | |
logger.addHandler(ch) | |
def post_music(data): | |
# POST /music | |
logger.info(f'POST {args.u}/music') | |
r = requests.post(f'{args.u}/music', | |
data=data, | |
headers={'Content-Type': 'application/octet-stream'}, | |
timeout=args.t) | |
if r.ok: | |
json_response = r.json() | |
logger.info(f'{json_response}') | |
return json_response['music_id'] | |
else: | |
logger.error('Error: POST /music') | |
def post_music_id(music_id, music_list): | |
# POST /music/{id} | |
logger.info(f'POST {args.u}/music/{music_id}') | |
# Tracks | |
tracks = [1, 2] | |
r = requests.post(f'{args.u}/music/{music_id}', | |
json = tracks, | |
headers={'Content-Type': 'application/json'}, | |
timeout=args.t) | |
if r.ok: | |
json_response = r.json() | |
logger.info(f'{json_response}') | |
else: | |
logger.error(f'Error: GET /music/{music_id}') | |
def main(args): | |
pool = ThreadPool(processes=2) | |
# Load Test Music | |
with open('test1.mp3', 'rb') as f: | |
test_music_data = f.read() | |
# Load Eval Music | |
with open('test2.mp3', 'rb') as f: | |
eval_music_data = f.read() | |
# Post two musicsat the same time | |
music_ids = [] | |
for result in pool.map(post_music, [test_music_data, eval_music_data]): | |
music_ids.append(result) | |
# Shouldbe the last id in the list | |
eval_music_id = music_ids[-1] | |
# GET /music | |
logger.info(f'GET {args.u}/music/') | |
r = requests.get(f'{args.u}/music', timeout=args.t) | |
music_list = None | |
if r.ok: | |
music_list = r.json() | |
logger.info(f'{music_list}') | |
else: | |
logger.error('Error: GET /music') | |
# Process two musics at the same time | |
post_parameters = [(music_id, music_list) for music_id in music_ids] | |
results = [] | |
for result in pool.map(post_music_id, post_parameters): | |
results.append(result) | |
# GET /music/{id} | |
logger.info(f'GET {args.u}/music/{eval_music_id}') | |
done = False | |
progressbar(0) | |
while not done: | |
time.sleep(3) | |
r = requests.get(f'{args.u}/music/{eval_music_id}', timeout=args.t) | |
if r.ok: | |
json_response = r.json() | |
progress = json_response['progress'] | |
progressbar(progress) | |
if progress >= 100: | |
done = True | |
final_download_url = json_response['final'] | |
print('\n') | |
logger.info(f'{json_response}') | |
else: | |
logger.error(f'Error: GET /music/{eval_music_id}') | |
# Download final music | |
r = requests.get(final_download_url) | |
if r.ok: | |
with open('final.wav', 'wb') as f: | |
f.write(r.content) | |
else: | |
logger.error('Error: GET {final_download_url}') | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser(description='Split an audio track', formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |
parser.add_argument('-u', type=str, help='API URL', default='http://localhost:5000') | |
parser.add_argument('-t', type=int, help='Request Timeout', default=3) | |
args = parser.parse_args() | |
main(args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment