-
-
Save lbragstad/a0b30f15b92798df6141 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Offline validation of oauth access_keys. | |
""" | |
import base64 | |
import string | |
import unittest | |
import urllib | |
import uuid | |
import zlib | |
from Crypto.Cipher import AES | |
from Crypto import Random | |
KEY = Random.new().read(32) | |
IV = Random.new().read(16) | |
ALPHABET = string.digits + string.letters | |
def encode(s): | |
s = base64.b64encode(s) | |
s = s.replace('+', '-') | |
s = s.replace('/', '_') | |
s = s.replace('=', '.') | |
return s | |
def decode(s): | |
s = s.replace('-', '+') | |
s = s.replace('_', '/') | |
s = s.replace('.', '=') | |
s = base64.b64decode(s) | |
return s | |
def compress(s): | |
return zlib.compress(s, 9) | |
def decompress(s): | |
return zlib.decompress(s) | |
def encrypt(plaintext): | |
cipher = AES.new(KEY, AES.MODE_CFB, IV) | |
return cipher.encrypt(plaintext) | |
def decrypt(ciphertext): | |
cipher = AES.new(KEY, AES.MODE_CFB, IV) | |
return cipher.decrypt(ciphertext) | |
def change_base(number): | |
"""Converts an integer to a string.""" | |
s = '' | |
while number != 0: | |
number, i = divmod(number, len(ALPHABET)) | |
s = ALPHABET[i] + s | |
return s | |
def generate_access_key(secret): | |
# the oauth secret is encoded into the access key so that oauth | |
# middleware can validate the oauth signature before making backend or | |
# remote calls | |
plaintext = compress(secret) | |
ciphertext = encrypt(plaintext) | |
encoded = encode(ciphertext) | |
return encoded | |
def verify_access_key(access_key): | |
ciphertext = decode(access_key) | |
plaintext = decrypt(ciphertext) | |
secret = decompress(plaintext) | |
return secret | |
class Tests(unittest.TestCase): | |
def setUp(self): | |
self.secret = uuid.uuid4().hex | |
def test_encode_decode(self): | |
s = 'a /+-_\xe8' | |
self.assertNotEqual(s, encode(s)) | |
self.assertEqual(s, decode(encode(s))) | |
def test_encrypt_decrypt(self): | |
s = uuid.uuid4().hex | |
self.assertEqual(s, decrypt(encrypt(s))) | |
def test_encrypt_encode_decode_decrypt(self): | |
s = 'a /+-_\xe8' | |
self.assertEqual(s, decrypt(decode(encode(encrypt(s))))) | |
def test_decrypt_access_key(self): | |
access_key = generate_access_key(self.secret) | |
# access keys should be url friendly as-is | |
self.assertEqual(urllib.quote(access_key), access_key) | |
self.assertEqual(urllib.quote_plus(access_key), access_key) | |
secret = verify_access_key(access_key) | |
self.assertEqual(self.secret, secret) | |
def test_access_key_length_reasonable(self): | |
access_key = generate_access_key(self.secret) | |
self.assertLessEqual(len(access_key), 255) | |
if __name__ == '__main__': | |
wrap = lambda s: '\n\n %s\n' % s | |
secret = change_base(uuid.uuid4().int) | |
print('Generate the secret key first: %s' % wrap(secret)) | |
access_key = generate_access_key(secret) | |
print('Then we can derive the access key: %s' % wrap(access_key)) | |
print('Which has a reasonable length of: %s' % wrap(len(access_key))) | |
verified_secret = verify_access_key(access_key) | |
print('Later, middleware receives an OAuth-signed request with an access ' | |
'key, and can independently extract the secret used to sign the ' | |
'request: %s' % wrap(verified_secret)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import base64 | |
import datetime | |
import hashlib | |
import hmac | |
import uuid | |
import msgpack | |
import demo | |
wrap = lambda s: '\n\n %s\n' % s | |
def make_digest(payload): | |
"""HMAC sign a payload. | |
:param payload: the information to generate a digest | |
:returns: a base64 encoded digest | |
""" | |
digest = hmac.new('super-secret-key', payload, hashlib.sha1).digest() | |
return base64.encodestring(digest).strip('\n') | |
def check_digest(payload): | |
"""Validate the information hasn't been tampered with. | |
:param payload: the entire message payload including the message and the | |
digest, which should be the first 29 characters. | |
:returns: True if the newly calculated digest matches the original digest | |
and False otherwise. | |
""" | |
(token_version, token_digest, encrypted_info) = break_down_token(payload) | |
new_digest = make_digest(encrypted_info) | |
return token_digest == new_digest | |
def break_down_token(token): | |
"""Break down the token depending on the version. | |
This will be an operation that is determined by the token version. We can | |
safely assume that the first character of the string will be the version | |
of the token. In the future, this method will only pop the version off the | |
payload and then determine which formatter to pass the rest of the string | |
to and except back a set of data that matches the format of the version. | |
For this case, this will just pop the version and decrypt the information | |
to give an idea of how this can be done. | |
:param token: this is the entire payload string that is represented as a | |
token. | |
:returns: tuple of token information | |
""" | |
# The first character will be the version | |
token_version = token[:1] | |
# The next 28 characters will be the digest generated on token creation. | |
token_digest = token[1:29] | |
# The remaining data in the token string is the encrypted token | |
# information. | |
encrypted_info = token[29:] | |
return (token_version, token_digest, encrypted_info) | |
def decrypt_token(encrypted_info): | |
"""Decrypt the information in the token. | |
This is something that will be handled by versions. | |
:param encrypted_info: a string of encrypted data | |
:returns: a dictionary of information based on the token format and | |
version. | |
""" | |
ciphertext = demo.decode(encrypted_info) | |
plaintext = demo.decrypt(ciphertext) | |
token_info = msgpack.unpackb(plaintext) | |
# Here is where the token format and version comes into play because | |
# we will need to know this information to properaly assume the order of | |
# the values. | |
token = dict() | |
token['user_id'] = token_info[0] | |
token['project_id'] = token_info[1] | |
token['created_at'] = convert_timestamp_to_datetime(token_info[2]) | |
token['token_ttl'] = token_info[3] | |
token['audit_id'] = token_info[4] | |
return token | |
def generate_token(message, version): | |
"""Encrypt and generate a token. | |
:param message: token information to encrypt | |
:param version: token version format | |
:returns: an authenticated encrypted token string | |
""" | |
msgpack_message = msgpack.packb(message) | |
msgpack_ciphertext = demo.encrypt(msgpack_message) | |
msgpack_encode = demo.encode(msgpack_ciphertext) | |
digest = make_digest(msgpack_encode) | |
result = str(version) + digest + msgpack_encode | |
print ('Encoded payload using msgpack: %s' % wrap(result)) | |
print ('Encoded payload length using msgpack: %s' % | |
wrap(len(result))) | |
return result | |
def convert_datetime_to_timestamp(date_object): | |
"""Convert from datetime to timestamp integer. | |
:param date_object: datetime object to convert to timestamp | |
:returns: timestamp integer | |
""" | |
return int(date_object.strftime('%s')) | |
def convert_timestamp_to_datetime(timestamp): | |
"""Convert from timestamp integer to datetime object. | |
:param timestamp: timestamp integer | |
:return: datetime object | |
""" | |
return datetime.datetime.fromtimestamp(int(timestamp)) | |
def generate_audit_id(): | |
"""Create a dummy audit_id for the token. | |
:returns: an audit id | |
""" | |
return base64.urlsafe_b64encode(uuid.uuid4().bytes)[:-2] | |
if __name__ == '__main__': | |
# Populate some data | |
user_id = uuid.uuid4().hex | |
project_id = uuid.uuid4().hex | |
created_at = datetime.datetime.now() | |
token_life = datetime.timedelta(hours=2) | |
expires_at = created_at + token_life | |
audit_id = generate_audit_id() | |
# Convert `created_at` datetime object to a timestamp | |
created_at_ts = convert_datetime_to_timestamp(created_at) | |
# Calculate token ttl in seconds | |
token_ttl = int(token_life.total_seconds()) | |
message = [user_id, project_id, created_at_ts, token_ttl, audit_id] | |
print ('Initial payload: %s' % wrap(message)) | |
# Generate token with msgpack | |
token = generate_token(message, 1) | |
# Validate the token was not tampered with | |
print ('Was the token tampered with? %s' % wrap(not check_digest(token))) | |
(token_version, token_digest, encrypted_info) = break_down_token(token) | |
token_dict = decrypt_token(encrypted_info) | |
print ('Token user: %s' % wrap(token_dict['user_id'])) | |
print ('Token project: %s' % wrap(token_dict['project_id'])) | |
print ('Token created at: %s' % wrap(token_dict['created_at'])) | |
print ('Token expires at: %s' % wrap(token_dict['created_at'] + | |
datetime.timedelta(seconds=token_dict['token_ttl']))) | |
print ('Token audit_id: %s' % wrap(token_dict['audit_id'])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment