Last active
August 29, 2015 14:08
-
-
Save dutc/b7e82f587662d9f9a6a1 to your computer and use it in GitHub Desktop.
Did you mean? in Python (Bonus Round!)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <Python.h> | |
#include "didyoumean-safe.h" | |
static int safe_merge_list_attr(PyObject* dict, PyObject* obj, const char *attrname); | |
static int safe_merge_class_dict(PyObject* dict, PyObject* aclass); | |
static PyObject * safe_PyObject_GetAttr(PyObject *v, PyObject *name); | |
static PyObject * safe_PyObject_GetAttrString(PyObject *v, const char *name); | |
static PyObject * safe__generic_dir(PyObject *obj); | |
static PyObject * safe__specialized_dir_type(PyObject *obj); | |
static PyObject * safe__specialized_dir_module(PyObject *obj); | |
static PyObject * safe__dir_object(PyObject *obj); | |
static int | |
safe_merge_list_attr(PyObject* dict, PyObject* obj, const char *attrname) | |
{ | |
PyObject *list; | |
int result = 0; | |
assert(PyDict_Check(dict)); | |
assert(obj); | |
assert(attrname); | |
list = safe_PyObject_GetAttrString(obj, attrname); | |
if (list == NULL) | |
PyErr_Clear(); | |
else if (PyList_Check(list)) { | |
int i; | |
for (i = 0; i < PyList_GET_SIZE(list); ++i) { | |
PyObject *item = PyList_GET_ITEM(list, i); | |
if (PyString_Check(item)) { | |
result = PyDict_SetItem(dict, item, Py_None); | |
if (result < 0) | |
break; | |
} | |
} | |
if (Py_Py3kWarningFlag && | |
(strcmp(attrname, "__members__") == 0 || | |
strcmp(attrname, "__methods__") == 0)) { | |
if (PyErr_WarnEx(PyExc_DeprecationWarning, | |
"__members__ and __methods__ not " | |
"supported in 3.x", 1) < 0) { | |
Py_XDECREF(list); | |
return -1; | |
} | |
} | |
} | |
Py_XDECREF(list); | |
return result; | |
} | |
static int | |
safe_merge_class_dict(PyObject* dict, PyObject* aclass) | |
{ | |
PyObject *classdict; | |
PyObject *bases; | |
assert(PyDict_Check(dict)); | |
assert(aclass); | |
/* Merge in the type's dict (if any). */ | |
classdict = safe_PyObject_GetAttrString(aclass, "__dict__"); | |
if (classdict == NULL) | |
PyErr_Clear(); | |
else { | |
int status = PyDict_Update(dict, classdict); | |
Py_DECREF(classdict); | |
if (status < 0) | |
return -1; | |
} | |
/* Recursively merge in the base types' (if any) dicts. */ | |
bases = safe_PyObject_GetAttrString(aclass, "__bases__"); | |
if (bases == NULL) | |
PyErr_Clear(); | |
else { | |
/* We have no guarantee that bases is a real tuple */ | |
Py_ssize_t i, n; | |
n = PySequence_Size(bases); /* This better be right */ | |
if (n < 0) | |
PyErr_Clear(); | |
else { | |
for (i = 0; i < n; i++) { | |
int status; | |
PyObject *base = PySequence_GetItem(bases, i); | |
if (base == NULL) { | |
Py_DECREF(bases); | |
return -1; | |
} | |
status = safe_merge_class_dict(dict, base); | |
Py_DECREF(base); | |
if (status < 0) { | |
Py_DECREF(bases); | |
return -1; | |
} | |
} | |
} | |
Py_DECREF(bases); | |
} | |
return 0; | |
} | |
static PyObject * | |
safe_PyObject_GetAttr(PyObject *v, PyObject *name) | |
{ | |
PyTypeObject *tp = Py_TYPE(v); | |
if (!PyString_Check(name)) { | |
#ifdef Py_USING_UNICODE | |
/* The Unicode to string conversion is done here because the | |
existing tp_getattro slots expect a string object as name | |
and we wouldn't want to break those. */ | |
if (PyUnicode_Check(name)) { | |
name = _PyUnicode_AsDefaultEncodedString(name, NULL); | |
if (name == NULL) | |
return NULL; | |
} | |
else | |
#endif | |
{ | |
PyErr_Format(PyExc_TypeError, | |
"attribute name must be string, not '%.200s'", | |
Py_TYPE(name)->tp_name); | |
return NULL; | |
} | |
} | |
if (tp->tp_getattro != NULL) | |
return (*tp->tp_getattro)(v, name); | |
if (tp->tp_getattr != NULL) | |
return (*tp->tp_getattr)(v, PyString_AS_STRING(name)); | |
PyErr_Format(PyExc_AttributeError, | |
"'%.50s' object has no attribute '%.400s'", | |
tp->tp_name, PyString_AS_STRING(name)); | |
return NULL; | |
} | |
static PyObject * | |
safe_PyObject_GetAttrString(PyObject *v, const char *name) | |
{ | |
PyObject *w, *res; | |
if (Py_TYPE(v)->tp_getattr != NULL) | |
return (*Py_TYPE(v)->tp_getattr)(v, (char*)name); | |
w = PyString_InternFromString(name); | |
if (w == NULL) | |
return NULL; | |
res = safe_PyObject_GetAttr(v, w); | |
Py_XDECREF(w); | |
return res; | |
} | |
static PyObject * | |
safe__generic_dir(PyObject *obj) | |
{ | |
PyObject *result = NULL; | |
PyObject *dict = NULL; | |
PyObject *itsclass = NULL; | |
/* Get __dict__ (which may or may not be a real dict...) */ | |
dict = safe_PyObject_GetAttrString(obj, "__dict__"); | |
if (dict == NULL) { | |
PyErr_Clear(); | |
dict = PyDict_New(); | |
} | |
else if (!PyDict_Check(dict)) { | |
Py_DECREF(dict); | |
dict = PyDict_New(); | |
} | |
else { | |
/* Copy __dict__ to avoid mutating it. */ | |
PyObject *temp = PyDict_Copy(dict); | |
Py_DECREF(dict); | |
dict = temp; | |
} | |
if (dict == NULL) | |
goto error; | |
/* Merge in __members__ and __methods__ (if any). | |
* This is removed in Python 3000. */ | |
if (safe_merge_list_attr(dict, obj, "__members__") < 0) | |
goto error; | |
if (safe_merge_list_attr(dict, obj, "__methods__") < 0) | |
goto error; | |
/* Merge in attrs reachable from its class. */ | |
itsclass = safe_PyObject_GetAttrString(obj, "__class__"); | |
if (itsclass == NULL) | |
/* XXX(tomer): Perhaps fall back to obj->ob_type if no | |
__class__ exists? */ | |
PyErr_Clear(); | |
else { | |
if (safe_merge_class_dict(dict, itsclass) != 0) | |
goto error; | |
} | |
result = PyDict_Keys(dict); | |
/* fall through */ | |
error: | |
Py_XDECREF(itsclass); | |
Py_XDECREF(dict); | |
return result; | |
} | |
static PyObject * | |
safe__specialized_dir_type(PyObject *obj) | |
{ | |
PyObject *result = NULL; | |
PyObject *dict = PyDict_New(); | |
if (dict != NULL && safe_merge_class_dict(dict, obj) == 0) | |
result = PyDict_Keys(dict); | |
Py_XDECREF(dict); | |
return result; | |
} | |
/* Helper for PyObject_Dir of module objects: returns the module's __dict__. */ | |
static PyObject * | |
safe__specialized_dir_module(PyObject *obj) | |
{ | |
PyObject *result = NULL; | |
PyObject *dict = safe_PyObject_GetAttrString(obj, "__dict__"); | |
if (dict != NULL) { | |
if (PyDict_Check(dict)) | |
result = PyDict_Keys(dict); | |
else { | |
char *name = PyModule_GetName(obj); | |
if (name) | |
PyErr_Format(PyExc_TypeError, | |
"%.200s.__dict__ is not a dictionary", | |
name); | |
} | |
} | |
Py_XDECREF(dict); | |
return result; | |
} | |
static PyObject * | |
safe__dir_object(PyObject *obj) | |
{ | |
PyObject *result = NULL; | |
static PyObject *dir_str = NULL; | |
PyObject *dirfunc; | |
assert(obj); | |
if (PyInstance_Check(obj)) { | |
dirfunc = safe_PyObject_GetAttrString(obj, "__dir__"); | |
if (dirfunc == NULL) { | |
if (PyErr_ExceptionMatches(PyExc_AttributeError)) | |
PyErr_Clear(); | |
else | |
return NULL; | |
} | |
} | |
else { | |
dirfunc = _PyObject_LookupSpecial(obj, "__dir__", &dir_str); | |
if (PyErr_Occurred()) | |
return NULL; | |
} | |
if (dirfunc == NULL) { | |
/* use default implementation */ | |
if (PyModule_Check(obj)) | |
result = safe__specialized_dir_module(obj); | |
else if (PyType_Check(obj) || PyClass_Check(obj)) | |
result = safe__specialized_dir_type(obj); | |
else | |
result = safe__generic_dir(obj); | |
} | |
else { | |
/* use __dir__ */ | |
result = PyObject_CallFunctionObjArgs(dirfunc, NULL); | |
Py_DECREF(dirfunc); | |
if (result == NULL) | |
return NULL; | |
/* result must be a list */ | |
/* XXX(gbrandl): could also check if all items are strings */ | |
if (!PyList_Check(result)) { | |
PyErr_Format(PyExc_TypeError, | |
"__dir__() must return a list, not %.200s", | |
Py_TYPE(result)->tp_name); | |
Py_DECREF(result); | |
result = NULL; | |
} | |
} | |
return result; | |
} | |
PyObject* safe_PyObject_Dir(PyObject *obj) | |
{ | |
PyObject * result; | |
#if 0 // don't need to support | |
if (obj == NULL) | |
/* no object -- introspect the locals */ | |
result = safe__dir_locals(); | |
else | |
#endif | |
/* object -- introspect the object */ | |
result = safe__dir_object(obj); | |
assert(result == NULL || PyList_Check(result)); | |
#if 0 // don't need to sort them | |
if (result != NULL && PyList_Sort(result) != 0) { | |
/* sorting the list failed */ | |
Py_DECREF(result); | |
result = NULL; | |
} | |
#endif | |
return result; | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#ifndef DIDYOUMEAN_SAFE_H | |
#define DIDYOUMEAN_SAFE_H | |
PyObject* safe_PyObject_Dir(PyObject *obj); | |
#endif |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <Python.h> | |
#include <stdio.h> | |
#include <sys/mman.h> | |
#include <unistd.h> | |
#include <string.h> | |
#include "didyoumean-safe.h" | |
#if !(__x86_64__) | |
#error "This only works on x86_64" | |
#endif | |
extern PyObject* PyErr_Occurred(void); | |
extern PyObject* PyObject_GetAttr(PyObject *v, PyObject *name); | |
static int distance(char* a, char* b) { | |
size_t maxi = strlen(b); | |
size_t maxj = strlen(a); | |
unsigned int compare[maxi+1][maxj+1]; | |
compare[0][0] = 0; | |
for (int i = 1; i <= maxi; i++) compare[i][0] = i; | |
for (int j = 1; j <= maxj; j++) compare[0][j] = j; | |
for (int i = 1; i <= maxi; i++) { | |
for (int j = 1; j <= maxj; j++) { | |
int left = compare[i-1][j] + 1; | |
int right = compare[i][j-1] + 1; | |
int middle = compare[i-1][j-1] + (a[j-1] == b[i-1] ? 0 : 1); | |
if( left < right && left < middle ) compare[i][j] = left; | |
else if( right < left && right < middle ) compare[i][j] = right; | |
else compare[i][j] = middle; | |
} | |
} | |
return compare[maxi][maxj]; | |
} | |
PyObject* trampoline(PyObject *v, PyObject *name) | |
{ | |
__asm__("nop"); | |
PyObject* rv = NULL; | |
PyTypeObject *tp = Py_TYPE(v); | |
if (!PyString_Check(name)) { | |
#ifdef Py_USING_UNICODE | |
/* The Unicode to string conversion is done here because the | |
existing tp_getattro slots expect a string object as name | |
and we wouldn't want to break those. */ | |
if (PyUnicode_Check(name)) { | |
name = _PyUnicode_AsDefaultEncodedString(name, NULL); | |
if (name == NULL) | |
return NULL; | |
} | |
else | |
#endif | |
{ | |
PyErr_Format(PyExc_TypeError, | |
"attribute name must be string, not '%.200s'", | |
Py_TYPE(name)->tp_name); | |
return NULL; | |
} | |
} | |
if (tp->tp_getattro != NULL) { | |
rv = (*tp->tp_getattro)(v, name); | |
} | |
else if (tp->tp_getattr != NULL) { | |
rv = (*tp->tp_getattr)(v, PyString_AS_STRING(name)); | |
} | |
else { | |
PyErr_Format(PyExc_AttributeError, | |
"'%.50s' object has no attribute '%.400s'", | |
tp->tp_name, PyString_AS_STRING(name)); | |
} | |
if(!rv && PyErr_Occurred() && PyErr_ExceptionMatches(PyExc_AttributeError)) { | |
PyThreadState *tstate = PyThreadState_GET(); | |
PyObject *oldtype, *oldvalue, *oldtraceback; | |
oldtype = tstate->curexc_type; | |
oldvalue = tstate->curexc_value; | |
oldtraceback = tstate->curexc_traceback; | |
PyErr_Clear(); | |
PyObject* dir = safe_PyObject_Dir(v); | |
Py_LeaveRecursiveCall(); | |
PyObject* candidate = NULL; | |
PyObject* newvalue = oldvalue; | |
if(dir) { | |
int candidate_dist = PyString_Size(name); | |
for(int i = 0; i < PyList_Size(dir); ++i) { | |
PyObject *item = PyList_GetItem(dir, i); | |
int dist = distance(PyString_AS_STRING(name), PyString_AS_STRING(item)); | |
if(!candidate || dist < candidate_dist ) { | |
candidate = item; | |
candidate_dist = dist; | |
} | |
} | |
if( candidate ) { | |
newvalue = PyString_FromFormat("%s\n\nMaybe you meant: .%s\n", | |
PyString_AS_STRING(oldvalue), | |
PyString_AS_STRING(candidate)); | |
Py_DECREF(oldvalue); | |
} | |
} | |
PyErr_Restore(oldtype, newvalue, oldtraceback); | |
} | |
return rv; | |
} | |
/* TODO: make less ugly! | |
* there's got to be a nicer way to do this! */ | |
#pragma pack(push, 1) | |
struct { | |
char push_rax; | |
char mov_rax[2]; | |
char addr[8]; | |
char jmp_rax[2]; } | |
jump_asm = { | |
.push_rax = 0x50, | |
.mov_rax = {0x48, 0xb8}, | |
.jmp_rax = {0xff, 0xe0} }; | |
#pragma pack(pop) | |
static PyMethodDef module_methods[] = { | |
{NULL} /* Sentinel */ | |
}; | |
PyDoc_STRVAR(module_doc, | |
"This module implements a \"did you mean?\" functionality on getattr/LOAD_ATTR.\n" | |
"(It's not so much what it does but how it does it.)"); | |
PyMODINIT_FUNC | |
initdidyoumean(void) { | |
__asm__(""); | |
Py_InitModule3("didyoumean", module_methods, module_doc); | |
void* target = PyObject_GetAttr; | |
char* page; | |
int rc; | |
int pagesize = sysconf(_SC_PAGE_SIZE); | |
void* addr = &trampoline; | |
page = (char *)addr; | |
page = (char *)((size_t) page & ~(pagesize - 1)); | |
rc = mprotect(page, pagesize, PROT_READ | PROT_WRITE | PROT_EXEC); | |
if(rc) { | |
fprintf(stderr, "mprotect() failed.\n"); | |
return; | |
} | |
int count; | |
for(count = 0; count < 255; ++count) | |
if(((unsigned char*)addr)[count] == 0x90) | |
break; // found the NOP | |
for(int i = count; i >= 0; --i) | |
((unsigned char*)addr)[i] = ((unsigned char*)addr)[i-1]; | |
*((unsigned char *)addr) = 0x58; | |
page = (char *)target; | |
page = (char *)((size_t) page & ~(pagesize - 1)); | |
rc = mprotect(page, pagesize, PROT_READ | PROT_WRITE | PROT_EXEC); | |
if(rc) { | |
fprintf(stderr, "mprotect() failed.\n"); | |
return; | |
} | |
memcpy(jump_asm.addr, &addr, sizeof (void *)); | |
memcpy(target, &jump_asm, sizeof jump_asm); | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
class Foo(object): | |
def bar(self): | |
pass | |
if __name__ == '__main__': | |
foo = Foo() | |
print foo.bar | |
try: | |
foo.baz | |
except Exception as e: | |
print e | |
import didyoumean | |
print foo.bar | |
try: | |
foo.baz | |
except Exception as e: | |
print e |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
CC=gcc -std=c99 -Wall | |
didyoumean.so: didyoumean.c didyoumean-safe.c | |
${CC} `python-config --cflags` `python-config --includes` -Wl,--export-dynamic -fPIC -shared -o $@ $^ -ldl `python-config --libs` |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment