Last active
April 21, 2020 22:55
-
-
Save beeftornado/7d1ded6989ba5d0d462b to your computer and use it in GitHub Desktop.
Python simulate broken dns for unit tests
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# # Broken Socket | |
""" A broken socket implementation. | |
Credits: | |
* [A. Jesse Jiryu Davis](http://emptysqua.re/blog/undoing-gevents-monkey-patching/) | |
Monkey patches the built in socket implmentation so that various exceptions are | |
raised. Useful for running unit tests to validate behavior when connections | |
fail. | |
**Usage**: | |
If you have a directory structure like this for your tests: | |
.. sourcecode:: | |
src/ | |
..project files.. | |
tests/ | |
__init__.py | |
test_pkg.py | |
utils/ | |
__init__.py | |
broken_socket.py (this file) | |
Then in the test code you can use the broken socket implementation like this: | |
.. sourcecode:: python | |
from nose.tools import timed | |
from .utils import broken_socket | |
# new-style | |
class TestMyApp(object): | |
def __init__(self): | |
super(TestMyApp, self).__init__() | |
@timed(10) | |
def test_third_party_api_call(self): | |
old_dns_attrs = broken_socket.patch_dns() | |
try: | |
# do test ... | |
finally: | |
broken_socket.unpatch_dns(old_dns_attrs) | |
# old-style | |
class Test(unittest.TestCase): | |
def test(self): | |
old_dns_attrs = broken_socket.patch_dns() | |
try: | |
# do test ... | |
finally: | |
broken_socket.unpatch_dns(old_dns_attrs) | |
""" | |
def patch_dns(): | |
""" Patches the socket module to create broken connections. | |
This method stores the old socket attributes for unpatching. | |
""" | |
_socket = __import__('socket') | |
old_attrs = {} | |
old_attrs['getaddrinfo'] = _socket.getaddrinfo | |
_socket.getaddrinfo = getaddrinfo | |
old_attrs['gethostbyname'] = _socket.gethostbyname | |
_socket.gethostbyname = gethostbyname | |
return old_attrs | |
def unpatch_dns(old_attrs): | |
""" Take output of patch_dns() and undo patching. """ | |
_socket = __import__('socket') | |
for attr in old_attrs: | |
setattr(_socket, attr, old_attrs[attr]) | |
def raise_random_socket_error(): | |
""" Raises a random socket error """ | |
import socket | |
import random | |
errs = [ | |
socket.error, | |
socket.gaierror, | |
socket.timeout, | |
socket.herror, | |
] | |
raise random.choice(errs)() | |
def gethostbyname(*args, **kwargs): | |
""" Broken version """ | |
raise_random_socket_error() | |
def getaddrinfo(*args, **kwargs): | |
""" Broken version """ | |
raise_random_socket_error() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment