Skip to content

Instantly share code, notes, and snippets.

@crowsonkb
Last active January 2, 2019 02:30
Show Gist options
  • Save crowsonkb/ec4ea19cab9b1f5a5bddbbddc8e237e7 to your computer and use it in GitHub Desktop.
Save crowsonkb/ec4ea19cab9b1f5a5bddbbddc8e237e7 to your computer and use it in GitHub Desktop.
import inspect
import typing
from typing import get_type_hints, TypeVar, Any, AnyStr, Generic, Union
from sphinx.util import logging
from sphinx.util.inspect import Signature
try:
from typing_extensions import Protocol
except ImportError:
Protocol = None
try:
from inspect import unwrap
except ImportError:
def unwrap(func, *, stop=None):
"""This is the inspect.unwrap() method copied from Python 3.5's standard library."""
if stop is None:
def _is_wrapper(f):
return hasattr(f, '__wrapped__')
else:
def _is_wrapper(f):
return hasattr(f, '__wrapped__') and not stop(f)
f = func # remember the original func for error reporting
memo = {id(f)} # Memoise by id to tolerate non-hashable objects
while _is_wrapper(func):
func = func.__wrapped__
id_func = id(func)
if id_func in memo:
raise ValueError('wrapper loop when unwrapping {!r}'.format(f))
memo.add(id_func)
return func
logger = logging.getLogger(__name__)
def format_annotation(annotation):
if inspect.isclass(annotation) and annotation.__module__ == 'builtins':
if annotation.__qualname__ == 'NoneType':
return '``None``'
else:
return ':py:class:`{}`'.format(annotation.__qualname__)
annotation_cls = annotation if inspect.isclass(annotation) else type(annotation)
class_name = None
if annotation_cls.__module__ == 'typing':
params = None
prefix = ':py:class:'
module = 'typing'
extra = ''
if inspect.isclass(getattr(annotation, '__origin__', None)):
annotation_cls = annotation.__origin__
try:
mro = annotation_cls.mro()
if Generic in mro or (Protocol and Protocol in mro):
module = annotation_cls.__module__
except TypeError:
pass # annotation_cls was either the "type" object or typing.Type
if annotation is Any:
return ':py:data:`~typing.Any`'
elif annotation is AnyStr:
return ':py:data:`~typing.AnyStr`'
elif isinstance(annotation, TypeVar):
return '\\%r' % annotation
elif (annotation is Union or getattr(annotation, '__origin__', None) is Union or
hasattr(annotation, '__union_params__')):
prefix = ':py:data:'
class_name = 'Union'
if hasattr(annotation, '__union_params__'):
params = annotation.__union_params__
elif hasattr(annotation, '__args__'):
params = annotation.__args__
if params and len(params) == 2 and (hasattr(params[1], '__qualname__') and
params[1].__qualname__ == 'NoneType'):
class_name = 'Optional'
params = (params[0],)
elif annotation_cls.__qualname__ == 'Tuple' and hasattr(annotation, '__tuple_params__'):
params = annotation.__tuple_params__
if annotation.__tuple_use_ellipsis__:
params += (Ellipsis,)
elif annotation_cls.__qualname__ == 'Callable':
prefix = ':py:data:'
arg_annotations = result_annotation = None
if hasattr(annotation, '__result__'):
arg_annotations = annotation.__args__
result_annotation = annotation.__result__
elif getattr(annotation, '__args__', None):
arg_annotations = annotation.__args__[:-1]
result_annotation = annotation.__args__[-1]
if arg_annotations in (Ellipsis, (Ellipsis,)):
params = [Ellipsis, result_annotation]
elif arg_annotations is not None:
params = [
'\\[{}]'.format(
', '.join(format_annotation(param) for param in arg_annotations)),
result_annotation
]
elif hasattr(annotation, 'type_var'):
# Type alias
class_name = annotation.name
params = (annotation.type_var,)
elif getattr(annotation, '__args__', None) is not None:
params = annotation.__args__
elif hasattr(annotation, '__parameters__'):
params = annotation.__parameters__
if params:
extra = '\\[{}]'.format(', '.join(format_annotation(param) for param in params))
if not class_name:
class_name = annotation_cls.__qualname__.title()
return '{}`~{}.{}`{}'.format(prefix, module, class_name, extra)
elif annotation is Ellipsis:
return '...'
elif (inspect.isfunction(annotation) and annotation.__module__ == 'typing' and
hasattr(annotation, '__name__') and hasattr(annotation, '__supertype__')):
return ':py:func:`~typing.NewType`\\(:py:data:`~{}`, {})'.format(
annotation.__name__, format_annotation(annotation.__supertype__))
elif inspect.isclass(annotation) or inspect.isclass(getattr(annotation, '__origin__', None)):
if not inspect.isclass(annotation):
annotation_cls = annotation.__origin__
extra = ''
mro = annotation_cls.mro()
if Generic in mro or (Protocol and Protocol in mro):
params = (getattr(annotation, '__parameters__', None) or
getattr(annotation, '__args__', None))
if params:
extra = '\\[{}]'.format(', '.join(format_annotation(param) for param in params))
return ':py:class:`~{}.{}`{}'.format(annotation.__module__, annotation_cls.__qualname__,
extra)
return str(annotation)
def process_signature(app, what: str, name: str, obj, options, signature, return_annotation):
if not callable(obj):
return
if what in ('class', 'exception'):
obj = getattr(obj, '__init__', getattr(obj, '__new__', None))
if not getattr(obj, '__annotations__', None):
return
obj = unwrap(obj)
signature = Signature(obj)
parameters = [
param.replace(annotation=inspect.Parameter.empty)
for param in signature.signature.parameters.values()
]
if parameters:
if what in ('class', 'exception'):
del parameters[0]
elif what == 'method':
outer = inspect.getmodule(obj)
for clsname in obj.__qualname__.split('.')[:-1]:
outer = getattr(outer, clsname)
method_name = obj.__name__
if method_name.startswith("__") and not method_name.endswith("__"):
# If the method starts with double underscore (dunder)
# Python applies mangling so we need to prepend the class name.
# This doesn't happen if it always ends with double underscore.
class_name = obj.__qualname__.split('.')[-2]
method_name = "_{c}{m}".format(c=class_name, m=method_name)
method_object = outer.__dict__[method_name]
if not isinstance(method_object, (classmethod, staticmethod)):
del parameters[0]
signature.signature = signature.signature.replace(
parameters=parameters,
return_annotation=inspect.Signature.empty)
return signature.format_args().replace('\\', '\\\\'), None
def process_docstring(app, what, name, obj, options, lines):
if isinstance(obj, property):
obj = obj.fget
if callable(obj):
if what in ('class', 'exception'):
obj = getattr(obj, '__init__')
obj = unwrap(obj)
type_hints = obj.__annotations__
for argname, annotation in type_hints.items():
if argname.endswith('_'):
argname = '{}\\_'.format(argname[:-1])
formatted_annotation = f':py:obj:`{annotation}`'
if argname == 'return':
if what in ('class', 'exception'):
# Don't add return type None from __init__()
continue
insert_index = len(lines)
for i, line in enumerate(lines):
if line.startswith(':rtype:'):
insert_index = None
break
elif line.startswith(':return:') or line.startswith(':returns:'):
insert_index = i
if insert_index is not None:
if insert_index == len(lines):
# Ensure that :rtype: doesn't get joined with a paragraph of text, which
# prevents it being interpreted.
lines.append('')
insert_index += 1
lines.insert(insert_index, ':rtype: {}'.format(formatted_annotation))
else:
searchfor = ':param {}:'.format(argname)
for i, line in enumerate(lines):
if line.startswith(searchfor):
lines.insert(i, ':type {}: {}'.format(argname, formatted_annotation))
break
def builder_ready(app):
if app.config.set_type_checking_flag:
typing.TYPE_CHECKING = True
def setup(app):
app.add_config_value('set_type_checking_flag', False, 'html')
app.connect('builder-inited', builder_ready)
app.connect('autodoc-process-signature', process_signature)
app.connect('autodoc-process-docstring', process_docstring)
return dict(parallel_read_safe=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment