Created
July 31, 2023 20:00
-
-
Save mikegrima/e53b6de4394c1fe4ef4e8eb9185115e8 to your computer and use it in GitHub Desktop.
Paginated and batched AWS Config resource fetching and listing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This gist covers how to: | |
- Make pytest fixtures for querying S3 buckets in AWS Config | |
- This tests the querying AWS Config's aggregated advanced query with full pagination (this works for non-aggregated queries as well) | |
- This also tests batch fetching resource configuration data out of the aggregator (this works for non-aggregated as well) | |
""" | |
import json | |
import os | |
from typing import Any, Dict, Generator, List | |
from unittest.mock import MagicMock | |
import boto3 | |
import pytest | |
from botocore.client import BaseClient | |
from botocore.paginate import PageIterator | |
from moto import mock_config, mock_s3 | |
@pytest.fixture | |
def aws_credentials() -> None: | |
"""Mocked AWS Credentials for moto.""" | |
os.environ["AWS_ACCESS_KEY_ID"] = "testing" | |
os.environ["AWS_SECRET_ACCESS_KEY"] = "testing" | |
os.environ["AWS_SECURITY_TOKEN"] = "testing" | |
os.environ["AWS_SESSION_TOKEN"] = "testing" | |
os.environ["AWS_DEFAULT_REGION"] = "us-east-2" | |
@pytest.fixture | |
def moto_account() -> None: | |
"""This mocks out the Moto account ID for our unit tests to be 000000000001.""" | |
os.environ["MOTO_ACCOUNT_ID"] = "000000000001" | |
yield | |
del os.environ["MOTO_ACCOUNT_ID"] | |
@pytest.fixture | |
def aws_s3(aws_credentials: None, moto_account: None) -> Generator[BaseClient, None, None]: | |
"""This is a fixture for a Moto wrapped AWS S3 mock for the entire unit test.""" | |
with mock_s3(): | |
yield boto3.client("s3", region_name="us-east-1") | |
@pytest.fixture | |
def aws_config(aws_credentials: None, moto_account: None) -> Generator[BaseClient, None, None]: | |
"""This is a fixture for a Moto wrapped AWS Config mock for the entire unit test.""" | |
with mock_config(): | |
# Make the aggregator for the tests named "myaggregator" | |
client = boto3.client("config", region_name="us-east-1") | |
client.put_configuration_aggregator( | |
ConfigurationAggregatorName="myaggregator", AccountAggregationSources=[{"AccountIds": ["000000000001"], "AllAwsRegions": True}] | |
) | |
yield client | |
@pytest.fixture | |
def bucket_list() -> List[str]: | |
"""This is a generated list of the S3 buckets that AWS Config's advanced query would return back.""" | |
buckets = [ | |
json.dumps( | |
{ | |
"accountId": "000000000001", | |
"resourceId": f"bucket-number-{count}", | |
"awsRegion": "us-east-1", | |
"resourceName": f"bucket-number-{count}", | |
"arn": f"arn:aws:s3:::bucket-number-{count}", | |
"configurationItemCaptureTime": "2023-05-16T10:00:01Z", | |
"availabilityZone": "Regional", | |
"version": "1.3", | |
"resourceCreationTime": "2023-05-16T10:00:01Z", | |
"resourceType": "AWS::S3::Bucket", | |
} | |
) | |
for count in range(0, 200) | |
] | |
return buckets | |
@pytest.fixture | |
def config_select_buckets_client(bucket_list: List[str]) -> BaseClient: | |
"""This will mock out the AWS Config select buckets call. This will also paginate on 50 items at a time.""" | |
client = boto3.client("config", region_name="us-east-1") | |
paginator = client.get_paginator("select_aggregate_resource_config") | |
class OurInterator(PageIterator): | |
"""This is a mocked out Boto3 PageIterator that will return the proper mocked out list of S3 buckets.""" | |
def _make_request(self, current_kwargs: Dict[str, Any]) -> Dict[str, Any]: | |
""" | |
This is the mocked out function that makes the request to AWS. This will paginate accordingly. We simply use the string value of the list index as the | |
NextToken. | |
""" | |
if not current_kwargs.get("NextToken"): | |
start_index = 0 | |
else: | |
start_index = int(current_kwargs["NextToken"]) | |
end_index = start_index + 50 | |
result = {"Results": bucket_list[start_index:end_index], "QueryInfo": {"SelectFields": [{"Name": "*"}]}} | |
if end_index < len(bucket_list): | |
result["NextToken"] = str(end_index) | |
return result | |
paginator.PAGE_ITERATOR_CLS = OurInterator # All we need to do is replace the PAGE_ITERATOR_CLS with the one above, and we are good 😎 | |
client.get_paginator = MagicMock(return_value=paginator) | |
return client | |
AWS_CONFIG_BUCKET_QUERY = """ | |
SELECT * WHERE | |
resourceType = 'AWS::S3::Bucket' | |
AND accountId = '{replaceme}' | |
""" | |
def select_buckets(account_id: str) -> List[Dict[str, Any]]: | |
"""This will select the S3 buckets from the AWS Config aggregator.""" | |
client = boto3.client("config", region_name="us-east-1") | |
paginator = client.get_paginator("select_aggregate_resource_config") | |
iterator = paginator.paginate( | |
Expression=AWS_CONFIG_BUCKET_QUERY.replace("{replaceme}", account_id), | |
ConfigurationAggregatorName="myaggregator", | |
) | |
bucket_list = [] | |
for page in iterator: | |
bucket_list += page.get("Results", []) | |
# For un_wrap_json, see this gist: https://github.com/gemini-oss/starfleet/blob/main/src/starfleet/worker_ships/niceties.py#L16 | |
return un_wrap_json(bucket_list) | |
def test_select_buckets(config_select_buckets_client: BaseClient) -> None: | |
"""This tests the Select Buckets code and verifies that it will paginate with AWS Config properly.""" | |
# We are also going to verify that we paginate with the correct arguments passed in: | |
paginator = config_select_buckets_client.get_paginator("select_aggregate_resource_config") | |
wrapped_paginate = MagicMock(side_effect=paginator.paginate) | |
paginator.paginate = wrapped_paginate | |
bucket_result = select_buckets(config_select_buckets_client, "000000000001", "myaggregator") | |
assert len(bucket_result) == 200 | |
for count in range(0, 200): | |
assert bucket_result[count] == { | |
"accountId": "000000000001", | |
"resourceId": f"bucket-number-{count}", | |
"awsRegion": "us-east-1", | |
"resourceName": f"bucket-number-{count}", | |
"arn": f"arn:aws:s3:::bucket-number-{count}", | |
"configurationItemCaptureTime": "2023-05-16T10:00:01Z", | |
"availabilityZone": "Regional", | |
"version": "1.3", | |
"resourceCreationTime": "2023-05-16T10:00:01Z", | |
"resourceType": "AWS::S3::Bucket", | |
} | |
assert wrapped_paginate.call_args.kwargs == { | |
"ConfigurationAggregatorName": "myaggregator", | |
"Expression": "\nSELECT * WHERE\n resourceType = 'AWS::S3::Bucket'\n AND accountId = '000000000001'\n", | |
} | |
def fetch_buckets(bucket_list: List[Dict[str, Any]], config_client: BaseClient, aggregator: str) -> Generator[List[Dict[str, Any]], None, None]: | |
"""This will return a batch of 100 bucket configurations at a time.""" | |
step = 100 | |
for index in range(0, len(bucket_list), step): | |
batch_start = index | |
# Describe the batch: | |
yield get_bucket_configs(bucket_list[batch_start: step + batch_start], config_client, aggregator) # noqa | |
def get_bucket_configs(batch: List[Dict[str, Any]], config_client: BaseClient, aggregator: str) -> List[Dict[str, Any]]: | |
"""This will go out to AWS Config and fetch the bucket configurations from the aggregator.""" | |
resource_identifiers = [ | |
{ | |
"SourceAccountId": bucket["accountId"], | |
"SourceRegion": bucket["awsRegion"], | |
"ResourceId": bucket["resourceId"], | |
"ResourceType": "AWS::S3::Bucket", | |
} | |
for bucket in batch | |
] | |
result = config_client.batch_get_aggregate_resource_config(ConfigurationAggregatorName=aggregator, ResourceIdentifiers=resource_identifiers) | |
config_items = result.get("BaseConfigurationItems", []) | |
# For un_wrap_json, see this gist: https://github.com/gemini-oss/starfleet/blob/main/src/starfleet/worker_ships/niceties.py#L16 | |
return un_wrap_json(config_items) | |
def test_fetch_and_get(aws_config: BaseClient, aws_s3: BaseClient, bucket_list: List[Dict[str, Any]]) -> None: | |
"""This tests the logic's ability to get and fetch AWS Config resource configuration details from the aggregator.""" | |
# For un_wrap_json, see this gist: https://github.com/gemini-oss/starfleet/blob/main/src/starfleet/worker_ships/niceties.py#L16 | |
bucket_list = un_wrap_json(bucket_list) | |
for bucket in bucket_list: | |
aws_s3.create_bucket(Bucket=bucket["resourceId"]) | |
bucket_config_list = list(fetch_buckets(bucket_list, aws_config, "myaggregator")) | |
# Iterate over all the buckets across both batches: | |
assert len(bucket_config_list) == 2 # There are 2 100 sized batches. | |
big_list = bucket_config_list[0] + bucket_config_list[1] | |
for count in range(0, 200): | |
assert big_list[count]["resourceName"] == f"bucket-number-{count}" | |
assert big_list[count]["accountId"] == "000000000001" | |
assert big_list[count]["awsRegion"] == "us-east-1" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment