Created
May 17, 2013 15:07
-
-
Save prschmid/5599682 to your computer and use it in GitHub Desktop.
An example test harness for testing the response status codes from routes in your flask application for different kinds of users (e.g. user, admin, super_user, etc.)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import unittest | |
# Import your Flask app from your module | |
from myapp import app | |
# A fitcitious database object that has get/put methods for getting/adding | |
# data. In your code you will want to use whatever database you are using | |
# (E.g. SQLAlchemy, MongoDB, CouchDB, etc.) | |
from database import database | |
class FlaskAppRouteStatusCodeTestCase(unittest.TestCase): | |
"""Base class that creates various calls so that we can test all possible | |
combinations of user types for the REST endpoints.""" | |
__GET_STATUS_CODES__ = dict( | |
user=403, | |
admin=403, | |
super_user=403 | |
) | |
__POST_STATUS_CODES__ = dict( | |
user=403, | |
admin=403, | |
super_user=403 | |
) | |
__PATCH_STATUS_CODES__ = dict( | |
user=403, | |
admin=403, | |
super_user=403 | |
) | |
def setUp(self): | |
"""The unit test setUp method that gets called before every test case""" | |
app.config['TESTING'] = True | |
self.app = app | |
# .. setup any other stuff .. | |
def _setup_for_test(self): | |
"""A simple setup to be run in addition to the default setUp(). | |
This is a good place to setup all of the database initialization | |
so that we can 1) initialize the database and 2) get one user | |
for each type (user, admin, super_user). | |
Note, we run this here as opposed to in setUp since we want to return | |
the things that are created in this method. | |
:returns: A tuple of the form ``(test_data, db_data)`` such that | |
``test_data`` is what is returned by :meth:`_get_testing_data` | |
and db_data is what is returned by :meth:`_init_db` | |
""" | |
db_data = self._init_db() | |
test_data = self._get_testing_data(db_data) | |
return (test_data, db_data) | |
def _init_db(self): | |
"""Initialize the database with the data that you want to have for | |
all of your tests.""" | |
database.put(User(email='[email protected]')) | |
database.put(User(email='[email protected]')) | |
database.put(User(email='[email protected]')) | |
# .. add all the other things that you may need/want .. | |
def _get_testing_data(self, db_data): | |
"""Get all of the data needed for the individual tests. | |
This basically chooses a random user, admin, and super_user | |
:param db_data: The data returned by :meth:`_init_db` | |
:returns: A dictionary with all of the data | |
:raises: | |
:RuntimeError: If there is an error setting up the data for some | |
reason. | |
""" | |
return dict( | |
user = database.get('[email protected]'), | |
admin = database.get('[email protected]'), | |
super_user = database.get('[email protected]') | |
) | |
def before_all_gets(self, test_data, db_data): | |
"""Run right before all of the GETs are tested.""" | |
pass | |
def get(self, user, test_data, db_data): | |
raise NotImplementedError("Implement me, or explicitly skip me.") | |
def before_all_posts(self, test_data, db_data): | |
"""Run right before all of the POSTs are tested.""" | |
pass | |
def post(self, user, test_data, db_data): | |
raise NotImplementedError("Implement me, or explicitly skip me.") | |
def before_all_patches(self, test_data, db_data): | |
"""Run right before all of the PATCHes are tested.""" | |
pass | |
def patch(self, user, test_data, db_data): | |
raise NotImplementedError("Implement me, or explicitly skip me.") | |
# ------------------------------------------------------------------------- | |
# The things below are the test_ methods that get run by unittest | |
# ------------------------------------------------------------------------- | |
def test_get(self): | |
test_data, db_data = self._setup_for_test() | |
self.before_all_posts(test_data, db_data) | |
self._run_tests( | |
self.get, self.__GET_STATUS_CODES__, test_data, db_data) | |
def test_post(self): | |
test_data, db_data = self._setup_for_test() | |
self.before_all_posts(test_data, db_data) | |
self._run_tests( | |
self.post, self.__POST_STATUS_CODES__, test_data, db_data) | |
def test_patch(self): | |
test_data, db_data = self._setup_for_test() | |
self.before_all_posts(test_data, db_data) | |
self._run_tests( | |
self.patch, self.__PATCH_STATUS_CODES__, test_data, db_data) | |
def _run_tests(self, method, status_codes, test_data, db_data): | |
"""Run all of the tests for each of the user types. | |
This is the method that will log in as each type of user and try to | |
perform the method (POST, GET, etc). | |
:param method: The method to perform (self.get, self.post, etc) | |
:param test_data: The data returned by :meth:`_get_testing_data` | |
:param db_data: The data returned by :meth:`_init_db` | |
""" | |
for user_type, user in test_data.iteritems(): | |
rv = method(user, test_data, db_data) | |
self.assertEqual( | |
rv.status_code, status_codes[user_type], | |
"Status codes did not match for {} during {}. Expected {} " | |
"got {}".format( | |
user_type, method.__name__, | |
status_codes[user_type], rv.status_code)) | |
# ------------------------------------------------------------------------- | |
# Helper methods for logging in and logging out | |
# ------------------------------------------------------------------------- | |
def login(self, email=None, password=None): | |
return self.app.post( | |
"/login", | |
data=json.dumps(dict(email=email, password=password)), | |
content_type='application/json') | |
def logout(self): | |
return self.app.post('/logout') | |
class SomeRouteTestCase(FlaskAppRouteStatusCodeTestCase): | |
"""Test for /foo""" | |
__GET_STATUS_CODES__ = dict( | |
user=200, | |
admin=200, | |
super_user=200 | |
) | |
__PUT_STATUS_CODES__ = dict( | |
user=403, | |
admin=200, | |
super_user=200 | |
) | |
def get(self, user, test_data, db_data): | |
self.login(user.email, user.password) | |
rv = self.app.get('/foo') | |
self.logout() | |
return rv | |
def post(self, user, test_data, db_data): | |
self.login(user.email, user.password) | |
rv = self.app.post( | |
'/foo', | |
data=json.dumps(dict(bar="barbar", bam="bambam")), | |
content_type='application/json') | |
self.logout() | |
@unittest.skip("No PATCH") | |
def test_patch(self): | |
"""Override the PATCH tester since /foo can't be patched.""" | |
pass |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment