|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
# you may not use this file except in compliance with the License. |
|
# You may obtain a copy of the License at |
|
# |
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
# |
|
# Unless required by applicable law or agreed to in writing, software |
|
# distributed under the License is distributed on an "AS IS" BASIS, |
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
# See the License for the specific language governing permissions and |
|
# limitations under the License. |
|
|
|
# Author: Letenkov, Eugene |
|
# Copyright 2024 Letenkov, Eugene. All rights reserved. |
|
|
|
import os |
|
import boto3 |
|
from fastavro import reader |
|
import io |
|
from concurrent.futures import ThreadPoolExecutor, as_completed |
|
import argparse |
|
import signal |
|
|
|
stop_processing = False |
|
|
|
def signal_handler(sig, frame): |
|
global stop_processing |
|
print('Signal received, stopping...') |
|
stop_processing = True |
|
|
|
def get_all_avro_files(bucket_name, prefix, s3_client): |
|
print(f"Fetching list of Avro files from bucket: {bucket_name} with prefix: {prefix}") |
|
paginator = s3_client.get_paginator('list_objects_v2') |
|
page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=prefix) |
|
|
|
avro_files = [] |
|
|
|
for page in page_iterator: |
|
file_info = page.get('Contents', []) |
|
avro_files.extend([file['Key'] for file in file_info if file['Key'].endswith('.avro')]) |
|
|
|
print(f"Found {len(avro_files)} Avro files.") |
|
return avro_files |
|
|
|
def count_rows_in_avro_file(s3_client, bucket_name, file_key): |
|
if stop_processing: |
|
return 0 |
|
print(f"Processing file: {file_key}\n") |
|
obj = s3_client.get_object(Bucket=bucket_name, Key=file_key) |
|
data = obj['Body'].read() |
|
|
|
with io.BytesIO(data) as file_io: |
|
avro_reader = reader(file_io) |
|
row_count = 0 |
|
for _ in avro_reader: |
|
if stop_processing: |
|
print(f"Interrupted processing of file: {file_key}") |
|
return row_count |
|
row_count += 1 |
|
return row_count |
|
|
|
def process_files(bucket_name, avro_files, num_threads): |
|
s3_client = boto3.client('s3') |
|
total_rows = 0 |
|
|
|
num_files = len(avro_files) |
|
|
|
with ThreadPoolExecutor(max_workers=num_threads) as executor: |
|
futures = {executor.submit(count_rows_in_avro_file, s3_client, bucket_name, file_key): file_key for file_key in avro_files} |
|
completed_files = 0 |
|
|
|
try: |
|
for future in as_completed(futures): |
|
if stop_processing: |
|
print("Process interrupted. Attempting to cancel remaining tasks...") |
|
for f in futures: |
|
f.cancel() |
|
break |
|
try: |
|
total_rows += future.result() |
|
completed_files += 1 |
|
remaining_files = num_files - completed_files |
|
print(f"Processed {completed_files}/{num_files} files. Remaining: {remaining_files}\n") |
|
except Exception as e: |
|
print(f"Error processing file: {e}") |
|
finally: |
|
print("Shutting down executor") |
|
executor.shutdown(wait=False) |
|
|
|
return total_rows |
|
|
|
def main(): |
|
global stop_processing |
|
signal.signal(signal.SIGINT, signal_handler) |
|
|
|
parser = argparse.ArgumentParser(description="Process Avro files in S3 bucket.") |
|
parser.add_argument('--bucket_name', required=True, help='Name of the S3 bucket') |
|
parser.add_argument('--prefix', required=True, help='Prefix of the S3 objects') |
|
parser.add_argument('--threads', type=int, default=10, help='Number of threads to use for processing files') |
|
args = parser.parse_args() |
|
|
|
bucket_name = args.bucket_name |
|
prefix = args.prefix |
|
num_threads = args.threads |
|
|
|
s3_client = boto3.client('s3') |
|
|
|
avro_files = get_all_avro_files(bucket_name, prefix, s3_client) |
|
|
|
total_rows = process_files(bucket_name, avro_files, num_threads) |
|
|
|
if not stop_processing: |
|
print(f"Total number of rows across all files: {total_rows}") |
|
else: |
|
print("Processing was interrupted.") |
|
|
|
if __name__ == '__main__': |
|
main() |