Created
June 10, 2013 15:44
-
-
Save xfenix/5749783 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class UsualCustomResourceTestCase(CustomResourceTestCase): | |
tested_object = None | |
resource_object = None | |
resources = None | |
created_stack = [] | |
generated_numbers = [] | |
exclude_fields = ['created', 'modified'] | |
exclude_ignore_fields = ['name', ] | |
username = 'dmitri' | |
password = '12345' | |
probable_primary_keys = ['id', 'pk'] | |
skip_classes = ( | |
'UsualCustomResourceTestCase', | |
'LocalUsualCase', | |
) | |
replace_parts = ( | |
'Resource', | |
'Test', | |
'Case', | |
) | |
def setUp(self): | |
class_name = self.__class__.__name__ | |
if class_name in self.skip_classes: | |
self.skipTest( u'Some test cases doesn\'t need to be tested') | |
""" | |
ELIMINATE duplication | |
we can get tested object, resource object and model name from test class name | |
it's very simple & we can write really compact usual test cases | |
""" | |
if self.tested_object is None: | |
probable_object = str_replace_multi(class_name, self.replace_parts) | |
try: | |
self.tested_object = eval('self.models.' + probable_object) | |
except: | |
pass | |
if self.model_name is None: | |
try: | |
self.model_name = self.tested_object.__name__.lower() | |
except: | |
pass | |
if self.resource_object is None: | |
try: | |
self.resource_object = eval( | |
'self.resources.' + | |
self.tested_object.__name__ + | |
'Resource' | |
) | |
except: | |
pass | |
super(UsualCustomResourceTestCase, self).setUp() | |
""" | |
Helpers methods for data generation | |
""" | |
def get_random_string(self, length=25): | |
return ''.join( | |
random.choice(string.lowercase) for i in range(length) | |
) | |
def get_random_email(self): | |
return reduce( | |
operator.concat, | |
( | |
self.get_random_string( | |
self.get_random_integer(5, 30) | |
), | |
'@', | |
self.get_random_string( | |
self.get_random_integer(5, 15) | |
), | |
'.', | |
self.get_random_string( | |
self.get_random_integer(2, 3) | |
) | |
) | |
) | |
def get_random_integer(self, start=100, end=100000): | |
return random.randrange(start, end) | |
def get_random_integer2(self, digits=4): | |
start = 10**(digits - 1) | |
end = 10*start | |
return self.get_random_integer(start, end) | |
def get_random_decimal(self, start=100, end=1000, places=2): | |
div = int(10**places) | |
div = div if div > 0 else 100 | |
integer = self.get_random_integer(start*div, end*div) | |
return Decimal(integer)/div | |
def get_random_decimal2(self, digits=2, places=2): | |
start = 10**(digits - 1) | |
end = start*10 | |
return self.get_random_decimal(start, end, places) | |
def get_random_number_dict(self): | |
"""Generate number dict and check if it unique""" | |
number = '9' + str(self.get_random_integer2(9)) | |
if number in self.generated_numbers: | |
return self.get_random_number_dict() | |
else: | |
self.generated_numbers.append(number) | |
return { | |
'number10': number, | |
'number': '7' + number, | |
} | |
def get_random_number(self, full=False): | |
result = self.get_random_number_dict() | |
return result['number'] if full else result['number10'] | |
def get_random_bool(self): | |
return bool(random.getrandbits(1)) | |
def get_random_date(self, withtime=False): | |
date = datetime.date if not withtime else datetime.datetime | |
start_date = date.today().replace(day=1, month=1).toordinal() | |
end_date = date.today().toordinal() | |
return date.fromordinal(random.randint(start_date, end_date)) | |
def get_random_datetime(self): | |
return self.get_random_date(True) | |
def get_random_datetime_tz(self): | |
data = self.get_random_datetime() | |
return make_aware(data, get_current_timezone()) | |
def wrap_data(self, data): | |
return self.get_object_fields() if not data else data | |
""" | |
Almost not-overridable methods of the test case | |
""" | |
def create_object(self, data=False): | |
data = self.get_object_fields() if not data else data | |
self.created_stack.append(data) | |
return self.tested_object.objects.create(**data) | |
def get_count(self, *args, **kwargs): | |
return self.tested_object.objects.filter(*args, **kwargs).count() | |
def get_wrong_post_payload(self, steps=2): | |
gen = lambda: self.get_random_string() | |
payload = {} | |
for i in range(steps): | |
payload[gen()] = gen() | |
return payload | |
def is_allow_method(self, method): | |
res = self.resource_object | |
if hasattr(res, 'Meta'): | |
if hasattr(res.Meta, 'list_allowed_methods'): | |
if method not in res.Meta.list_allowed_methods: | |
return False | |
return True | |
def prepare_filter_data(self, data): | |
"""Here we must convert data values to enable db comparison | |
Replace resource_uri with its uri (last segment) | |
Replace Decimal with string analog | |
""" | |
data = self.prepare_assert_data(data) | |
if isinstance(data, basestring): | |
if re.match('/.*/api/v[0-9]*?/.*/', data): | |
return filter(None, data.split(self.url_separator))[-1] | |
return data | |
def prepare_assert_data(self, data, resp=''): | |
if isinstance(data, (datetime.date, datetime.datetime)): | |
return str(data).replace(' ', 'T') | |
if isinstance(data, Decimal) or (resp and isinstance(resp, basestring)): | |
return str(data) | |
return data | |
""" | |
Probably overridable methods of the test case | |
""" | |
def get_put_payload(self): | |
# generate put payload from post payload by default | |
return self.get_post_payload() | |
""" | |
Overridable methods (usual) of the test case | |
""" | |
def get_object_fields(self): | |
meta = self.tested_object._meta | |
model = meta.fields | |
fields = {} | |
for field in model: | |
key = field.name | |
field_type = meta.get_field(key).get_internal_type() | |
if (key in self.exclude_fields or field.null) and\ | |
key not in self.exclude_ignore_fields: | |
continue | |
if field_type == 'DecimalField': | |
fields[key] = self.get_random_decimal2(10) | |
elif field_type == 'DateField': | |
fields[key] = self.get_random_date() | |
elif field_type == 'DateTimeField': | |
fields[key] = self.get_random_datetime_tz() | |
elif field_type in ('IntegerField', 'BigIntegerField'): | |
fields[key] = self.get_random_integer2() | |
elif field_type in ('CharField', 'TextField', 'SlugField'): | |
fields[key] = self.get_random_string(field.max_length) | |
elif field_type == 'BooleanField': | |
fields[key] = self.get_random_bool() | |
elif field_type == 'EmailField': | |
fields[key] = self.get_random_email() | |
return fields | |
def get_post_payload(self): | |
return self.get_object_fields() | |
""" | |
Tests themselves | |
""" | |
def test_unauthorized(self): | |
self.create_object() | |
params = self.get_rest_params() | |
self.assertHttpUnauthorized( | |
self.api_client.get(self.get_rest_url(), **params) | |
) | |
self.assertHttpUnauthorized( | |
self.api_client.get(self.get_rest_url(1), **params) | |
) | |
def test_get_list(self): | |
self.create_perms('view') | |
self.login() | |
self.create_object() | |
resp = self.api_client.get(self.get_rest_url(), **self.get_rest_params()) | |
self.assertValidJSONResponse(resp) | |
def test_get_detail(self): | |
self.create_perms('view') | |
self.login() | |
n = self.create_object() | |
resp = self.api_client.get(self.get_rest_url(n.pk), **self.get_rest_params()) | |
self.assertValidJSONResponse(resp) | |
def test_post_list(self): | |
""" | |
Tests the creation of an object with post method | |
if post is not allowed, tests server response (405) | |
""" | |
allow = self.is_allow_method('post') | |
url = self.get_rest_url() | |
params = self.get_rest_params(data=self.get_post_payload()) | |
send = lambda: self.api_client.post(url, **params) | |
post = lambda: self.assertHttpUnauthorized(send()) if allow else\ | |
lambda: self.assertHttpMethodNotAllowed(send()) | |
post() | |
self.login() | |
post() | |
self.create_perms('add') | |
if allow: | |
cnt = self.get_count() | |
self.assertHttpCreated(send()) | |
self.assertEqual(cnt + 1, self.get_count()) | |
else: | |
post() | |
def test_deny_mass_list_operations(self): | |
url = self.get_rest_url() | |
params = self.get_rest_params(data=self.get_wrong_post_payload()) | |
self.create_perms('change') | |
self.login() | |
self.assertHttpMethodNotAllowed( | |
self.api_client.put(url, **params) | |
) | |
self.assertHttpMethodNotAllowed( | |
self.api_client.delete(url, **params) | |
) | |
self.assertHttpMethodNotAllowed( | |
self.api_client.patch(url, **params) | |
) | |
def test_delete_detail(self): | |
self.create_perms('delete') | |
self.login() | |
n = self.create_object() | |
cnt = self.get_count() | |
resp = self.api_client.delete(self.get_rest_url(n.pk), **self.get_rest_params()) | |
self.assertHttpAccepted(resp) | |
self.assertEqual(cnt - 1, self.get_count()) | |
def test_put_detail(self): | |
ns = self.create_object() | |
cnt = self.get_count() | |
put_data = self.get_put_payload() | |
put_items = put_data.items() | |
params = self.get_rest_params(data=put_data) | |
self.login() | |
self.create_perms('change') | |
response = self.api_client.put(self.get_rest_url(ns.pk), **params) | |
self.assertHttpAccepted(response) | |
dresp = self.deserialize(response)['objects'][0] | |
for key in self.probable_primary_keys: | |
if key in dresp: | |
self.assertEqual( | |
dresp[key], | |
self.prepare_assert_data(ns.pk, dresp[key]) | |
) | |
for key, value in put_items: | |
if not isinstance(dresp[key], dict): | |
self.assertEqual( | |
dresp[key], | |
self.prepare_assert_data(value, dresp[key]) | |
) | |
self.assertEqual(cnt, self.get_count()) | |
# need to filter with ALL fields | |
# otherwise we can get some troubles | |
condition = {} | |
for key, value in put_items: | |
condition[key] = self.prepare_filter_data(value) | |
self.assertEqual(1, | |
self.tested_object.objects.filter(**condition).count() | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment