Created
December 20, 2016 17:51
-
-
Save Melraidin/7952fbafc4fd8aa0aaf8784d3bb70a3f to your computer and use it in GitHub Desktop.
Athena query base class for Luigi
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Base classes for Athena queries and loads. | |
""" | |
import datetime | |
import jaydebeapi; | |
import luigi | |
import luigi.s3 | |
import sources | |
class AthenaQuery(luigi.Task): | |
""" | |
Base class for Athena queries. | |
Provides methods to execute a query against Athena. Typically the | |
only Luigi methods needed to be implemented will be requires() and | |
run(). Run will usually only require a call to query_store() to | |
generate its output at S3. | |
Note that the ".csv.metadata" key at S3 will be removed to allow | |
later loads to Redshift. This may impact viewing the query's | |
results in the Athena web UI. | |
""" | |
results_path_base = "s3://500px-emr/athena-results/" | |
resources = {"athena_query": 1} | |
date = luigi.DateParameter(default=datetime.date.today()) | |
access_key = luigi.Parameter(default="", | |
config_path={"section": "s3", "name": "aws_access_key_id"}) | |
secret_key = luigi.Parameter(default="", | |
config_path={"section": "s3", "name": "aws_secret_access_key"}) | |
athena_jdbc_driver_path = luigi.Parameter(default="", | |
config_path={"section": "athena", "name": "jdbc_driver_path"}) | |
def __init__(self, *args, **kwargs): | |
super(AthenaQuery, self).__init__(*args, **kwargs) | |
self.results_path = self.results_path_base + self.results_path_template % self.date | |
if self.results_path.endswith("/"): | |
self.results_path = self.results_path.rstrip("/") | |
def run(self): | |
raise NotImplementedError("must override run() method") | |
def _execute_query(self, sql): | |
props = {"s3_staging_dir": self.results_path, | |
"user": self.access_key, | |
"password": self.secret_key} | |
conn = jaydebeapi.connect( | |
"com.amazonaws.athena.jdbc.AthenaDriver", | |
["jdbc:awsathena://athena.us-east-1.amazonaws.com:443"], | |
jars=athena_jdbc_driver_path, | |
props=props) | |
rs = conn.execute(sql) | |
s3_client = luigi.s3.S3Client() | |
metadata = [k for k in s3_client.list(self.results_path) | |
if k.endswith(".csv.metadata")] | |
for k in metadata: | |
s3_client.remove("%s/%s" % (self.results_path, k)) | |
return rs | |
def query_store(self, sql): | |
""" | |
Query Athena and close result set immediately. Results will be | |
at self.results_path. | |
""" | |
rs = self._execute_query(sql) | |
rs.close() | |
def query_retrieve(self, sql): | |
""" | |
Query Athena and return result set. | |
""" | |
rs = self._execute_query(sql) | |
while rs.next(): | |
yield rs | |
def output(self): | |
return luigi.s3.S3Target(self.results_path) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment