Created
May 6, 2020 16:49
-
-
Save moriyoshi/da4aff9844f8cce20d20e7b93be37b64 to your computer and use it in GitHub Desktop.
CPython bytecode instrumentation (convert ordinary methods to async methods)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import bisect | |
import opcode | |
import dis | |
import inspect | |
import sys | |
import types | |
from collections import defaultdict | |
import httpx | |
alternative_globals = {} | |
DUP_TOP = dis.opmap["DUP_TOP"] | |
POP_TOP = dis.opmap["POP_TOP"] | |
BUILD_TUPLE = dis.opmap["BUILD_TUPLE"] | |
LOAD_FAST = dis.opmap["LOAD_FAST"] | |
LOAD_ATTR = dis.opmap["LOAD_ATTR"] | |
LOAD_CONST = dis.opmap["LOAD_CONST"] | |
LOAD_METHOD = dis.opmap["LOAD_METHOD"] | |
LOAD_GLOBAL = dis.opmap["LOAD_GLOBAL"] | |
STORE_FAST = dis.opmap["STORE_FAST"] | |
GET_AWAITABLE = dis.opmap["GET_AWAITABLE"] | |
YIELD_FROM = dis.opmap["YIELD_FROM"] | |
EXTENDED_ARG = dis.opmap["EXTENDED_ARG"] | |
CALL_METHOD = dis.opmap["CALL_METHOD"] | |
CALL_FUNCTION = dis.opmap["CALL_FUNCTION"] | |
CALL_FUNCTION_KW = dis.opmap["CALL_FUNCTION_KW"] | |
CALL_FUNCTION_EX = dis.opmap["CALL_FUNCTION_EX"] | |
def render_ins(opcode, arg=None): | |
b = bytearray() | |
if arg is not None: | |
if arg > 0x00ffffff: | |
b.append(EXTENDED_ARG) | |
b.append((arg >> 24)) | |
if arg > 0x0000ffff: | |
b.append(EXTENDED_ARG) | |
b.append((arg >> 16) & 0xff) | |
if arg > 0x000000ff: | |
b.append(EXTENDED_ARG) | |
b.append((arg >> 8) & 0xff) | |
b.append(opcode) | |
if arg is not None: | |
b.append(arg & 0xff) | |
else: | |
if sys.version_info.major > 3 or (sys.version_info.major == 3 and sys.version_info.minor >= 6): | |
b.append(0) | |
return b | |
def append_ins(b, opcode, arg=None): | |
b.extend(render_ins(opcode, arg)) | |
def rewrite_to_async_call(fn, verbs): | |
code = fn.__code__ | |
if code.co_flags & inspect.CO_COROUTINE: | |
return code | |
state = 0 | |
rewritten = False | |
none_const = None | |
aread_name = None | |
resp_store = set() | |
consts = code.co_consts | |
names = code.co_names | |
offset_addr_map = [] | |
offset_map = [] | |
jumps = [] | |
result = bytearray() | |
sp = 0 | |
ssp = 0 | |
for ins in dis.Bytecode(code): | |
if ins.opcode == EXTENDED_ARG: | |
continue | |
sp += opcode.stack_effect(ins.opcode, ins.arg) | |
if state == 0: | |
if ins.opcode == LOAD_FAST and ins.arg == 0: | |
state = 1 | |
elif ins.opcode == LOAD_FAST and ins.arg in resp_store: | |
state = 5 | |
elif ins.opcode == LOAD_GLOBAL and ins.argval == "super": | |
state = 8 | |
ssp = sp | |
else: | |
state = 0 | |
elif state == 1: | |
if ins.opcode == LOAD_ATTR and ins.argval in verbs: | |
state = 2 | |
ssp = sp | |
elif ins.opcode == LOAD_ATTR and ins.argval == "refresh_token": | |
state = 6 | |
ssp = sp | |
elif ins.opcode == LOAD_METHOD and ins.argval in verbs: | |
state = 3 | |
ssp = sp | |
else: | |
state = 0 | |
elif state == 2: | |
if sp - ssp == 0: | |
if ins.opcode == CALL_FUNCTION: | |
state = 4 | |
elif ins.opcode == CALL_FUNCTION_KW: | |
state = 4 | |
elif ins.opcode == CALL_FUNCTION_EX: | |
state = 4 | |
elif state == 3: | |
if sp - ssp == 0 and ins.opcode == CALL_METHOD: | |
state = 4 | |
elif state == 4: | |
if none_const is None: | |
try: | |
none_const = consts.index(None) | |
except ValueError: | |
none_const = len(consts) | |
consts = consts + (None,) | |
append_ins(result, GET_AWAITABLE) | |
append_ins(result, LOAD_CONST, none_const) | |
append_ins(result, YIELD_FROM) | |
offset_addr_map.append(ins.offset) | |
offset_map.append(len(result)) | |
if ins.opcode == STORE_FAST: | |
resp_store.add(ins.arg) | |
state = 0 | |
rewritten = True | |
elif state == 5: | |
if ins.opcode == LOAD_ATTR and ins.argval in ("body", "text", "json"): | |
if aread_name is None: | |
try: | |
aread_name = names.index("aread") | |
except ValueError: | |
aread_name = len(names) | |
names = names + ("aread",) | |
append_ins(result, DUP_TOP) | |
append_ins(result, LOAD_METHOD, aread_name) | |
append_ins(result, CALL_METHOD, 0) | |
append_ins(result, GET_AWAITABLE) | |
append_ins(result, LOAD_CONST, none_const) | |
append_ins(result, YIELD_FROM) | |
append_ins(result, POP_TOP) | |
offset_addr_map.append(ins.offset) | |
offset_map.append(len(result)) | |
rewritten = True | |
state = 0 | |
elif state == 6: | |
if sp - ssp == 0: | |
if ins.opcode == CALL_FUNCTION: | |
state = 7 | |
elif ins.opcode == CALL_FUNCTION_KW: | |
state = 7 | |
elif ins.opcode == CALL_FUNCTION_EX: | |
state = 7 | |
elif state == 7: | |
if none_const is None: | |
try: | |
none_const = consts.index(None) | |
except ValueError: | |
none_const = len(consts) | |
consts = consts + (None,) | |
prev = len(result) | |
append_ins(result, GET_AWAITABLE) | |
append_ins(result, LOAD_CONST, none_const) | |
append_ins(result, YIELD_FROM) | |
offset_addr_map.append(ins.offset) | |
offset_map.append(len(result)) | |
state = 0 | |
rewritten = True | |
elif state == 8: | |
if sp - ssp == 0 and ins.opcode == CALL_FUNCTION: | |
state = 1 | |
co = len(result) | |
append_ins(result, ins.opcode, ins.arg) | |
abs_ = ins.opcode in dis.hasjabs | |
rel = ins.opcode in dis.hasjrel | |
if abs_ or rel: | |
jumps.append((co, ins, rel, len(result) - co)) | |
# backpatching | |
for addr, ins, rel, l in jumps: | |
if rel: | |
offset = ins.offset + l | |
target = ins.argval | |
s = bisect.bisect_right(offset_addr_map, target) - 1 | |
e = bisect.bisect_right(offset_addr_map, offset) - 1 | |
if s >= 0: | |
target += offset_map[s] - offset_addr_map[s] | |
if e >= 0: | |
offset += offset_map[e] - offset_addr_map[e] | |
b = render_ins(ins.opcode, target - offset) | |
result[addr:addr + len(b)] = b | |
else: | |
i = bisect.bisect_right(offset_addr_map, ins.arg) - 1 | |
if i < 0: | |
continue | |
b = render_ins(ins.opcode, ins.arg + offset_map[i] - offset_addr_map[i]) | |
result[addr:addr + len(b)] = b | |
return ( | |
code if not rewritten | |
else code.replace(co_code=bytes(result), co_flags=(code.co_flags | inspect.CO_COROUTINE), co_consts=consts, co_names=names, co_stacksize=code.co_stacksize + 96) | |
) | |
class CompatibleRequest(httpx.Request): | |
@property | |
def body(self): | |
self.read() | |
return self.content | |
class CompatibleAsyncClient(httpx.AsyncClient): | |
async def request(self, *args, **kwargs): | |
kwargs.pop("verify", None) | |
kwargs.pop("proxies", None) | |
resp = await super().request(*args, **kwargs) | |
resp.request.__class__ = CompatibleRequest | |
return resp | |
async def get(self, *args, **kwargs): | |
return await self.request("GET", *args, **kwargs) | |
async def options(self, *args, **kwargs): | |
return await self.request("OPTIONS", *args, **kwargs) | |
async def head(self, *args, **kwargs): | |
allow_redirects = kwargs.pop("allow_redirects", False) | |
return await self.request("HEAD", *args, allow_redirects=allow_redirects, **kwargs) | |
async def post(self, *args, **kwargs): | |
return await self.request("POST", *args, **kwargs) | |
async def put(self, *args, **kwargs): | |
return await self.request("PUT", *args, **kwargs) | |
async def patch(self, *args, **kwargs): | |
return await self.request("PATCH", *args, **kwargs) | |
async def delete(self, *args, **kwargs): | |
return await self.request("DELETE", *args, **kwargs) | |
def instrument(type_): | |
new_type = type( | |
type_.__name__, | |
(CompatibleAsyncClient,), | |
{ | |
k: v | |
for k, v in type_.__dict__.items() | |
if k not in ("__module__", "__dict__", "__weakref__") | |
} | |
) | |
for k, v in new_type.__dict__.items(): | |
if isinstance(v, types.FunctionType): | |
if v.__closure__ is not None: | |
new_closure = tuple( | |
types.CellType(new_type) | |
if cell.cell_contents is type_ | |
else cell | |
for cell in v.__closure__ | |
) | |
else: | |
new_closure = v.__closure__ | |
globals = alternative_globals.get(v.__module__) | |
if globals is None: | |
globals = alternative_globals[v.__module__] = { | |
k: (new_type if v is type_ else v) | |
for k, v in v.__globals__.items() | |
} | |
v = types.FunctionType( | |
rewrite_to_async_call(v, ("request", "get", "options", "head", "post", "put", "patch", "delete")), | |
globals, | |
v.__name__, | |
v.__defaults__, | |
new_closure | |
) | |
setattr(new_type, k, v) | |
return new_type |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment