Skip to content

Instantly share code, notes, and snippets.

@huynhbaoan
Created November 3, 2024 07:52
Show Gist options
  • Save huynhbaoan/8c218c788068b9d855021189c002209a to your computer and use it in GitHub Desktop.
Save huynhbaoan/8c218c788068b9d855021189c002209a to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
"""
Run a CloudWatch Insights query, then save the results as CSV
"""
from __future__ import print_function
import sys
import time
from datetime import datetime
import csv
import re
import json
import boto3
HEADERS = ['tenant', 'clientip', 'host', 'requests']
TENANT_HEADERS = ['tenant', 'host', 'requests']
HOST_HEADERS = ['domain', 'requests']
LOG_CLIENT = boto3.client('logs')
LIMIT = 10000 # Max results per query
# CloudWatch query to filter specific log data
QUERY_STRING = """parse @message /^.*eni-\w+\s+(?<srcAddr>\d+\.\d+\.\d+\.\d+)\s+(?<destAddr>\d+\.\d+\.\d+\.\d+)\s+(?<srcPort>\d+)\s+(?<destPort>\d+)\s+\d+\s+\d+\s+\d+\s+\d+\s+\d+\s+(?<action>\w+)\s+OK/
| filter (destPort="25" or destPort="465") and isIpv4InSubnet(destAddr, "10.39.32.0/23")
| stat count(*) as requests by srcAddr, destAddr, destPort"""
class TenantCidr(object):
""" IPAM CIDR class to check tenant name based on IP address """
def __init__(self, csvfile):
name_pattern = re.compile(r'\w+-\w+-.*prod-\w+')
pat = re.compile(r'\.\d{1,3}$')
self.cidr_maps = {}
self.subnet_maps = {}
self.src_ip_maps = {}
with open(csvfile) as f_handler:
f_csv = csv.reader(f_handler)
next(f_csv)
for row in f_csv:
(net_address, net_mask) = row[1].split('/')
j = 24 - min(24, int(net_mask))
if j > 4:
continue
name_tag = row[2].lower().replace(' ', '')
acc_short_name = name_pattern.match(name_tag)
class_c_address = re.sub(pat, '', net_address)
class_b_address = re.sub(pat, '', class_c_address)
first_c_network = int(class_c_address.split('.')[2])
for i in range(0, 2 ** j):
subnet = '%s.%s' % (class_b_address, (first_c_network + i))
if subnet not in self.cidr_maps:
self.cidr_maps[subnet] = {
'acc_short_name': acc_short_name,
'vpc_subnet_name': name_tag
}
self.subnet_maps[subnet] = row[1]
def get_ip_subnet(self, src_ip):
""" Find AWS Tenant Account name based on instance IP address """
pat = re.compile(r'\.\d{1,3}$')
if src_ip in self.src_ip_maps:
return self.src_ip_maps[src_ip]
net_address = re.sub(pat, '', src_ip).replace("\n", "")
if net_address in self.subnet_maps:
self.src_ip_maps[src_ip] = self.subnet_maps[net_address]
return self.subnet_maps[net_address]
self.src_ip_maps[src_ip] = 'UnknownAccount'
return 'UnknownAccount'
def start_query(start_time, end_time, query_string):
try:
response = LOG_CLIENT.start_query(
logGroupName=LOG_GROUP,
startTime=start_time,
endTime=end_time,
queryString=query_string,
limit=LIMIT
)
query_id = response.get('queryId')
if query_id:
print("\tStarting Query ID:", query_id)
return query_id
except Exception as e:
print("Error starting query:", e)
return None
def wait_for_all_queries_to_complete(query_ids, timeout=3600):
now = time.time()
all_queries_complete = False
sys.stdout.write('\tQuerying: ')
while not all_queries_complete and (time.time() - now) < timeout:
all_queries_complete = True
for query_id in query_ids:
response = LOG_CLIENT.describe_queries(logGroupName=LOG_GROUP)
query_status = next((q['status'] for q in response['queries'] if q['queryId'] == query_id), None)
if query_status in ['Failed', 'Cancelled', 'Timeout']:
print(f"Query {query_id} {query_status}")
sys.exit(1)
elif query_status != 'Complete':
all_queries_complete = False
break
sys.stdout.write('.')
sys.stdout.flush()
time.sleep(15)
if not all_queries_complete:
print("Warning: Some queries did not complete in time.")
print("\n\tQuerying time: %s seconds" % (time.time() - now))
def get_query_results(query_id, start_time, end_time):
batch_query_results = []
next_token = None
total_records = 0
while True:
response = LOG_CLIENT.get_query_results(queryId=query_id, nextToken=next_token)
# Collect results and count records
for result in response['results']:
rec = {field['field']: field['value'] for field in result}
batch_query_results.append(rec)
total_records += 1
# Check for more pages of results
next_token = response.get('nextToken')
if not next_token:
break
start_dt = datetime.fromtimestamp(start_time).strftime('%Y-%m-%d %H:%M:%S')
end_dt = datetime.fromtimestamp(end_time).strftime('%Y-%m-%d %H:%M:%S')
print(f"Query ID {query_id}: Total records retrieved = {total_records} for time range {start_dt} - {end_dt}")
return batch_query_results
def run_segmented_queries(start, end, query_string, interval=86400):
""" Run queries over smaller time intervals if results exceed the limit """
batch_query_results = []
current_start = start
while current_start < end:
current_end = min(current_start + interval - 1, end)
query_id = start_query(current_start, current_end, query_string)
# Wait for the query to complete
if query_id:
wait_for_all_queries_to_complete([query_id])
query_results_current = get_query_results(query_id, current_start, current_end)
# Check if limit was reached; if so, split time further
if len(query_results_current) == LIMIT:
print(f"Limit reached for {datetime.fromtimestamp(current_start)} - {datetime.fromtimestamp(current_end)}, further segmenting")
# Recursively segment into smaller intervals (e.g., half the current interval)
batch_query_results.extend(run_segmented_queries(current_start, current_end, query_string, interval=interval // 2))
else:
batch_query_results.extend(query_results_current)
current_start += interval
return batch_query_results
def cwlog_batch_queries(start_days_ago, end_days_ago, query_string):
""" CloudWatch Log Insight queries in daily batch with sub-segmentation for large result sets """
now = int(time.time())
start_time = now - start_days_ago * 86400
query_results = []
for day in range(start_days_ago - end_days_ago):
day_start = start_time + day * 86400
day_end = day_start + 86399
query_results.extend(run_segmented_queries(day_start, day_end, query_string))
return query_results
def main():
start_days_ago = int(sys.argv[1]) if len(sys.argv) >= 2 else 1
end_days_ago = int(sys.argv[2]) if len(sys.argv) >= 3 else 0
env_type = sys.argv[3] if len(sys.argv) == 4 else 'nonprod'
tenant_cidrs = TenantCidr('./ipam_subnets.csv')
query_results = cwlog_batch_queries(start_days_ago, end_days_ago, QUERY_STRING)
mail_client_25 = {}
mail_client_465 = {}
for rec in query_results:
if rec['destPort'] == '25':
mail_client_25[rec['srcAddr']] = rec['destPort']
else:
mail_client_465[rec['srcAddr']] = rec['destPort']
print("\nmail client on port 25")
print("========================")
count_25 = 0
mail_client_25_list = []
for src_ip_25 in mail_client_25:
subnet_25 = tenant_cidrs.get_ip_subnet(src_ip_25)
print(src_ip_25, " ", subnet_25)
mail_client_25_list.append(subnet_25)
count_25 += 1
print(f"Number of src IPs from port 25 are {count_25}\n")
print("\nmail client on port 465")
print("========================")
count_465 = 0
mail_client_465_list = []
for src_ip_465 in mail_client_465:
subnet_465 = tenant_cidrs.get_ip_subnet(src_ip_465)
print(src_ip_465, " ", subnet_465)
mail_client_465_list.append(subnet_465)
count_465 += 1
print(f"Number of src IPs from port 465 are {count_465}\n")
print("========================")
print("Prod whitelist Port 25")
print("========================")
uniqlist_25 = list(set(mail_client_25_list))
print(*uniqlist_25, sep="\n")
print(f"Total number is {len(mail_client_25_list)}\n")
print(f"Total unique number is {len(uniqlist_25)}\n")
print("========================")
print("Prod whitelist Port 465")
print("========================")
uniqlist_465 = list(set(mail_client_465_list))
print(*uniqlist_465, sep="\n")
print(f"Total number is {len(mail_client_465_list)}\n")
print(f"Total unique number is {len(uniqlist_465)}\n")
sys.exit(0)
if __name__ == '__main__':
try:
main()
except KeyboardInterrupt:
print("User interrupted")
sys.exit(0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment