Skip to content

Instantly share code, notes, and snippets.

@mk-fg
Last active May 30, 2017 20:45
Show Gist options
  • Save mk-fg/396745b36f77a9fe84c0b27d92349fe7 to your computer and use it in GitHub Desktop.
Save mk-fg/396745b36f77a9fe84c0b27d92349fe7 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
import itertools as it, operator as op, functools as ft
import os, sys, errno, argparse, logging, secrets, re, json, shutil
import contextlib, inspect, tempfile, pathlib, collections, enum
import asyncio, asyncio.subprocess, socket, signal, heapq, struct
import aiohttp
class TVFConfig:
slice_start = 0
slice_len = None
slice_scatter_len = None
slice_scatter_interval = None
aria2c_opts = None # overrides opts in aria2_conf_func
ytdl_opts = None
aiohttp_opts = dict(read_timeout=30, conn_timeout=60)
verbose = False
use_temp_dirs = True
keep_tempfiles = False
file_part_suffix = False
file_append_batch = 20
file_append_max = 150 # should not be needed normally
compat_windows = sys.platform == 'win32'
aria2_cmd = 'aria2c'
aria2_conf_func = lambda self, opts: '\n'.join( [
'summary-interval=0',
f'console-log-level={opts.log_level}',
'enable-rpc=true',
f'rpc-listen-port={opts.port}',
f'rpc-secret={opts.key}',
f'stop-with-process={os.getpid()}',
'no-netrc',
'always-resume=false',
'max-concurrent-downloads=5',
'max-connection-per-server=5',
'max-file-not-found=5',
'max-tries=8',
'timeout=15',
'connect-timeout=10',
'lowest-speed-limit=100K',
f'user-agent={opts.ua}' ]
+ (self.aria2c_opts or list()) + [''] )
aria2_term_wait = 10, 20 # (clean, term), clean should be >3s
aria2_startup_checks = 8, 10.0
aria2_queue_batch = 50
aria2_ws_heartbeat = 20.0
aria2_ws_debug = False # very noisy
aria2_ws_debug_signals = False
class LogMessage(object):
def __init__(self, fmt, a, k): self.fmt, self.a, self.k = fmt, a, k
def __str__(self): return self.fmt.format(*self.a, **self.k) if self.a or self.k else self.fmt
class LogStyleAdapter(logging.LoggerAdapter):
def __init__(self, logger, extra=None):
super(LogStyleAdapter, self).__init__(logger, extra or {})
def log(self, level, msg, *args, **kws):
if not self.isEnabledFor(level): return
log_kws = {} if 'exc_info' not in kws else dict(exc_info=kws.pop('exc_info'))
msg, kws = self.process(msg, kws)
self.logger._log(level, LogMessage(msg, args, kws), (), log_kws)
get_logger = lambda name: LogStyleAdapter(logging.getLogger(name))
class AsyncExitStack:
# Might be merged to 3.7, see https://bugs.python.org/issue29302
# Implementation from https://gist.github.com/thehesiod/b8442ed50e27a23524435a22f10c04a0
def __init__(self):
self._exit_callbacks = collections.deque()
def pop_all(self):
new_stack = type(self)()
new_stack._exit_callbacks = self._exit_callbacks
self._exit_callbacks = collections.deque()
return new_stack
def push(self, exit_obj):
_cb_type = type(exit_obj)
try:
exit_method = getattr(_cb_type, '__aexit__', None)
if exit_method is None: exit_method = _cb_type.__exit__
except AttributeError: self._exit_callbacks.append(exit_obj)
else: self._push_cm_exit(exit_obj, exit_method)
return exit_obj
@staticmethod
def _create_exit_wrapper(cm, cm_exit):
if inspect.iscoroutinefunction(cm_exit):
async def _exit_wrapper(exc_type, exc, tb):
return await cm_exit(cm, exc_type, exc, tb)
else:
def _exit_wrapper(exc_type, exc, tb):
return cm_exit(cm, exc_type, exc, tb)
return _exit_wrapper
def _push_cm_exit(self, cm, cm_exit):
_exit_wrapper = self._create_exit_wrapper(cm, cm_exit)
_exit_wrapper.__self__ = cm
self.push(_exit_wrapper)
@staticmethod
def _create_cb_wrapper(callback, *args, **kwds):
if inspect.iscoroutinefunction(callback):
async def _exit_wrapper(exc_type, exc, tb): await callback(*args, **kwds)
else:
def _exit_wrapper(exc_type, exc, tb): callback(*args, **kwds)
return _exit_wrapper
def callback(self, callback, *args, **kwds):
_exit_wrapper = self._create_cb_wrapper(callback, *args, **kwds)
_exit_wrapper.__wrapped__ = callback
self.push(_exit_wrapper)
return callback
def _shutdown_loop(self, *exc_details):
received_exc = exc_details[0] is not None
frame_exc = sys.exc_info()[1]
def _fix_exception_context(new_exc, old_exc):
while True:
exc_context = new_exc.__context__
if exc_context is old_exc: return
if exc_context is None or exc_context is frame_exc: break
new_exc = exc_context
new_exc.__context__ = old_exc
suppressed_exc = pending_raise = False
while self._exit_callbacks:
cb = self._exit_callbacks.pop()
try:
cb_result = yield cb(*exc_details)
if cb_result:
suppressed_exc, pending_raise = True, False
exc_details = (None, None, None)
except:
new_exc_details = sys.exc_info()
_fix_exception_context(new_exc_details[1], exc_details[1])
pending_raise, exc_details = True, new_exc_details
if pending_raise:
try:
fixed_ctx = exc_details[1].__context__
raise exc_details[1]
except BaseException:
exc_details[1].__context__ = fixed_ctx
raise
return received_exc and suppressed_exc
async def enter(self, cm):
_cm_type = type(cm)
_exit = getattr(_cm_type, '__aexit__', None)
if _exit is not None: result = await _cm_type.__aenter__(cm)
else:
_exit = _cm_type.__exit__
result = _cm_type.__enter__(cm)
self._push_cm_exit(cm, _exit)
return result
async def close(self):
await self.__aexit__(None, None, None)
async def __aenter__(self): return self
async def __aexit__(self, *exc_details):
gen = self._shutdown_loop(*exc_details)
try:
result = next(gen)
while True:
try:
if inspect.isawaitable(result): result = await result
result = gen.send(result)
except StopIteration: raise
except BaseException as e: result = gen.throw(e)
except StopIteration as e: return e.value
def add_stack_wrappers(cls):
def _make_wrapper(func):
async def _wrapper(self, *args, **kws):
async with AsyncExitStack() as ctx:
return await func(self, ctx, *args, **kws)
return ft.wraps(func)(_wrapper)
for name, func in inspect.getmembers(cls, inspect.isroutine):
if name.startswith('__'): continue
sig = inspect.signature(func)
if ( len(sig.parameters) <= 1 or
list(sig.parameters.values())[1].annotation is not AsyncExitStack ): continue
setattr(cls, name, _make_wrapper(func))
return cls
it_adjacent = lambda seq, n, fill=None: it.zip_longest(fillvalue=fill, *([iter(seq)] * n))
it_adjacent_nofill = lambda seq, n:\
( tuple(filter(lambda v: v is not it_adjacent, chunk))
for chunk in it_adjacent(seq, n, fill=it_adjacent) )
class adict(dict):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.__dict__ = self
def log_lines(log_func, lines, log_func_last=False):
if isinstance(lines, str):
lines = list(line.rstrip() for line in lines.rstrip().split('\n'))
uid = secrets.token_urlsafe(3)
for n, line in enumerate(lines, 1):
if isinstance(line, str): line = '[{}] {}', uid, line
else: line = ['[{}] {}'.format(uid, line[0])] + list(line[1:])
if log_func_last and n == len(lines): log_func_last(*line)
else: log_func(*line)
def retries_within_timeout( tries, timeout,
backoff_func=lambda e,n: ((e**n-1)/e), slack=1e-2 ):
'Return list of delays to make exactly n tires within timeout, with backoff_func.'
a, b = 0, timeout
while True:
m = (a + b) / 2
delays = list(backoff_func(m, n) for n in range(tries))
error = sum(delays) - timeout
if abs(error) < slack: return delays
elif error > 0: b = m
else: a = m
def parse_pos_spec(pos):
if not pos: return
try: mins, secs = pos.rsplit(':', 1)
except ValueError: hrs, mins, secs = 0, 0, pos
else:
try: hrs, mins = mins.rsplit(':', 1)
except ValueError: hrs = 0
try:
return sum( a*b for a, b in
zip([3600, 60, 1], map(float, [hrs, mins, secs])) )
except ValueError as err:
raise argparse.ArgumentTypeError(
f'Failed to parse {pos!r} as [[hours:]minutes:]seconds value: {err}' )
def suppress_err(exc_type, func, *args, **kws):
def _wrapper():
try: func(*args, **kws)
except exc_type: pass
return _wrapper
class TVFError(Exception): pass
class TVFWSClosed(TVFError): pass
class TVFReqError(TVFError): pass
async def vod_fetch(loop, conf, vod_queue, list_formats_only=False):
exit_code, log = 1, get_logger('tvf.fetcher')
task = asyncio.Task.current_task(loop)
def sig_handler(sig, code):
log.info('Exiting on {} signal with code: {}', sig, code)
nonlocal exit_code
exit_code = code
task.cancel()
if not conf.compat_windows:
for sig, code in ('INT', 1), ('TERM', 0):
loop.add_signal_handler(getattr(signal, f'SIG{sig}'), ft.partial(sig_handler, sig, code))
async with TVF(loop, conf, log) as fetcher:
fetch = fetcher.get if not list_formats_only else fetcher.ytdl_probe_formats
for n, (url, prefix) in enumerate(vod_queue, 1):
info_suffix = None if len(vod_queue) == 1 else f' [{n} / {len(vod_queue)}]'
try: await fetch(url, prefix, info_suffix=info_suffix)
except asyncio.CancelledError as err: break
except TVFError as err:
if err.args: log.exception('BUG: {}', err)
log.error('Failed to fetch VoD #{} {!r} (url: {!r}), exiting', n, prefix, url)
break
else: exit_code = 0
return exit_code
class TVFFileCache:
update_lock = True
def __init__(self, prefix, ext, text=True):
self.b, self.path = '' if text else 'b', f'{prefix}.{ext}'
def __enter__(self):
assert self.update_lock
self.update_lock = False
return self
def __exit__(self, *err):
self.update_lock = True
@property
def cached(self):
if not os.path.exists(self.path): return None
with open(self.path, 'r' + self.b) as src: return src.read()
def update(self, data):
assert not self.update_lock
with open(self.path, 'w' + self.b) as dst: dst.write(data)
return data
class TVFAria2Proc:
def __init__(self, loop, conf, http, cmd_conf, key, port, path_func):
self.loop, self.conf, self.proc, self.http = loop, conf, None, http
self.ws = self.ws_ctx = self.ws_poll_task = self.ws_handlers = self.chunks = None
self.ws_closed = asyncio.Event()
self.cmd_conf, self.ws_key, self.ws_url, self.path_func = (
cmd_conf, f'token:{key}', f'ws://localhost:{port}/jsonrpc', path_func )
self.log = get_logger('tvf.aria2')
async def __aenter__(self):
cmd = [self.conf.aria2_cmd, '--conf-path', self.cmd_conf]
self.log.debug('Starting aria2c daemon: {}', ' '.join(cmd))
self.proc = await asyncio.create_subprocess_exec(*cmd)
self.ws_ctx = AsyncExitStack()
return self
async def __aexit__(self, *err):
proc_term_delay = 0
if self.ws_ctx:
if not self.ws_closed.is_set():
await self.req('shutdown', sync=False)
proc_term_delay = self.conf.aria2_term_wait[0]
await self.ws_ctx.close()
if self.ws_handlers:
for func in self.ws_handlers.values(): func(StopIteration)
if self.ws_poll_task: await self.ws_poll_task
self.ws = self.ws_ctx = self.ws_poll_task = self.ws_handlers = None
if self.proc:
if proc_term_delay > 0:
await asyncio.wait_for(self.proc.wait(), proc_term_delay)
try:
self.proc.send_signal(0)
self.log.debug( 'Sending SIGTERM to aria2c pid'
' (timeout={:.2f}s): {}', self.conf.aria2_term_wait[1], self.proc.pid )
self.proc.terminate()
try: await asyncio.wait_for(self.proc.wait(), timeout=self.conf.aria2_term_wait[1])
except asyncio.TimeoutError:
self.log.debug('Sending SIGKILL to aria2c pid: {}', self.proc.pid)
self.proc.kill()
except OSError: pass
exit_code = await self.proc.wait()
if exit_code: self.log.warning('aria2c has exited with error code: {}', exit_code)
self.proc = None
async def connect(self):
ts, delays = self.loop.time(), retries_within_timeout(*self.conf.aria2_startup_checks)
ts_next, ts_max, delay_iter = ts, ts + sum(delays), iter(delays)
while True:
ts = self.loop.time()
if ts >= ts_max: break
timeout = ts_max - ts
self.log.debug('Connecting to aria2c ws url (timeout={:.2f}s): {}', timeout, self.ws_url)
try:
self.ws = await self.ws_ctx.enter(self.http.ws_connect(
self.ws_url, heartbeat=self.conf.aria2_ws_heartbeat, timeout=timeout ))
except aiohttp.ClientError as err: self.log.debug('aria2c conn attempt error: {}', err)
else:
self.log.debug('Connected to aria2c json-rpc ws')
break
ts = self.loop.time()
if ts_next > ts:
await asyncio.sleep(ts_next - ts)
ts = self.loop.time()
while ts_next < ts:
try: ts_next += next(delay_iter)
except StopIteration: break
if not self.ws:
self.ws_closed.set()
raise TVFError( 'aria2c connection failed'
' (max_tries={}, timeout={})'.format(*self.conf.aria2_startup_checks) )
jrpc_uid_ns = secrets.token_urlsafe(3)
self.ws_jrpc_uid_iter = (f'{jrpc_uid_ns}.{n}' for n in range(1, 2**30))
self.ws_handlers, self.ws_poll_task = dict(), self.loop.create_task(self._ws_poller())
@contextlib.contextmanager
def ws_close_wrap(self):
'Raises TVFWSClosed in current task on ws_closed event.'
async def _ev_wait():
nonlocal triggered
await self.ws_closed.wait()
triggered = True
task.cancel()
triggered, task = False, asyncio.Task.current_task(self.loop)
ev_wait_task = self.loop.create_task(_ev_wait())
try: yield
except asyncio.CancelledError:
if not triggered: raise
raise TVFWSClosed() from None
finally: ev_wait_task.cancel()
async def _ws_poller(self):
try:
async for msg in self.ws:
if msg.type == aiohttp.WSMsgType.TEXT:
if self.conf.aria2_ws_debug: self.log.debug('rpc-msg: {}', msg.data)
msg_data, hs_discard = json.loads(msg.data), set()
if self.conf.aria2_ws_debug_signals and msg_data.get('method'):
self.log.debug('rpc-signal: {}', msg.data)
for k, func in self.ws_handlers.items():
if func(msg_data): hs_discard.add(k)
for k in hs_discard: del self.ws_handlers[k]
elif msg.type == aiohttp.WSMsgType.closed: break
elif msg.type == aiohttp.WSMsgType.error:
self.log.error('ws protocol error, aborting: {}', msg)
break
else: self.log.warning('Unhandled ws msg type {}, ignoring: {}', msg.type, msg)
except Exception as err:
self.log.exception('Unhandled ws handler failure, aborting: {}', err)
await self.ws.close()
self.ws_closed.set()
def _add_handler(self, func): self.ws_handlers[id(func)] = func
def _expect_handler(self, uid_or_func, fut_or_cb, oneshot, msg):
if msg is StopIteration:
if not callable(fut_or_cb): fut_or_cb.cancel()
return
if callable(uid_or_func) and not uid_or_func(msg): return
else:
uid = msg.get('id') or msg.get('method')
if uid.startswith('aria2.'): uid = uid[6:]
if uid != uid_or_func: return
if callable(fut_or_cb): fut_or_cb(msg)
else: fut_or_cb.set_result(msg)
return oneshot
def expect(self, uid_or_func, fut_or_cb=None, oneshot=False):
if not fut_or_cb: fut_or_cb = asyncio.Future()
self._add_handler(ft.partial(
self._expect_handler, uid_or_func, fut_or_cb, oneshot ))
return fut_or_cb
async def req(self, method, *params, sync=True):
if method != 'system.multicall':
method, params = f'aria2.{method}', [self.ws_key] + list(params)
res = req_uid = next(self.ws_jrpc_uid_iter)
data_req = dict(
jsonrpc='2.0', id=req_uid,
method=method, params=params )
if self.conf.aria2_ws_debug:
self.log.debug( 'rpc-req{} [{}]: {}',
'-sync' if sync else '', req_uid, json.dumps(data_req) )
await self.ws.send_str(json.dumps(data_req))
if sync:
with self.ws_close_wrap(): res = await self.expect(req_uid)
if self.conf.aria2_ws_debug:
self.log.debug('rpc-res-sync [{}]: {}', req_uid, json.dumps(res))
if 'result' not in res:
raise TVFReqError(f'Request failed (method={method}, params={params}): {res}')
res = res['result']
return res
async def req_ok(self, method, *params):
res = await self.req(method, *params)
if res != 'OK': raise TVFError(f'aria2c command failed: {method}{params}')
def _queue_params(self, gid, url, pos=None):
return (
[[url], dict(gid=gid, out=str(self.path_func(gid)))]
+ ([] if not pos else [pos]) )
async def queue(self, gid, url, pos=None):
res = await self.req('addUri', *self._queue_params(gid, url, pos))
if res != gid:
self.log.error('addUri error (gid={}): {}', gid, res)
raise TVFError(f'Failed to queue chunk URL to aria2c')
async def queue_batch(self, *gid_urls, pos=None):
if pos is not None: pos = iter(range(pos, 2**30))
res = await self.req('system.multicall', list(
dict(methodName='aria2.addUri', params=(
[self.ws_key] + self._queue_params(gid, url, pos and next(pos)) ))
for gid, url in gid_urls ))
res_chk = list([gid] for gid, url in gid_urls)
if res != res_chk:
log_lines(self.log.error, [
'Result gid match check failed for submitted urls.',
(' expected: {}', ', '.join(str(r[0]) for r in res_chk)),
(' returned: {}', ', '.join(( str(r[0])
if isinstance(r, list) else repr(r) ) for r in res)) ])
raise TVFError(f'Failed to queue {len(gid_urls)} chunk URLs to aria2c')
async def queue_chunks(self, chunks):
self.chunks = chunks
count, gid_last = 0, None
with self.ws_close_wrap():
for gid_url_batch in it_adjacent_nofill(chunks, self.conf.aria2_queue_batch):
await self.queue_batch(*gid_url_batch)
count, gid_last = count + len(gid_url_batch), gid_url_batch[-1][0]
return count, gid_last
class TVFChunkIter:
def __init__(self, conf, url_base, pls_text, gids_done=None):
self.conf = conf
self.gid_urls = self.parse_pls(url_base, pls_text)
self.gids_done = set(gids_done or list())
def __iter__(self):
for gid, url in self.gid_urls.items():
if gid in self.gids_done: continue
yield gid, url
def scatter_check_coro(self, *scatter_opts):
(a, b), res = scatter_opts, True
while True:
td = yield res
if a > 0: a -= td
if a <= 0: res = False
b -= td
if b <= 0: (a, b), res = scatter_opts, True
def parse_pls(self, url_base, pls_text):
slice_start, slice_len = self.conf.slice_start, self.conf.slice_len
scatter_check = self.scatter_check_coro(
self.conf.slice_scatter_len, self.conf.slice_scatter_interval )\
if self.conf.slice_scatter_len is not None else None
if scatter_check: next(scatter_check)
# aria2c requires 16-char gid, format used here fits number in
# first 6 chars because it looks nice in (tuncated) aria2c output
gid_iter = iter(map('{:06d}0000000000'.format, range(1, 999999)))
chunks = collections.OrderedDict()
for line in pls_text.splitlines():
m = re.search(r'^#EXTINF:([\d.]+),', line)
if m:
td = float(m.group(1))
slice_start -= td
if slice_start > 0: continue
if not line or line.startswith('#'): continue
if slice_len and slice_len + slice_start < 0: break
if scatter_check and not scatter_check.send(td): continue
chunks[next(gid_iter)] = f'{url_base}/{line}'
return chunks
def mark_needed(self, gid, check=True):
if check and gid not in self.gids_done:
raise TVFError(f'BUG: gid was not marked as done: {gid}')
self.gids_done.discard(gid)
def mark_done(self, gid):
self.gids_done.add(gid)
@property
def finished(self):
return len(self.gid_urls) == len(self.gids_done)
@add_stack_wrappers
class TVF:
def __init__(self, loop, conf, log):
self.loop, self.conf = loop, conf
self.log = log or get_logget('tvf.fetcher')
async def __aenter__(self): return self
async def __aexit__(self, *err): pass
async def ytdl_run(self, ytdl_op, *args, check=True, out=False, **popen_kws):
cmd = ['youtube-dl']
if self.conf.verbose: cmd.append('--verbose')
cmd = cmd + [ytdl_op] + (self.conf.ytdl_opts or list()) + list(args)
self.log.debug('Running "youtube-dl {}" command: {}', ytdl_op, ' '.join(cmd))
if out: popen_kws.setdefault('stdout', asyncio.subprocess.PIPE)
proc = await asyncio.create_subprocess_exec(*cmd, **popen_kws)
if out: proc_stdout = (await proc.stdout.read()).strip()
exit_code = await proc.wait()
if check and exit_code != 0:
raise TVFError(f'"youtube-dl {ytdl_op}" command exited with error (code: {exit_code})')
return exit_code if not out else proc_stdout.decode()
async def ytdl_probe_formats(self, url, file_prefix, info_suffix=None):
self.log.info( '--- Listing formats available'
' for VoD {} (url: {}){}', file_prefix, url, info_suffix or '' )
return await self.ytdl_run('--list-formats', url, check=False)
async def get(self, ctx: AsyncExitStack, url, file_prefix, info_suffix=None):
file_prefix = pathlib.Path(file_prefix)
file_dst_path = file_prefix.with_suffix('.mp4')
if self.conf.use_temp_dirs:
name = file_prefix.name
file_prefix_dir = file_prefix.with_suffix('.tmp')
file_prefix = file_prefix_dir / name
# Assuming file_dst is fully downloaded if there's no playlist tempfile around
# It's not a problem even with --keep-tempfiles, as file_dst_mark will be kept there too
if file_dst_path.exists() and (
self.conf.file_part_suffix or not file_prefix.with_suffix('.m3u8').exists() ):
self.log.info( '--- Skipping download'
' for existing file: {} (rename/remove it to force)', file_dst_path )
return
if self.conf.use_temp_dirs: file_prefix_dir.mkdir(exist_ok=True)
if self.conf.file_part_suffix:
file_dst_done, file_dst_path = file_dst_path, file_dst_path.with_suffix('.part.mp4')
vod_cache = ft.partial(TVFFileCache, file_prefix)
http = await ctx.enter(aiohttp.ClientSession(**self.conf.aiohttp_opts))
self.log.info('--- Downloading VoD {} (url: {}){}', file_dst_path, url, info_suffix or '')
with vod_cache('m3u8.url') as vc:
url_pls = vc.cached or vc.update(await self.ytdl_run('--get-url', url, out=True))
assert ' ' not in url_pls, url_pls # can return URLs of multiple flvs for *really* old VoDs
url_base = url_pls.rsplit('/', 1)[0]
with vod_cache('m3u8.ua') as vc:
ua = vc.cached or vc.update(await self.ytdl_run('--dump-user-agent', out=True))
with vod_cache('m3u8') as vc:
pls = vc.cached
if not pls:
self.log.debug('Fetching playlist from URL: {}', url_pls)
async with http.get(url_pls, headers={'User-Agent': ua}) as res:
pls = vc.update(await res.text())
# Config is always updated between aria2c runs, to account for any cli opts changes
with vod_cache('aria2c_conf') as vc:
key = secrets.token_urlsafe(18)
with contextlib.closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(('localhost', 0))
addr, port = s.getsockname()
vc.update(self.conf.aria2_conf_func(adict(
key=key, port=port, ua=ua,
log_level='notice' if self.conf.verbose else 'warn' )))
aria2c_conf_path = vc.path
# file_dst_mark stores (pos, seq) tuples, appending each update to the end
# It is assumed here that such tiny writes will be atomic
file_dst_pos = file_dst_seq = 0 # position in dst file, sequential gid int
file_dst_mark_fmt = '>QI'
file_dst_mark = await ctx.enter(file_prefix.with_suffix('.pos').open('a+b'))
if file_dst_mark.tell() != 0:
file_dst_mark.seek(-struct.calcsize(file_dst_mark_fmt), os.SEEK_END)
file_dst_pos, file_dst_seq = struct.unpack(file_dst_mark_fmt, file_dst_mark.read())
if not file_dst_path.exists() and file_dst_pos > 0:
raise TVFError( f'Missing partial-dst-file'
f' (expecting bytes={file_dst_pos:,} gid-seq={file_dst_seq}): {file_dst_path}' )
file_dst = file_dst_path.open('a+b')
if file_dst.tell() < file_dst_pos:
raise TVFError( f'Missing chunk of partial-dst-file'
f' (expecting size>={file_dst_pos:,} actual-size={file_dst.tell():,}): {file_dst_path}' )
file_dst.seek(file_dst_pos)
file_dst.truncate()
# seq - gid converted to a sequential int, to sort appended chunks
seq_heap, seq_chunks = list(), list()
gid_to_seq = lambda gid: int(gid[:6])
seq_to_gid = lambda seq: f'{seq:06d}0000000000'
gid_pretty = lambda gid: gid[:6]
chunks = TVFChunkIter(self.conf, url_base, pls)
for gid, url in chunks:
if gid_to_seq(gid) > file_dst_seq: break
chunks.mark_done(gid)
chunk_path = lambda gid: file_prefix.with_suffix(f'.{gid}.mp4.chunk')
async with TVFAria2Proc( self.loop,
self.conf, http, aria2c_conf_path, key, port, chunk_path ) as aria2c:
await aria2c.connect()
# All chunks from iter(chunks) get queued for download here
info = await aria2c.req('getVersion')
self.log.debug( 'Starting downloads'
' (rpc-port={}, aria2-version={})...', port, info['version'] )
count, gid_last = await aria2c.queue_chunks(chunks)
if gid_last:
self.log.info( '\n\n ------ Started {} downloads,'
' last gid: {} ------ \n', count, gid_pretty(gid_last) )
else: self.log.info('No extra downloads were necessary')
### Main download-control loop ahead
# Download control strategy:
# - listen for aria2.onDownloadComplete event for each gid
# - append sequential-gid chunk to partial file, update mark-file, remove chunk
# - non-sequential gids - just remember and check/append next time
# - if >file_append_max non-sequential-gid chunks downloaded - abort
# - for all error/paused/stopped events: prepend to download queue
# - stop on chunks.finished (all chunks marked as "done")
ev_t = enum.Enum('ev_type', 'complete stop pause error')
evq, ev_tuple = asyncio.Queue(), collections.namedtuple('ev', 't data')
for t in ev_t:
aria2c.expect( f'onDownload{t.name.title()}',
lambda e,t=t: evq.put_nowait(ev_tuple(t, e)) )
while not chunks.finished:
with aria2c.ws_close_wrap():
ev = await evq.get()
for e in ev.data['params']:
if ev.t == ev_t.complete: # completed chunk goes to seq_heap
heapq.heappush(seq_heap, gid_to_seq(e['gid']))
await aria2c.req('removeDownloadResult', e['gid'], sync=False)
chunks.mark_done(e['gid'])
else: # stop/pause/error - retry immediately
# Limit on forced retries (on top of what aria2 does) for same gid can be added here
if self.log.isEnabledFor(logging.DEBUG):
status = await aria2c.req( 'tellStatus', e['gid'],
['status', 'errorCode', 'errorMessage', 'completedLength', 'totalLength'] )
self.log.debug( 'Chunk {} download'
' stopped or failed, retrying: {}', e['gid'], status )
await aria2c.req_ok('removeDownloadResult', e['gid'])
await aria2c.queue(e['gid'], chunks.gid_urls[e['gid']], 0)
if chunks.finished or len(seq_heap) > self.conf.file_append_batch:
while seq_heap and seq_heap[0] == file_dst_seq + 1:
seq = heapq.heappop(seq_heap)
p = chunk_path(seq_to_gid(seq))
with p.open('rb') as chunk: shutil.copyfileobj(chunk, file_dst)
file_dst_pos, file_dst_seq = file_dst.tell(), seq
seq_chunks.append(p)
if seq_chunks:
file_dst.flush()
file_dst_mark.write(struct.pack(file_dst_mark_fmt, file_dst_pos, file_dst_seq))
file_dst_mark.flush()
for p in seq_chunks: os.unlink(p)
# self.log.debug('\nExtended partial dst file by {} chunk(s)', len(seq_chunks))
seq_chunks.clear()
if len(seq_heap) > self.conf.file_append_max:
err_msg = ( 'Chunk-append sequence is broken'
f' (pending={len(seq_heap)} max={self.conf.file_append_max})' )
log_lines(self.log.error, [ ('{}:', err_msg),
(' expecting gid: {}', seq_to_gid(file_dst_seq)),
(' first gid among pending: {}', seq_to_gid(seq_heap[0])) ])
raise TVFError(err_msg)
### done!
file_dst.close()
file_dst_mark.close()
if self.conf.file_part_suffix:
file_dst_path.rename(file_dst_done)
file_dst_path = file_dst_done
if not self.conf.keep_tempfiles:
self.log.debug('Cleanup for prefix (dir={}): {}', self.conf.use_temp_dirs, file_prefix)
if self.conf.use_temp_dirs: shutil.rmtree(file_prefix_dir)
else:
for p in file_prefix.parent.glob(f'{file_prefix.name}.*'):
if p != file_dst_path: os.unlink(p)
self.log.info('--- Downloaded VoD file: {}{}', file_dst_path, info_suffix or '')
def main(args=None, conf=None):
if not conf: conf = TVFConfig()
import argparse
parser = argparse.ArgumentParser(
usage='%(prog)s [options] url file_prefix [url-2 file_prefix-2 ...]',
description='Grab a VoD or a specified slice of it from twitch.tv, properly.')
parser.add_argument('url', help='URL for a VoD to fetch.')
parser.add_argument('file_prefix', help='File prefix to assemble temp files under.')
parser.add_argument('more_url_and_prefix_pairs', nargs='*',
help='Any number of extra "url file_prefix" arguments can be specified.')
group = parser.add_argument_group('Download slice options')
group.add_argument('-s', '--start-pos',
type=parse_pos_spec, metavar='[[hours:]minutes:]seconds',
help='Only download video chunks after specified start position.'
' If multiple url/prefix args are specified, this option will be applied to all of them.')
group.add_argument('-l', '--length',
type=parse_pos_spec, metavar='[[hours:]minutes:]seconds',
help='Only download specified length of the video (from specified start or beginning).'
' If multiple url/prefix args are specified, this option will be applied to all of them.')
group.add_argument('-x', '--scatter',
metavar='[[hours:]minutes:]seconds/[[hours:]minutes:]seconds',
help='Out of whole video (or a chunk specified by --start and --length),'
' download only every N seconds (or mins/hours) out of M.'
' E.g. "1:00/10:00" spec here will download 1 first min of video out of every 10.'
' Idea here is to produce something like preview of the video to allow'
' to easily narrow down which part of it is interesting and worth downloading in full.')
group = parser.add_argument_group('youtube-dl/aria2c options')
group.add_argument('-F', '--ytdl-list-formats',
action='store_true', help='Do not download anything,'
' just list formats available for each specified URL and exit.')
group.add_argument('-y', '--ytdl-opts',
action='append', metavar='opts',
help='Extra opts for youtube-dl --get-url command.'
' Will be split on spaces, unless option is used multiple times.')
group.add_argument('-a', '--aria2c-opts',
action='append', metavar='opts',
help='Extra options to store in aria2c configuration file.'
' These are same as long-form command-line options, but without double-dash prefix.'
' Will be split on spaces, unless option is used multiple times.'
' Example: --aria2c-opts "lowest-speed-limit=20K max-overall-download-limit=1M"')
group = parser.add_argument_group('Misc/debug options')
group.add_argument('-p', '--use-part-suffix',
action='store_true', help='Use .part suffix for'
' partial destination file until it is fully downloaded.')
group.add_argument('-n', '--no-temp-dirs',
action='store_true', help='Do not create temporary'
' directories for download parts, keep everything in the destination one.')
group.add_argument('-k', '--keep-tempfiles',
action='store_true', help='Do not remove all the'
' temporary files after successfully assembling resulting mp4.'
' Chunks in particular might be useful to download different but overlapping video slices.')
group.add_argument('--debug', action='store_true', help='Verbose operation mode.')
opts = parser.parse_args(sys.argv[1:] if args is None else args)
logging.basicConfig(
datefmt='%Y-%m-%d %H:%M:%S',
format='%(asctime)s :: {}%(levelname)s :: %(message)s'\
.format('%(name)s ' if opts.debug else ''),
level=logging.DEBUG if opts.debug else logging.INFO )
log = get_logger('tvf.main')
conf.slice_start, conf.slice_len = opts.start_pos or 0, opts.length
if opts.scatter:
scatter = opts.scatter.split('/', 1)
if len(scatter) != 2:
parser.error( f'Invalid value for -x/--scatter option'
' - {opts.scatter!r} - should be in "len/interval" format.' )
conf.slice_scatter_len, conf.slice_scatter_interval = map(parse_pos_spec, scatter)
ytdl_opts = opts.ytdl_opts or list()
if len(ytdl_opts) == 1: ytdl_opts = ytdl_opts[0].strip().split()
aria2c_opts = opts.aria2c_opts or list()
if len(aria2c_opts) == 1: aria2c_opts = aria2c_opts[0].strip().split()
conf.verbose = opts.debug
conf.keep_tempfiles = opts.keep_tempfiles
conf.use_temp_dirs = not opts.no_temp_dirs
conf.file_part_suffix = opts.use_part_suffix
vod_queue, args = list(),\
[opts.url, opts.file_prefix] + (opts.more_url_and_prefix_pairs or list())
if len(args) % 2:
parser.error( 'Odd number of url/prefix args specified'
f' ({len(args)}), while these should always come in pairs.' )
for url, prefix in it_adjacent(args, 2):
if re.search(r'^https?:', prefix):
if re.search(r'^https?:', url):
parser.error( 'Both url/file_prefix args seem'
f' to be an URL, only first one should be: {url} {prefix}' )
prefix, url = url, prefix
log.warn( 'Looks like url/prefix args got'
' mixed-up, correcting that to prefix=%s url=%s', prefix, url )
if not re.search(r'^https?://[^/]+/([^/]+/v|videos)/', url):
parser.error( 'Provided URL appears to be for the unsupported'
f' VoD format (only /user/v/ or /videos/ VoDs are supported): {url}' )
vod_queue.append((url, prefix))
if conf.compat_windows:
# Use the Proactor event loop on Windows for subprocess support
loop = asyncio.ProactorEventLoop()
asyncio.set_event_loop(loop)
log.debug('Starting vod_fetch loop...')
with contextlib.closing(asyncio.get_event_loop()) as loop:
exit_code = loop.run_until_complete(vod_fetch(
loop, conf, vod_queue, list_formats_only=opts.ytdl_list_formats ))
log.debug('Finished')
return exit_code
if __name__ == '__main__': sys.exit(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment