Created
November 3, 2024 07:52
-
-
Save huynhbaoan/8c218c788068b9d855021189c002209a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
""" | |
Run a CloudWatch Insights query, then save the results as CSV | |
""" | |
from __future__ import print_function | |
import sys | |
import time | |
from datetime import datetime | |
import csv | |
import re | |
import json | |
import boto3 | |
HEADERS = ['tenant', 'clientip', 'host', 'requests'] | |
TENANT_HEADERS = ['tenant', 'host', 'requests'] | |
HOST_HEADERS = ['domain', 'requests'] | |
LOG_CLIENT = boto3.client('logs') | |
LIMIT = 10000 # Max results per query | |
# CloudWatch query to filter specific log data | |
QUERY_STRING = """parse @message /^.*eni-\w+\s+(?<srcAddr>\d+\.\d+\.\d+\.\d+)\s+(?<destAddr>\d+\.\d+\.\d+\.\d+)\s+(?<srcPort>\d+)\s+(?<destPort>\d+)\s+\d+\s+\d+\s+\d+\s+\d+\s+\d+\s+(?<action>\w+)\s+OK/ | |
| filter (destPort="25" or destPort="465") and isIpv4InSubnet(destAddr, "10.39.32.0/23") | |
| stat count(*) as requests by srcAddr, destAddr, destPort""" | |
class TenantCidr(object): | |
""" IPAM CIDR class to check tenant name based on IP address """ | |
def __init__(self, csvfile): | |
name_pattern = re.compile(r'\w+-\w+-.*prod-\w+') | |
pat = re.compile(r'\.\d{1,3}$') | |
self.cidr_maps = {} | |
self.subnet_maps = {} | |
self.src_ip_maps = {} | |
with open(csvfile) as f_handler: | |
f_csv = csv.reader(f_handler) | |
next(f_csv) | |
for row in f_csv: | |
(net_address, net_mask) = row[1].split('/') | |
j = 24 - min(24, int(net_mask)) | |
if j > 4: | |
continue | |
name_tag = row[2].lower().replace(' ', '') | |
acc_short_name = name_pattern.match(name_tag) | |
class_c_address = re.sub(pat, '', net_address) | |
class_b_address = re.sub(pat, '', class_c_address) | |
first_c_network = int(class_c_address.split('.')[2]) | |
for i in range(0, 2 ** j): | |
subnet = '%s.%s' % (class_b_address, (first_c_network + i)) | |
if subnet not in self.cidr_maps: | |
self.cidr_maps[subnet] = { | |
'acc_short_name': acc_short_name, | |
'vpc_subnet_name': name_tag | |
} | |
self.subnet_maps[subnet] = row[1] | |
def get_ip_subnet(self, src_ip): | |
""" Find AWS Tenant Account name based on instance IP address """ | |
pat = re.compile(r'\.\d{1,3}$') | |
if src_ip in self.src_ip_maps: | |
return self.src_ip_maps[src_ip] | |
net_address = re.sub(pat, '', src_ip).replace("\n", "") | |
if net_address in self.subnet_maps: | |
self.src_ip_maps[src_ip] = self.subnet_maps[net_address] | |
return self.subnet_maps[net_address] | |
self.src_ip_maps[src_ip] = 'UnknownAccount' | |
return 'UnknownAccount' | |
def start_query(start_time, end_time, query_string): | |
try: | |
response = LOG_CLIENT.start_query( | |
logGroupName=LOG_GROUP, | |
startTime=start_time, | |
endTime=end_time, | |
queryString=query_string, | |
limit=LIMIT | |
) | |
query_id = response.get('queryId') | |
if query_id: | |
print("\tStarting Query ID:", query_id) | |
return query_id | |
except Exception as e: | |
print("Error starting query:", e) | |
return None | |
def wait_for_all_queries_to_complete(query_ids, timeout=3600): | |
now = time.time() | |
all_queries_complete = False | |
sys.stdout.write('\tQuerying: ') | |
while not all_queries_complete and (time.time() - now) < timeout: | |
all_queries_complete = True | |
for query_id in query_ids: | |
response = LOG_CLIENT.describe_queries(logGroupName=LOG_GROUP) | |
query_status = next((q['status'] for q in response['queries'] if q['queryId'] == query_id), None) | |
if query_status in ['Failed', 'Cancelled', 'Timeout']: | |
print(f"Query {query_id} {query_status}") | |
sys.exit(1) | |
elif query_status != 'Complete': | |
all_queries_complete = False | |
break | |
sys.stdout.write('.') | |
sys.stdout.flush() | |
time.sleep(15) | |
if not all_queries_complete: | |
print("Warning: Some queries did not complete in time.") | |
print("\n\tQuerying time: %s seconds" % (time.time() - now)) | |
def get_query_results(query_id, start_time, end_time): | |
batch_query_results = [] | |
next_token = None | |
total_records = 0 | |
while True: | |
response = LOG_CLIENT.get_query_results(queryId=query_id, nextToken=next_token) | |
# Collect results and count records | |
for result in response['results']: | |
rec = {field['field']: field['value'] for field in result} | |
batch_query_results.append(rec) | |
total_records += 1 | |
# Check for more pages of results | |
next_token = response.get('nextToken') | |
if not next_token: | |
break | |
start_dt = datetime.fromtimestamp(start_time).strftime('%Y-%m-%d %H:%M:%S') | |
end_dt = datetime.fromtimestamp(end_time).strftime('%Y-%m-%d %H:%M:%S') | |
print(f"Query ID {query_id}: Total records retrieved = {total_records} for time range {start_dt} - {end_dt}") | |
return batch_query_results | |
def run_segmented_queries(start, end, query_string, interval=86400): | |
""" Run queries over smaller time intervals if results exceed the limit """ | |
batch_query_results = [] | |
current_start = start | |
while current_start < end: | |
current_end = min(current_start + interval - 1, end) | |
query_id = start_query(current_start, current_end, query_string) | |
# Wait for the query to complete | |
if query_id: | |
wait_for_all_queries_to_complete([query_id]) | |
query_results_current = get_query_results(query_id, current_start, current_end) | |
# Check if limit was reached; if so, split time further | |
if len(query_results_current) == LIMIT: | |
print(f"Limit reached for {datetime.fromtimestamp(current_start)} - {datetime.fromtimestamp(current_end)}, further segmenting") | |
# Recursively segment into smaller intervals (e.g., half the current interval) | |
batch_query_results.extend(run_segmented_queries(current_start, current_end, query_string, interval=interval // 2)) | |
else: | |
batch_query_results.extend(query_results_current) | |
current_start += interval | |
return batch_query_results | |
def cwlog_batch_queries(start_days_ago, end_days_ago, query_string): | |
""" CloudWatch Log Insight queries in daily batch with sub-segmentation for large result sets """ | |
now = int(time.time()) | |
start_time = now - start_days_ago * 86400 | |
query_results = [] | |
for day in range(start_days_ago - end_days_ago): | |
day_start = start_time + day * 86400 | |
day_end = day_start + 86399 | |
query_results.extend(run_segmented_queries(day_start, day_end, query_string)) | |
return query_results | |
def main(): | |
start_days_ago = int(sys.argv[1]) if len(sys.argv) >= 2 else 1 | |
end_days_ago = int(sys.argv[2]) if len(sys.argv) >= 3 else 0 | |
env_type = sys.argv[3] if len(sys.argv) == 4 else 'nonprod' | |
tenant_cidrs = TenantCidr('./ipam_subnets.csv') | |
query_results = cwlog_batch_queries(start_days_ago, end_days_ago, QUERY_STRING) | |
mail_client_25 = {} | |
mail_client_465 = {} | |
for rec in query_results: | |
if rec['destPort'] == '25': | |
mail_client_25[rec['srcAddr']] = rec['destPort'] | |
else: | |
mail_client_465[rec['srcAddr']] = rec['destPort'] | |
print("\nmail client on port 25") | |
print("========================") | |
count_25 = 0 | |
mail_client_25_list = [] | |
for src_ip_25 in mail_client_25: | |
subnet_25 = tenant_cidrs.get_ip_subnet(src_ip_25) | |
print(src_ip_25, " ", subnet_25) | |
mail_client_25_list.append(subnet_25) | |
count_25 += 1 | |
print(f"Number of src IPs from port 25 are {count_25}\n") | |
print("\nmail client on port 465") | |
print("========================") | |
count_465 = 0 | |
mail_client_465_list = [] | |
for src_ip_465 in mail_client_465: | |
subnet_465 = tenant_cidrs.get_ip_subnet(src_ip_465) | |
print(src_ip_465, " ", subnet_465) | |
mail_client_465_list.append(subnet_465) | |
count_465 += 1 | |
print(f"Number of src IPs from port 465 are {count_465}\n") | |
print("========================") | |
print("Prod whitelist Port 25") | |
print("========================") | |
uniqlist_25 = list(set(mail_client_25_list)) | |
print(*uniqlist_25, sep="\n") | |
print(f"Total number is {len(mail_client_25_list)}\n") | |
print(f"Total unique number is {len(uniqlist_25)}\n") | |
print("========================") | |
print("Prod whitelist Port 465") | |
print("========================") | |
uniqlist_465 = list(set(mail_client_465_list)) | |
print(*uniqlist_465, sep="\n") | |
print(f"Total number is {len(mail_client_465_list)}\n") | |
print(f"Total unique number is {len(uniqlist_465)}\n") | |
sys.exit(0) | |
if __name__ == '__main__': | |
try: | |
main() | |
except KeyboardInterrupt: | |
print("User interrupted") | |
sys.exit(0) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment