Skip to content

Instantly share code, notes, and snippets.

@moriyoshi
Created October 31, 2015 04:33
Show Gist options
  • Save moriyoshi/6ec572d8ea3b8db6561b to your computer and use it in GitHub Desktop.
Save moriyoshi/6ec572d8ea3b8db6561b to your computer and use it in GitHub Desktop.
Downloads a log file from an RDS instance.
#!/usr/bin/python
from __future__ import print_function
import httplib
import urllib
import hashlib
import urlparse
import hmac
from datetime import datetime
from dateutil import tz
def pairs(pairs):
if hasattr(pairs, 'iteritems'):
return pairs.iteritems()
elif hasattr(pairs, 'items'):
return pairs.items()
else:
return iter(pairs)
param_encoders = {
'application/x-www-form-urlencoded': lambda params, encoding, add_header: \
urllib.urlencode([
(k.encode(encoding), v.encode(encoding))
for k, v in pairs(params)
])
}
hashers = {
'sha256': ('AWS4-HMAC-SHA256', 'X-Amz-Content-Sha256', hashlib.sha256),
}
def denaive_datetime(now):
if now.tzinfo is None:
now = now.replace(tzinfo=tz.tzlocal())
return now
class AWSHTTPRequest(object):
signed_headers = set([
'host',
'content-type',
'x-amz-date',
])
reserved_keys = set([
'x-amz-date',
'x-amz-algorithm',
'x-amz-signature',
'x-amz-expires',
'x-amz-credential',
])
reserved_headers = signed_headers | reserved_keys | set([
'authorization',
])
def __init__(self, method, host, request_url, query_string, content_type='application/x-www-form-urlencoded', params=None, encoding='utf-8', payload=None, headers=[]):
self.method = method
self.host = host
self.request_url = request_url
self.query_string = query_string
self.content_type = content_type
self._headers = headers
self.encoding = encoding
self._params = params
self._payload = payload
self._amz_date_cache = {}
def format_amz_date(self, now):
if now in self._amz_date_cache:
retval = self._amz_date_cache[now]
else:
retval = self._amz_date_cache[now] = now.astimezone(tz.tzutc()).strftime("%Y%m%dT%H%M%SZ")
return retval
@property
def payload(self):
if self._payload is None and self._params is not None:
content_type = self.content_type.lower()
self._payload = param_encoders[content_type](self._params, self.encoding, self.add_header)
return self._payload
@payload.setter
def payload(self, value):
self._params = None
self._payload = value
@property
def params(self):
return self._params
@params.setter
def params(self, value):
self._params = value
self._payload = None
def add_header(self, k, v):
self.headers.append((k, v))
def _get_signed_headers(self, id, now, region, service):
amz_date = self.format_amz_date(now)
headers = [
('Host','host', self.host),
('X-Amz-Date', 'x-amz-date', amz_date),
]
if self.payload is not None:
headers.append(('Content-Type', 'content-type', '%s; charset=%s' % (self.content_type, self.encoding)))
basic_header_names = set(lk for _, lk, _ in headers)
for pair in self._headers:
k, v = pair
lk = k.lower()
if lk not in basic_header_names and lk in self.signed_headers:
headers.append((k, lk, v))
return headers
def _get_credential_scope(self, now, region, service):
return '%s/%s/%s/aws4_request' % (self.format_amz_date(now)[0:8], region, service)
def _get_hashed_canonical_request(self, id, now, region, service, hasher):
h = hasher()
h.update(self.method.upper())
h.update('\n')
h.update(self.request_url)
h.update('\n')
h.update(self.query_string or '')
h.update('\n')
signed_headers = self._get_signed_headers(id, now, region, service)
for _, lk, v in signed_headers:
h.update('%s:%s\n' % (lk, v.strip()))
h.update('\n')
signed_headers_str = ';'.join(set(lk for _, lk, _ in signed_headers))
h.update(signed_headers_str)
h.update('\n')
content_hash = hasher(self.payload or '').hexdigest()
h.update(content_hash)
return h.hexdigest(), content_hash, signed_headers, signed_headers_str
def _seal(self, hash_method, id, now, region, service, secret):
hasher_name, content_hash_header_name, hasher = hashers[hash_method]
signing_key = hmac.HMAC(
hmac.HMAC(
hmac.HMAC(
hmac.HMAC(
'AWS4%s' % secret,
self.format_amz_date(now)[0:8],
hasher
).digest(),
region,
hasher
).digest(),
service,
hasher
).digest(),
"aws4_request",
hasher
).digest()
h = hmac.HMAC(signing_key, None, hasher)
h.update(hasher_name)
h.update('\n')
h.update(self.format_amz_date(now))
h.update('\n')
credential_scope = self._get_credential_scope(now, region, service)
h.update(credential_scope)
h.update('\n')
hashed_canonical_request, content_hash, signed_headers, signed_headers_str = self._get_hashed_canonical_request(id, now, region, service, hasher)
h.update(hashed_canonical_request)
return dict(
signature=h.hexdigest(),
hasher_name=hasher_name,
content_hash_header_name=content_hash_header_name,
content_hash=content_hash,
credential_scope=credential_scope,
signed_headers=signed_headers,
signed_headers_str=signed_headers_str
)
def _put_request(self, hash_method, id, now, region, service, expiry, secret, http):
real_request_url = self.request_url
if self.query_string is not None:
real_request_url += '?' + self.query_string
seal = self._seal(hash_method, id, now, region, service, secret)
authorization_header_value = '{hasher_name} ' \
'Credential={id}/{credential_scope},' \
'SignedHeaders={signed_headers_str},' \
'Signature={signature}' \
.format(id=id, **seal)
http.putrequest(
self.method,
real_request_url,
True,
True
)
for k, _, v in seal['signed_headers']:
http.putheader(k, v)
for k, v in self._headers:
if k.lower() not in self.reserved_headers:
http.putheader(k, v)
http.putheader('X-Amz-Algorithm', seal['hasher_name'])
http.putheader('X-Amz-Signature', seal['signature'])
if expiry is not None:
http.putheader('X-Amz-Expires', '%d' % expiry)
http.putheader('Authorization', authorization_header_value)
http.putheader(seal['content_hash_header_name'], seal['content_hash'])
http.endheaders()
http.send(self.payload or '')
class AWSAPI(object):
def __init__(self, id, secret, hash_method='sha256', region='us-east-1', now_getter=lambda _:datetime.now()):
self.id = id
self.secret = secret
self.hash_method = hash_method
self.region = region
self.now_getter = now_getter
def do_request(self, request, service, expiry=86400):
http = httplib.HTTPSConnection(request.host)
now = self.now_getter(self)
now = denaive_datetime(now)
request._put_request(
now=now,
id=self.id,
hash_method=self.hash_method,
region=self.region,
service=service,
expiry=expiry,
secret=self.secret,
http=http
)
return http.getresponse()
class RDSEndpointURLBUilder(object):
version = 'v13'
def __init__(self, host):
self.host = host
def download_complete_log_file(self, db_instance_identifier, log_file_name):
return 'https://{host}/{version}/downloadCompleteLogFile/{DBInstanceIdentifier}/{LogFileName}'.format(
host=self.host,
version=self.version,
DBInstanceIdentifier=db_instance_identifier,
LogFileName=log_file_name
)
class AWSEndpoints(object):
endpoints = {
'rds': (
RDSEndpointURLBUilder,
{
'us-east-1': 'rds.us-east-1.amazonaws.com',
'us-west-2': 'rds.us-west-2.amazonaws.com',
'us-west-1': 'rds.us-west-1.amazonaws.com',
'eu-west-1': 'rds.eu-west-1.amazonaws.com',
'ap-southeast-1': 'rds.ap-southeast-1.amazonaws.com',
'ap-southeast-2': 'rds.ap-southeast-2.amazonaws.com',
'ap-northeast-1': 'rds.ap-northeast-1.amazonaws.com',
'sa-east-1': 'rds.sa-east-1.amazonaws.com',
}
)
}
def __init__(self):
self.instances = {}
def get_endpoint(self, region, service):
return self.endpoints[service][1][region]
def get_endpoint_url_builder(self, region, service):
url_builder = None
url_builders_for_service = self.instances.get(service)
if url_builders_for_service is not None:
url_builder = url_builders_for_service.get(service)
if url_builder is None:
url_builder_factory = self.endpoints[service][0]
url_builder = url_builder_factory(self.endpoints[service][1][region])
if url_builders_for_service is None:
url_builders_for_service = self.instances[service] = {}
url_builders_for_service[region] = url_builder
return url_builder
def main(argv):
import os, sys, shutil
region = os.environ.get('AWS_REGION', 'us-east-1')
aws = AWSAPI(
id=os.environ['AWS_ACCESS_KEY_ID'],
secret=os.environ['AWS_SECRET_ACCESS_KEY'],
region=region
)
endpoints = AWSEndpoints()
service = 'rds'
endpoints_for_rds = endpoints.get_endpoint_url_builder(region, service)
url = endpoints_for_rds.download_complete_log_file(argv[1], argv[2])
parsed_url = urlparse.urlparse(url)
resp = aws.do_request(
AWSHTTPRequest(
method='GET',
host=parsed_url.netloc,
request_url=parsed_url.path,
query_string=parsed_url.query or None,
),
service=service
)
if resp.status != 200:
print(resp.read(), file=sys.stderr)
else:
shutil.copyfileobj(resp, sys.stdout)
if __name__ == '__main__':
import sys
main(sys.argv)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment