Last active
May 30, 2017 20:45
-
-
Save mk-fg/396745b36f77a9fe84c0b27d92349fe7 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import itertools as it, operator as op, functools as ft | |
import os, sys, errno, argparse, logging, secrets, re, json, shutil | |
import contextlib, inspect, tempfile, pathlib, collections, enum | |
import asyncio, asyncio.subprocess, socket, signal, heapq, struct | |
import aiohttp | |
class TVFConfig: | |
slice_start = 0 | |
slice_len = None | |
slice_scatter_len = None | |
slice_scatter_interval = None | |
aria2c_opts = None # overrides opts in aria2_conf_func | |
ytdl_opts = None | |
aiohttp_opts = dict(read_timeout=30, conn_timeout=60) | |
verbose = False | |
use_temp_dirs = True | |
keep_tempfiles = False | |
file_part_suffix = False | |
file_append_batch = 20 | |
file_append_max = 150 # should not be needed normally | |
compat_windows = sys.platform == 'win32' | |
aria2_cmd = 'aria2c' | |
aria2_conf_func = lambda self, opts: '\n'.join( [ | |
'summary-interval=0', | |
f'console-log-level={opts.log_level}', | |
'enable-rpc=true', | |
f'rpc-listen-port={opts.port}', | |
f'rpc-secret={opts.key}', | |
f'stop-with-process={os.getpid()}', | |
'no-netrc', | |
'always-resume=false', | |
'max-concurrent-downloads=5', | |
'max-connection-per-server=5', | |
'max-file-not-found=5', | |
'max-tries=8', | |
'timeout=15', | |
'connect-timeout=10', | |
'lowest-speed-limit=100K', | |
f'user-agent={opts.ua}' ] | |
+ (self.aria2c_opts or list()) + [''] ) | |
aria2_term_wait = 10, 20 # (clean, term), clean should be >3s | |
aria2_startup_checks = 8, 10.0 | |
aria2_queue_batch = 50 | |
aria2_ws_heartbeat = 20.0 | |
aria2_ws_debug = False # very noisy | |
aria2_ws_debug_signals = False | |
class LogMessage(object): | |
def __init__(self, fmt, a, k): self.fmt, self.a, self.k = fmt, a, k | |
def __str__(self): return self.fmt.format(*self.a, **self.k) if self.a or self.k else self.fmt | |
class LogStyleAdapter(logging.LoggerAdapter): | |
def __init__(self, logger, extra=None): | |
super(LogStyleAdapter, self).__init__(logger, extra or {}) | |
def log(self, level, msg, *args, **kws): | |
if not self.isEnabledFor(level): return | |
log_kws = {} if 'exc_info' not in kws else dict(exc_info=kws.pop('exc_info')) | |
msg, kws = self.process(msg, kws) | |
self.logger._log(level, LogMessage(msg, args, kws), (), log_kws) | |
get_logger = lambda name: LogStyleAdapter(logging.getLogger(name)) | |
class AsyncExitStack: | |
# Might be merged to 3.7, see https://bugs.python.org/issue29302 | |
# Implementation from https://gist.github.com/thehesiod/b8442ed50e27a23524435a22f10c04a0 | |
def __init__(self): | |
self._exit_callbacks = collections.deque() | |
def pop_all(self): | |
new_stack = type(self)() | |
new_stack._exit_callbacks = self._exit_callbacks | |
self._exit_callbacks = collections.deque() | |
return new_stack | |
def push(self, exit_obj): | |
_cb_type = type(exit_obj) | |
try: | |
exit_method = getattr(_cb_type, '__aexit__', None) | |
if exit_method is None: exit_method = _cb_type.__exit__ | |
except AttributeError: self._exit_callbacks.append(exit_obj) | |
else: self._push_cm_exit(exit_obj, exit_method) | |
return exit_obj | |
@staticmethod | |
def _create_exit_wrapper(cm, cm_exit): | |
if inspect.iscoroutinefunction(cm_exit): | |
async def _exit_wrapper(exc_type, exc, tb): | |
return await cm_exit(cm, exc_type, exc, tb) | |
else: | |
def _exit_wrapper(exc_type, exc, tb): | |
return cm_exit(cm, exc_type, exc, tb) | |
return _exit_wrapper | |
def _push_cm_exit(self, cm, cm_exit): | |
_exit_wrapper = self._create_exit_wrapper(cm, cm_exit) | |
_exit_wrapper.__self__ = cm | |
self.push(_exit_wrapper) | |
@staticmethod | |
def _create_cb_wrapper(callback, *args, **kwds): | |
if inspect.iscoroutinefunction(callback): | |
async def _exit_wrapper(exc_type, exc, tb): await callback(*args, **kwds) | |
else: | |
def _exit_wrapper(exc_type, exc, tb): callback(*args, **kwds) | |
return _exit_wrapper | |
def callback(self, callback, *args, **kwds): | |
_exit_wrapper = self._create_cb_wrapper(callback, *args, **kwds) | |
_exit_wrapper.__wrapped__ = callback | |
self.push(_exit_wrapper) | |
return callback | |
def _shutdown_loop(self, *exc_details): | |
received_exc = exc_details[0] is not None | |
frame_exc = sys.exc_info()[1] | |
def _fix_exception_context(new_exc, old_exc): | |
while True: | |
exc_context = new_exc.__context__ | |
if exc_context is old_exc: return | |
if exc_context is None or exc_context is frame_exc: break | |
new_exc = exc_context | |
new_exc.__context__ = old_exc | |
suppressed_exc = pending_raise = False | |
while self._exit_callbacks: | |
cb = self._exit_callbacks.pop() | |
try: | |
cb_result = yield cb(*exc_details) | |
if cb_result: | |
suppressed_exc, pending_raise = True, False | |
exc_details = (None, None, None) | |
except: | |
new_exc_details = sys.exc_info() | |
_fix_exception_context(new_exc_details[1], exc_details[1]) | |
pending_raise, exc_details = True, new_exc_details | |
if pending_raise: | |
try: | |
fixed_ctx = exc_details[1].__context__ | |
raise exc_details[1] | |
except BaseException: | |
exc_details[1].__context__ = fixed_ctx | |
raise | |
return received_exc and suppressed_exc | |
async def enter(self, cm): | |
_cm_type = type(cm) | |
_exit = getattr(_cm_type, '__aexit__', None) | |
if _exit is not None: result = await _cm_type.__aenter__(cm) | |
else: | |
_exit = _cm_type.__exit__ | |
result = _cm_type.__enter__(cm) | |
self._push_cm_exit(cm, _exit) | |
return result | |
async def close(self): | |
await self.__aexit__(None, None, None) | |
async def __aenter__(self): return self | |
async def __aexit__(self, *exc_details): | |
gen = self._shutdown_loop(*exc_details) | |
try: | |
result = next(gen) | |
while True: | |
try: | |
if inspect.isawaitable(result): result = await result | |
result = gen.send(result) | |
except StopIteration: raise | |
except BaseException as e: result = gen.throw(e) | |
except StopIteration as e: return e.value | |
def add_stack_wrappers(cls): | |
def _make_wrapper(func): | |
async def _wrapper(self, *args, **kws): | |
async with AsyncExitStack() as ctx: | |
return await func(self, ctx, *args, **kws) | |
return ft.wraps(func)(_wrapper) | |
for name, func in inspect.getmembers(cls, inspect.isroutine): | |
if name.startswith('__'): continue | |
sig = inspect.signature(func) | |
if ( len(sig.parameters) <= 1 or | |
list(sig.parameters.values())[1].annotation is not AsyncExitStack ): continue | |
setattr(cls, name, _make_wrapper(func)) | |
return cls | |
it_adjacent = lambda seq, n, fill=None: it.zip_longest(fillvalue=fill, *([iter(seq)] * n)) | |
it_adjacent_nofill = lambda seq, n:\ | |
( tuple(filter(lambda v: v is not it_adjacent, chunk)) | |
for chunk in it_adjacent(seq, n, fill=it_adjacent) ) | |
class adict(dict): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.__dict__ = self | |
def log_lines(log_func, lines, log_func_last=False): | |
if isinstance(lines, str): | |
lines = list(line.rstrip() for line in lines.rstrip().split('\n')) | |
uid = secrets.token_urlsafe(3) | |
for n, line in enumerate(lines, 1): | |
if isinstance(line, str): line = '[{}] {}', uid, line | |
else: line = ['[{}] {}'.format(uid, line[0])] + list(line[1:]) | |
if log_func_last and n == len(lines): log_func_last(*line) | |
else: log_func(*line) | |
def retries_within_timeout( tries, timeout, | |
backoff_func=lambda e,n: ((e**n-1)/e), slack=1e-2 ): | |
'Return list of delays to make exactly n tires within timeout, with backoff_func.' | |
a, b = 0, timeout | |
while True: | |
m = (a + b) / 2 | |
delays = list(backoff_func(m, n) for n in range(tries)) | |
error = sum(delays) - timeout | |
if abs(error) < slack: return delays | |
elif error > 0: b = m | |
else: a = m | |
def parse_pos_spec(pos): | |
if not pos: return | |
try: mins, secs = pos.rsplit(':', 1) | |
except ValueError: hrs, mins, secs = 0, 0, pos | |
else: | |
try: hrs, mins = mins.rsplit(':', 1) | |
except ValueError: hrs = 0 | |
try: | |
return sum( a*b for a, b in | |
zip([3600, 60, 1], map(float, [hrs, mins, secs])) ) | |
except ValueError as err: | |
raise argparse.ArgumentTypeError( | |
f'Failed to parse {pos!r} as [[hours:]minutes:]seconds value: {err}' ) | |
def suppress_err(exc_type, func, *args, **kws): | |
def _wrapper(): | |
try: func(*args, **kws) | |
except exc_type: pass | |
return _wrapper | |
class TVFError(Exception): pass | |
class TVFWSClosed(TVFError): pass | |
class TVFReqError(TVFError): pass | |
async def vod_fetch(loop, conf, vod_queue, list_formats_only=False): | |
exit_code, log = 1, get_logger('tvf.fetcher') | |
task = asyncio.Task.current_task(loop) | |
def sig_handler(sig, code): | |
log.info('Exiting on {} signal with code: {}', sig, code) | |
nonlocal exit_code | |
exit_code = code | |
task.cancel() | |
if not conf.compat_windows: | |
for sig, code in ('INT', 1), ('TERM', 0): | |
loop.add_signal_handler(getattr(signal, f'SIG{sig}'), ft.partial(sig_handler, sig, code)) | |
async with TVF(loop, conf, log) as fetcher: | |
fetch = fetcher.get if not list_formats_only else fetcher.ytdl_probe_formats | |
for n, (url, prefix) in enumerate(vod_queue, 1): | |
info_suffix = None if len(vod_queue) == 1 else f' [{n} / {len(vod_queue)}]' | |
try: await fetch(url, prefix, info_suffix=info_suffix) | |
except asyncio.CancelledError as err: break | |
except TVFError as err: | |
if err.args: log.exception('BUG: {}', err) | |
log.error('Failed to fetch VoD #{} {!r} (url: {!r}), exiting', n, prefix, url) | |
break | |
else: exit_code = 0 | |
return exit_code | |
class TVFFileCache: | |
update_lock = True | |
def __init__(self, prefix, ext, text=True): | |
self.b, self.path = '' if text else 'b', f'{prefix}.{ext}' | |
def __enter__(self): | |
assert self.update_lock | |
self.update_lock = False | |
return self | |
def __exit__(self, *err): | |
self.update_lock = True | |
@property | |
def cached(self): | |
if not os.path.exists(self.path): return None | |
with open(self.path, 'r' + self.b) as src: return src.read() | |
def update(self, data): | |
assert not self.update_lock | |
with open(self.path, 'w' + self.b) as dst: dst.write(data) | |
return data | |
class TVFAria2Proc: | |
def __init__(self, loop, conf, http, cmd_conf, key, port, path_func): | |
self.loop, self.conf, self.proc, self.http = loop, conf, None, http | |
self.ws = self.ws_ctx = self.ws_poll_task = self.ws_handlers = self.chunks = None | |
self.ws_closed = asyncio.Event() | |
self.cmd_conf, self.ws_key, self.ws_url, self.path_func = ( | |
cmd_conf, f'token:{key}', f'ws://localhost:{port}/jsonrpc', path_func ) | |
self.log = get_logger('tvf.aria2') | |
async def __aenter__(self): | |
cmd = [self.conf.aria2_cmd, '--conf-path', self.cmd_conf] | |
self.log.debug('Starting aria2c daemon: {}', ' '.join(cmd)) | |
self.proc = await asyncio.create_subprocess_exec(*cmd) | |
self.ws_ctx = AsyncExitStack() | |
return self | |
async def __aexit__(self, *err): | |
proc_term_delay = 0 | |
if self.ws_ctx: | |
if not self.ws_closed.is_set(): | |
await self.req('shutdown', sync=False) | |
proc_term_delay = self.conf.aria2_term_wait[0] | |
await self.ws_ctx.close() | |
if self.ws_handlers: | |
for func in self.ws_handlers.values(): func(StopIteration) | |
if self.ws_poll_task: await self.ws_poll_task | |
self.ws = self.ws_ctx = self.ws_poll_task = self.ws_handlers = None | |
if self.proc: | |
if proc_term_delay > 0: | |
await asyncio.wait_for(self.proc.wait(), proc_term_delay) | |
try: | |
self.proc.send_signal(0) | |
self.log.debug( 'Sending SIGTERM to aria2c pid' | |
' (timeout={:.2f}s): {}', self.conf.aria2_term_wait[1], self.proc.pid ) | |
self.proc.terminate() | |
try: await asyncio.wait_for(self.proc.wait(), timeout=self.conf.aria2_term_wait[1]) | |
except asyncio.TimeoutError: | |
self.log.debug('Sending SIGKILL to aria2c pid: {}', self.proc.pid) | |
self.proc.kill() | |
except OSError: pass | |
exit_code = await self.proc.wait() | |
if exit_code: self.log.warning('aria2c has exited with error code: {}', exit_code) | |
self.proc = None | |
async def connect(self): | |
ts, delays = self.loop.time(), retries_within_timeout(*self.conf.aria2_startup_checks) | |
ts_next, ts_max, delay_iter = ts, ts + sum(delays), iter(delays) | |
while True: | |
ts = self.loop.time() | |
if ts >= ts_max: break | |
timeout = ts_max - ts | |
self.log.debug('Connecting to aria2c ws url (timeout={:.2f}s): {}', timeout, self.ws_url) | |
try: | |
self.ws = await self.ws_ctx.enter(self.http.ws_connect( | |
self.ws_url, heartbeat=self.conf.aria2_ws_heartbeat, timeout=timeout )) | |
except aiohttp.ClientError as err: self.log.debug('aria2c conn attempt error: {}', err) | |
else: | |
self.log.debug('Connected to aria2c json-rpc ws') | |
break | |
ts = self.loop.time() | |
if ts_next > ts: | |
await asyncio.sleep(ts_next - ts) | |
ts = self.loop.time() | |
while ts_next < ts: | |
try: ts_next += next(delay_iter) | |
except StopIteration: break | |
if not self.ws: | |
self.ws_closed.set() | |
raise TVFError( 'aria2c connection failed' | |
' (max_tries={}, timeout={})'.format(*self.conf.aria2_startup_checks) ) | |
jrpc_uid_ns = secrets.token_urlsafe(3) | |
self.ws_jrpc_uid_iter = (f'{jrpc_uid_ns}.{n}' for n in range(1, 2**30)) | |
self.ws_handlers, self.ws_poll_task = dict(), self.loop.create_task(self._ws_poller()) | |
@contextlib.contextmanager | |
def ws_close_wrap(self): | |
'Raises TVFWSClosed in current task on ws_closed event.' | |
async def _ev_wait(): | |
nonlocal triggered | |
await self.ws_closed.wait() | |
triggered = True | |
task.cancel() | |
triggered, task = False, asyncio.Task.current_task(self.loop) | |
ev_wait_task = self.loop.create_task(_ev_wait()) | |
try: yield | |
except asyncio.CancelledError: | |
if not triggered: raise | |
raise TVFWSClosed() from None | |
finally: ev_wait_task.cancel() | |
async def _ws_poller(self): | |
try: | |
async for msg in self.ws: | |
if msg.type == aiohttp.WSMsgType.TEXT: | |
if self.conf.aria2_ws_debug: self.log.debug('rpc-msg: {}', msg.data) | |
msg_data, hs_discard = json.loads(msg.data), set() | |
if self.conf.aria2_ws_debug_signals and msg_data.get('method'): | |
self.log.debug('rpc-signal: {}', msg.data) | |
for k, func in self.ws_handlers.items(): | |
if func(msg_data): hs_discard.add(k) | |
for k in hs_discard: del self.ws_handlers[k] | |
elif msg.type == aiohttp.WSMsgType.closed: break | |
elif msg.type == aiohttp.WSMsgType.error: | |
self.log.error('ws protocol error, aborting: {}', msg) | |
break | |
else: self.log.warning('Unhandled ws msg type {}, ignoring: {}', msg.type, msg) | |
except Exception as err: | |
self.log.exception('Unhandled ws handler failure, aborting: {}', err) | |
await self.ws.close() | |
self.ws_closed.set() | |
def _add_handler(self, func): self.ws_handlers[id(func)] = func | |
def _expect_handler(self, uid_or_func, fut_or_cb, oneshot, msg): | |
if msg is StopIteration: | |
if not callable(fut_or_cb): fut_or_cb.cancel() | |
return | |
if callable(uid_or_func) and not uid_or_func(msg): return | |
else: | |
uid = msg.get('id') or msg.get('method') | |
if uid.startswith('aria2.'): uid = uid[6:] | |
if uid != uid_or_func: return | |
if callable(fut_or_cb): fut_or_cb(msg) | |
else: fut_or_cb.set_result(msg) | |
return oneshot | |
def expect(self, uid_or_func, fut_or_cb=None, oneshot=False): | |
if not fut_or_cb: fut_or_cb = asyncio.Future() | |
self._add_handler(ft.partial( | |
self._expect_handler, uid_or_func, fut_or_cb, oneshot )) | |
return fut_or_cb | |
async def req(self, method, *params, sync=True): | |
if method != 'system.multicall': | |
method, params = f'aria2.{method}', [self.ws_key] + list(params) | |
res = req_uid = next(self.ws_jrpc_uid_iter) | |
data_req = dict( | |
jsonrpc='2.0', id=req_uid, | |
method=method, params=params ) | |
if self.conf.aria2_ws_debug: | |
self.log.debug( 'rpc-req{} [{}]: {}', | |
'-sync' if sync else '', req_uid, json.dumps(data_req) ) | |
await self.ws.send_str(json.dumps(data_req)) | |
if sync: | |
with self.ws_close_wrap(): res = await self.expect(req_uid) | |
if self.conf.aria2_ws_debug: | |
self.log.debug('rpc-res-sync [{}]: {}', req_uid, json.dumps(res)) | |
if 'result' not in res: | |
raise TVFReqError(f'Request failed (method={method}, params={params}): {res}') | |
res = res['result'] | |
return res | |
async def req_ok(self, method, *params): | |
res = await self.req(method, *params) | |
if res != 'OK': raise TVFError(f'aria2c command failed: {method}{params}') | |
def _queue_params(self, gid, url, pos=None): | |
return ( | |
[[url], dict(gid=gid, out=str(self.path_func(gid)))] | |
+ ([] if not pos else [pos]) ) | |
async def queue(self, gid, url, pos=None): | |
res = await self.req('addUri', *self._queue_params(gid, url, pos)) | |
if res != gid: | |
self.log.error('addUri error (gid={}): {}', gid, res) | |
raise TVFError(f'Failed to queue chunk URL to aria2c') | |
async def queue_batch(self, *gid_urls, pos=None): | |
if pos is not None: pos = iter(range(pos, 2**30)) | |
res = await self.req('system.multicall', list( | |
dict(methodName='aria2.addUri', params=( | |
[self.ws_key] + self._queue_params(gid, url, pos and next(pos)) )) | |
for gid, url in gid_urls )) | |
res_chk = list([gid] for gid, url in gid_urls) | |
if res != res_chk: | |
log_lines(self.log.error, [ | |
'Result gid match check failed for submitted urls.', | |
(' expected: {}', ', '.join(str(r[0]) for r in res_chk)), | |
(' returned: {}', ', '.join(( str(r[0]) | |
if isinstance(r, list) else repr(r) ) for r in res)) ]) | |
raise TVFError(f'Failed to queue {len(gid_urls)} chunk URLs to aria2c') | |
async def queue_chunks(self, chunks): | |
self.chunks = chunks | |
count, gid_last = 0, None | |
with self.ws_close_wrap(): | |
for gid_url_batch in it_adjacent_nofill(chunks, self.conf.aria2_queue_batch): | |
await self.queue_batch(*gid_url_batch) | |
count, gid_last = count + len(gid_url_batch), gid_url_batch[-1][0] | |
return count, gid_last | |
class TVFChunkIter: | |
def __init__(self, conf, url_base, pls_text, gids_done=None): | |
self.conf = conf | |
self.gid_urls = self.parse_pls(url_base, pls_text) | |
self.gids_done = set(gids_done or list()) | |
def __iter__(self): | |
for gid, url in self.gid_urls.items(): | |
if gid in self.gids_done: continue | |
yield gid, url | |
def scatter_check_coro(self, *scatter_opts): | |
(a, b), res = scatter_opts, True | |
while True: | |
td = yield res | |
if a > 0: a -= td | |
if a <= 0: res = False | |
b -= td | |
if b <= 0: (a, b), res = scatter_opts, True | |
def parse_pls(self, url_base, pls_text): | |
slice_start, slice_len = self.conf.slice_start, self.conf.slice_len | |
scatter_check = self.scatter_check_coro( | |
self.conf.slice_scatter_len, self.conf.slice_scatter_interval )\ | |
if self.conf.slice_scatter_len is not None else None | |
if scatter_check: next(scatter_check) | |
# aria2c requires 16-char gid, format used here fits number in | |
# first 6 chars because it looks nice in (tuncated) aria2c output | |
gid_iter = iter(map('{:06d}0000000000'.format, range(1, 999999))) | |
chunks = collections.OrderedDict() | |
for line in pls_text.splitlines(): | |
m = re.search(r'^#EXTINF:([\d.]+),', line) | |
if m: | |
td = float(m.group(1)) | |
slice_start -= td | |
if slice_start > 0: continue | |
if not line or line.startswith('#'): continue | |
if slice_len and slice_len + slice_start < 0: break | |
if scatter_check and not scatter_check.send(td): continue | |
chunks[next(gid_iter)] = f'{url_base}/{line}' | |
return chunks | |
def mark_needed(self, gid, check=True): | |
if check and gid not in self.gids_done: | |
raise TVFError(f'BUG: gid was not marked as done: {gid}') | |
self.gids_done.discard(gid) | |
def mark_done(self, gid): | |
self.gids_done.add(gid) | |
@property | |
def finished(self): | |
return len(self.gid_urls) == len(self.gids_done) | |
@add_stack_wrappers | |
class TVF: | |
def __init__(self, loop, conf, log): | |
self.loop, self.conf = loop, conf | |
self.log = log or get_logget('tvf.fetcher') | |
async def __aenter__(self): return self | |
async def __aexit__(self, *err): pass | |
async def ytdl_run(self, ytdl_op, *args, check=True, out=False, **popen_kws): | |
cmd = ['youtube-dl'] | |
if self.conf.verbose: cmd.append('--verbose') | |
cmd = cmd + [ytdl_op] + (self.conf.ytdl_opts or list()) + list(args) | |
self.log.debug('Running "youtube-dl {}" command: {}', ytdl_op, ' '.join(cmd)) | |
if out: popen_kws.setdefault('stdout', asyncio.subprocess.PIPE) | |
proc = await asyncio.create_subprocess_exec(*cmd, **popen_kws) | |
if out: proc_stdout = (await proc.stdout.read()).strip() | |
exit_code = await proc.wait() | |
if check and exit_code != 0: | |
raise TVFError(f'"youtube-dl {ytdl_op}" command exited with error (code: {exit_code})') | |
return exit_code if not out else proc_stdout.decode() | |
async def ytdl_probe_formats(self, url, file_prefix, info_suffix=None): | |
self.log.info( '--- Listing formats available' | |
' for VoD {} (url: {}){}', file_prefix, url, info_suffix or '' ) | |
return await self.ytdl_run('--list-formats', url, check=False) | |
async def get(self, ctx: AsyncExitStack, url, file_prefix, info_suffix=None): | |
file_prefix = pathlib.Path(file_prefix) | |
file_dst_path = file_prefix.with_suffix('.mp4') | |
if self.conf.use_temp_dirs: | |
name = file_prefix.name | |
file_prefix_dir = file_prefix.with_suffix('.tmp') | |
file_prefix = file_prefix_dir / name | |
# Assuming file_dst is fully downloaded if there's no playlist tempfile around | |
# It's not a problem even with --keep-tempfiles, as file_dst_mark will be kept there too | |
if file_dst_path.exists() and ( | |
self.conf.file_part_suffix or not file_prefix.with_suffix('.m3u8').exists() ): | |
self.log.info( '--- Skipping download' | |
' for existing file: {} (rename/remove it to force)', file_dst_path ) | |
return | |
if self.conf.use_temp_dirs: file_prefix_dir.mkdir(exist_ok=True) | |
if self.conf.file_part_suffix: | |
file_dst_done, file_dst_path = file_dst_path, file_dst_path.with_suffix('.part.mp4') | |
vod_cache = ft.partial(TVFFileCache, file_prefix) | |
http = await ctx.enter(aiohttp.ClientSession(**self.conf.aiohttp_opts)) | |
self.log.info('--- Downloading VoD {} (url: {}){}', file_dst_path, url, info_suffix or '') | |
with vod_cache('m3u8.url') as vc: | |
url_pls = vc.cached or vc.update(await self.ytdl_run('--get-url', url, out=True)) | |
assert ' ' not in url_pls, url_pls # can return URLs of multiple flvs for *really* old VoDs | |
url_base = url_pls.rsplit('/', 1)[0] | |
with vod_cache('m3u8.ua') as vc: | |
ua = vc.cached or vc.update(await self.ytdl_run('--dump-user-agent', out=True)) | |
with vod_cache('m3u8') as vc: | |
pls = vc.cached | |
if not pls: | |
self.log.debug('Fetching playlist from URL: {}', url_pls) | |
async with http.get(url_pls, headers={'User-Agent': ua}) as res: | |
pls = vc.update(await res.text()) | |
# Config is always updated between aria2c runs, to account for any cli opts changes | |
with vod_cache('aria2c_conf') as vc: | |
key = secrets.token_urlsafe(18) | |
with contextlib.closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: | |
s.bind(('localhost', 0)) | |
addr, port = s.getsockname() | |
vc.update(self.conf.aria2_conf_func(adict( | |
key=key, port=port, ua=ua, | |
log_level='notice' if self.conf.verbose else 'warn' ))) | |
aria2c_conf_path = vc.path | |
# file_dst_mark stores (pos, seq) tuples, appending each update to the end | |
# It is assumed here that such tiny writes will be atomic | |
file_dst_pos = file_dst_seq = 0 # position in dst file, sequential gid int | |
file_dst_mark_fmt = '>QI' | |
file_dst_mark = await ctx.enter(file_prefix.with_suffix('.pos').open('a+b')) | |
if file_dst_mark.tell() != 0: | |
file_dst_mark.seek(-struct.calcsize(file_dst_mark_fmt), os.SEEK_END) | |
file_dst_pos, file_dst_seq = struct.unpack(file_dst_mark_fmt, file_dst_mark.read()) | |
if not file_dst_path.exists() and file_dst_pos > 0: | |
raise TVFError( f'Missing partial-dst-file' | |
f' (expecting bytes={file_dst_pos:,} gid-seq={file_dst_seq}): {file_dst_path}' ) | |
file_dst = file_dst_path.open('a+b') | |
if file_dst.tell() < file_dst_pos: | |
raise TVFError( f'Missing chunk of partial-dst-file' | |
f' (expecting size>={file_dst_pos:,} actual-size={file_dst.tell():,}): {file_dst_path}' ) | |
file_dst.seek(file_dst_pos) | |
file_dst.truncate() | |
# seq - gid converted to a sequential int, to sort appended chunks | |
seq_heap, seq_chunks = list(), list() | |
gid_to_seq = lambda gid: int(gid[:6]) | |
seq_to_gid = lambda seq: f'{seq:06d}0000000000' | |
gid_pretty = lambda gid: gid[:6] | |
chunks = TVFChunkIter(self.conf, url_base, pls) | |
for gid, url in chunks: | |
if gid_to_seq(gid) > file_dst_seq: break | |
chunks.mark_done(gid) | |
chunk_path = lambda gid: file_prefix.with_suffix(f'.{gid}.mp4.chunk') | |
async with TVFAria2Proc( self.loop, | |
self.conf, http, aria2c_conf_path, key, port, chunk_path ) as aria2c: | |
await aria2c.connect() | |
# All chunks from iter(chunks) get queued for download here | |
info = await aria2c.req('getVersion') | |
self.log.debug( 'Starting downloads' | |
' (rpc-port={}, aria2-version={})...', port, info['version'] ) | |
count, gid_last = await aria2c.queue_chunks(chunks) | |
if gid_last: | |
self.log.info( '\n\n ------ Started {} downloads,' | |
' last gid: {} ------ \n', count, gid_pretty(gid_last) ) | |
else: self.log.info('No extra downloads were necessary') | |
### Main download-control loop ahead | |
# Download control strategy: | |
# - listen for aria2.onDownloadComplete event for each gid | |
# - append sequential-gid chunk to partial file, update mark-file, remove chunk | |
# - non-sequential gids - just remember and check/append next time | |
# - if >file_append_max non-sequential-gid chunks downloaded - abort | |
# - for all error/paused/stopped events: prepend to download queue | |
# - stop on chunks.finished (all chunks marked as "done") | |
ev_t = enum.Enum('ev_type', 'complete stop pause error') | |
evq, ev_tuple = asyncio.Queue(), collections.namedtuple('ev', 't data') | |
for t in ev_t: | |
aria2c.expect( f'onDownload{t.name.title()}', | |
lambda e,t=t: evq.put_nowait(ev_tuple(t, e)) ) | |
while not chunks.finished: | |
with aria2c.ws_close_wrap(): | |
ev = await evq.get() | |
for e in ev.data['params']: | |
if ev.t == ev_t.complete: # completed chunk goes to seq_heap | |
heapq.heappush(seq_heap, gid_to_seq(e['gid'])) | |
await aria2c.req('removeDownloadResult', e['gid'], sync=False) | |
chunks.mark_done(e['gid']) | |
else: # stop/pause/error - retry immediately | |
# Limit on forced retries (on top of what aria2 does) for same gid can be added here | |
if self.log.isEnabledFor(logging.DEBUG): | |
status = await aria2c.req( 'tellStatus', e['gid'], | |
['status', 'errorCode', 'errorMessage', 'completedLength', 'totalLength'] ) | |
self.log.debug( 'Chunk {} download' | |
' stopped or failed, retrying: {}', e['gid'], status ) | |
await aria2c.req_ok('removeDownloadResult', e['gid']) | |
await aria2c.queue(e['gid'], chunks.gid_urls[e['gid']], 0) | |
if chunks.finished or len(seq_heap) > self.conf.file_append_batch: | |
while seq_heap and seq_heap[0] == file_dst_seq + 1: | |
seq = heapq.heappop(seq_heap) | |
p = chunk_path(seq_to_gid(seq)) | |
with p.open('rb') as chunk: shutil.copyfileobj(chunk, file_dst) | |
file_dst_pos, file_dst_seq = file_dst.tell(), seq | |
seq_chunks.append(p) | |
if seq_chunks: | |
file_dst.flush() | |
file_dst_mark.write(struct.pack(file_dst_mark_fmt, file_dst_pos, file_dst_seq)) | |
file_dst_mark.flush() | |
for p in seq_chunks: os.unlink(p) | |
# self.log.debug('\nExtended partial dst file by {} chunk(s)', len(seq_chunks)) | |
seq_chunks.clear() | |
if len(seq_heap) > self.conf.file_append_max: | |
err_msg = ( 'Chunk-append sequence is broken' | |
f' (pending={len(seq_heap)} max={self.conf.file_append_max})' ) | |
log_lines(self.log.error, [ ('{}:', err_msg), | |
(' expecting gid: {}', seq_to_gid(file_dst_seq)), | |
(' first gid among pending: {}', seq_to_gid(seq_heap[0])) ]) | |
raise TVFError(err_msg) | |
### done! | |
file_dst.close() | |
file_dst_mark.close() | |
if self.conf.file_part_suffix: | |
file_dst_path.rename(file_dst_done) | |
file_dst_path = file_dst_done | |
if not self.conf.keep_tempfiles: | |
self.log.debug('Cleanup for prefix (dir={}): {}', self.conf.use_temp_dirs, file_prefix) | |
if self.conf.use_temp_dirs: shutil.rmtree(file_prefix_dir) | |
else: | |
for p in file_prefix.parent.glob(f'{file_prefix.name}.*'): | |
if p != file_dst_path: os.unlink(p) | |
self.log.info('--- Downloaded VoD file: {}{}', file_dst_path, info_suffix or '') | |
def main(args=None, conf=None): | |
if not conf: conf = TVFConfig() | |
import argparse | |
parser = argparse.ArgumentParser( | |
usage='%(prog)s [options] url file_prefix [url-2 file_prefix-2 ...]', | |
description='Grab a VoD or a specified slice of it from twitch.tv, properly.') | |
parser.add_argument('url', help='URL for a VoD to fetch.') | |
parser.add_argument('file_prefix', help='File prefix to assemble temp files under.') | |
parser.add_argument('more_url_and_prefix_pairs', nargs='*', | |
help='Any number of extra "url file_prefix" arguments can be specified.') | |
group = parser.add_argument_group('Download slice options') | |
group.add_argument('-s', '--start-pos', | |
type=parse_pos_spec, metavar='[[hours:]minutes:]seconds', | |
help='Only download video chunks after specified start position.' | |
' If multiple url/prefix args are specified, this option will be applied to all of them.') | |
group.add_argument('-l', '--length', | |
type=parse_pos_spec, metavar='[[hours:]minutes:]seconds', | |
help='Only download specified length of the video (from specified start or beginning).' | |
' If multiple url/prefix args are specified, this option will be applied to all of them.') | |
group.add_argument('-x', '--scatter', | |
metavar='[[hours:]minutes:]seconds/[[hours:]minutes:]seconds', | |
help='Out of whole video (or a chunk specified by --start and --length),' | |
' download only every N seconds (or mins/hours) out of M.' | |
' E.g. "1:00/10:00" spec here will download 1 first min of video out of every 10.' | |
' Idea here is to produce something like preview of the video to allow' | |
' to easily narrow down which part of it is interesting and worth downloading in full.') | |
group = parser.add_argument_group('youtube-dl/aria2c options') | |
group.add_argument('-F', '--ytdl-list-formats', | |
action='store_true', help='Do not download anything,' | |
' just list formats available for each specified URL and exit.') | |
group.add_argument('-y', '--ytdl-opts', | |
action='append', metavar='opts', | |
help='Extra opts for youtube-dl --get-url command.' | |
' Will be split on spaces, unless option is used multiple times.') | |
group.add_argument('-a', '--aria2c-opts', | |
action='append', metavar='opts', | |
help='Extra options to store in aria2c configuration file.' | |
' These are same as long-form command-line options, but without double-dash prefix.' | |
' Will be split on spaces, unless option is used multiple times.' | |
' Example: --aria2c-opts "lowest-speed-limit=20K max-overall-download-limit=1M"') | |
group = parser.add_argument_group('Misc/debug options') | |
group.add_argument('-p', '--use-part-suffix', | |
action='store_true', help='Use .part suffix for' | |
' partial destination file until it is fully downloaded.') | |
group.add_argument('-n', '--no-temp-dirs', | |
action='store_true', help='Do not create temporary' | |
' directories for download parts, keep everything in the destination one.') | |
group.add_argument('-k', '--keep-tempfiles', | |
action='store_true', help='Do not remove all the' | |
' temporary files after successfully assembling resulting mp4.' | |
' Chunks in particular might be useful to download different but overlapping video slices.') | |
group.add_argument('--debug', action='store_true', help='Verbose operation mode.') | |
opts = parser.parse_args(sys.argv[1:] if args is None else args) | |
logging.basicConfig( | |
datefmt='%Y-%m-%d %H:%M:%S', | |
format='%(asctime)s :: {}%(levelname)s :: %(message)s'\ | |
.format('%(name)s ' if opts.debug else ''), | |
level=logging.DEBUG if opts.debug else logging.INFO ) | |
log = get_logger('tvf.main') | |
conf.slice_start, conf.slice_len = opts.start_pos or 0, opts.length | |
if opts.scatter: | |
scatter = opts.scatter.split('/', 1) | |
if len(scatter) != 2: | |
parser.error( f'Invalid value for -x/--scatter option' | |
' - {opts.scatter!r} - should be in "len/interval" format.' ) | |
conf.slice_scatter_len, conf.slice_scatter_interval = map(parse_pos_spec, scatter) | |
ytdl_opts = opts.ytdl_opts or list() | |
if len(ytdl_opts) == 1: ytdl_opts = ytdl_opts[0].strip().split() | |
aria2c_opts = opts.aria2c_opts or list() | |
if len(aria2c_opts) == 1: aria2c_opts = aria2c_opts[0].strip().split() | |
conf.verbose = opts.debug | |
conf.keep_tempfiles = opts.keep_tempfiles | |
conf.use_temp_dirs = not opts.no_temp_dirs | |
conf.file_part_suffix = opts.use_part_suffix | |
vod_queue, args = list(),\ | |
[opts.url, opts.file_prefix] + (opts.more_url_and_prefix_pairs or list()) | |
if len(args) % 2: | |
parser.error( 'Odd number of url/prefix args specified' | |
f' ({len(args)}), while these should always come in pairs.' ) | |
for url, prefix in it_adjacent(args, 2): | |
if re.search(r'^https?:', prefix): | |
if re.search(r'^https?:', url): | |
parser.error( 'Both url/file_prefix args seem' | |
f' to be an URL, only first one should be: {url} {prefix}' ) | |
prefix, url = url, prefix | |
log.warn( 'Looks like url/prefix args got' | |
' mixed-up, correcting that to prefix=%s url=%s', prefix, url ) | |
if not re.search(r'^https?://[^/]+/([^/]+/v|videos)/', url): | |
parser.error( 'Provided URL appears to be for the unsupported' | |
f' VoD format (only /user/v/ or /videos/ VoDs are supported): {url}' ) | |
vod_queue.append((url, prefix)) | |
if conf.compat_windows: | |
# Use the Proactor event loop on Windows for subprocess support | |
loop = asyncio.ProactorEventLoop() | |
asyncio.set_event_loop(loop) | |
log.debug('Starting vod_fetch loop...') | |
with contextlib.closing(asyncio.get_event_loop()) as loop: | |
exit_code = loop.run_until_complete(vod_fetch( | |
loop, conf, vod_queue, list_formats_only=opts.ytdl_list_formats )) | |
log.debug('Finished') | |
return exit_code | |
if __name__ == '__main__': sys.exit(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment