Created
April 5, 2018 21:42
-
-
Save rchamarthi/4f1fbb3f79048df655cf3f5f4437db31 to your computer and use it in GitHub Desktop.
Airflow S3 cross account copy operator
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
from tempfile import NamedTemporaryFile | |
from airflow.exceptions import AirflowException | |
from airflow.hooks.S3_hook import S3Hook | |
from airflow.models import BaseOperator | |
from airflow.utils.decorators import apply_defaults | |
class S3CopyOperator(BaseOperator): | |
""" | |
Copies data from a source to a target s3 location. | |
This is useful when source and destination buckets are in different accounts and | |
access is provided using two sets of AWS keys instead of cross-account access policies. | |
:param source_s3_key: The key to be retrieved from S3 | |
:type source_s3_key: str | |
:param source_aws_conn_id: source s3 connection | |
:type source_aws_conn_id: str | |
:param dest_s3_key: The key to be written from S3 | |
:type dest_s3_key: str | |
:param dest_aws_conn_id: destination s3 connection | |
:type dest_aws_conn_id: str | |
:param replace: Replace dest S3 key if it already exists | |
:type replace: bool | |
""" | |
template_fields = ('source_s3_key', 'dest_s3_key') | |
template_ext = () | |
ui_color = '#f9c915' | |
@apply_defaults | |
def __init__( | |
self, | |
source_s3_key, | |
dest_s3_key, | |
source_aws_conn_id='aws_default', | |
dest_aws_conn_id='aws_default', | |
replace=False, | |
*args, **kwargs): | |
super(S3CopyOperator, self).__init__(*args, **kwargs) | |
self.source_s3_key = source_s3_key | |
self.source_aws_conn_id = source_aws_conn_id | |
self.dest_s3_key = dest_s3_key | |
self.dest_aws_conn_id = dest_aws_conn_id | |
self.replace = replace | |
def execute(self, context): | |
source_s3 = S3Hook(s3_conn_id=self.source_aws_conn_id) | |
dest_s3 = S3Hook(s3_conn_id=self.dest_aws_conn_id) | |
logging.info("Downloading source S3 file %s", self.source_s3_key) | |
if not source_s3.check_for_key(self.source_s3_key): | |
raise AirflowException("The source key {0} does not exist".format(self.source_s3_key)) | |
source_s3_key_object = source_s3.get_key(self.source_s3_key) | |
dest_s3.load_string( | |
string_data=source_s3_key_object.get_contents_as_string(), | |
key=self.dest_s3_key, | |
replace=self.replace | |
) | |
logging.info("Copy successful") | |
source_s3.connection.close() | |
dest_s3.connection.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment