Created
July 20, 2020 20:29
-
-
Save cartershanklin/954bcc146e19e53477d45c513cf1d0b4 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import argparse | |
import os | |
import sys | |
import oci | |
import oci_dataflow | |
from pyspark import SparkConf | |
from pyspark.sql.functions import udf | |
from pyspark.sql.types import SparkSession, StringType | |
def am_in_dataflow(): | |
if os.environ.get("HOME") == "/home/dataflow": | |
return True | |
return False | |
def get_dataflow_spark_session(file_location=None, profile_name=None): | |
if am_in_dataflow(): | |
spark = SparkSession.builder.appName("adw_example").getOrCreate() | |
else: | |
# Import OCI. | |
try: | |
import oci | |
except: | |
raise Exception( | |
"You need to install the OCI python library to use oci_dataflow locally" | |
) | |
# Use defaults for anything unset. | |
if file_location is None: | |
file_location = oci.config.DEFAULT_LOCATION | |
if profile_name is None: | |
profile_name = oci.config.DEFAULT_PROFILE | |
# Load the config file. | |
try: | |
oci_config = oci.config.from_file( | |
file_location=file_location, profile_name=profile_name | |
) | |
except Exception as e: | |
print( | |
"You need to set up your OCI config properly to use oci_dataflow locally" | |
) | |
raise e | |
conf = SparkConf() | |
conf.set("fs.oci.client.auth.tenantId", oci_config["tenancy"]) | |
conf.set("fs.oci.client.auth.userId", oci_config["user"]) | |
conf.set("fs.oci.client.auth.fingerprint", oci_config["fingerprint"]) | |
conf.set("fs.oci.client.auth.pemfilepath", oci_config["key_file"]) | |
conf.set( | |
"fs.oci.client.hostname", | |
"https://objectstorage.{0}.oraclecloud.com".format(oci_config["region"]), | |
) | |
spark = ( | |
SparkSession.builder.appName("adw_example").config(conf=conf).getOrCreate() | |
) | |
return spark | |
def get_authenticated_client(spark, client): | |
import oci | |
if os.environ.get("HOME") != "/home/dataflow": | |
# We are running locally, use our API Key. | |
config = oci.config.from_file() | |
authenticated_client = client(config) | |
else: | |
# We are running in Data Flow, use our Delegation Token. | |
conf = spark.sparkContext.getConf() | |
token_path = conf.get("spark.hadoop.fs.oci.client.auth.delegationTokenPath") | |
with open(token_path) as fd: | |
delegation_token = fd.read() | |
signer = oci.auth.signers.InstancePrincipalsDelegationTokenSigner( | |
delegation_token=delegation_token | |
) | |
authenticated_client = client(config={}, signer=signer) | |
return authenticated_client | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--input-bucket", default="input_sample_data") | |
parser.add_argument("--output-bucket", default="output_sample_data") | |
parser.add_argument("--reset", action="store_true") | |
args = parser.parse_args() | |
input_bucket = args.input_bucket | |
output_bucket = args.output_bucket | |
target_file = "joined.csv" | |
if args.reset: | |
print("Resetting output data.") | |
command = f"oci os object bulk-delete --bucket-name {output_bucket} --prefix {target_file} --force" | |
retval = os.system(command) | |
sys.exit(retval) | |
# Get our Spark session. | |
spark = oci_dataflow.get_dataflow_spark_session() | |
# Get our OCI Object Storage Namespace. | |
client = oci_dataflow.get_authenticated_client( | |
spark, oci.object_storage.ObjectStorageClient | |
) | |
namespace = client.get_namespace().data | |
# Generate URIs of our CSV files. | |
files = ["auto-mpg.csv", "manufacturers.csv"] | |
uris = {file: f"oci://{input_bucket}@{namespace}/{file}" for file in files} | |
# Load our DataFrames. | |
print("Loading MPG information from " + uris["auto-mpg.csv"]) | |
auto_mpg_df = ( | |
spark.read.format("csv").option("header", True).load(uris["auto-mpg.csv"]) | |
) | |
print("Loading manufacturer information from " + uris["manufacturers.csv"]) | |
manufacturers_df = ( | |
spark.read.format("csv").option("header", True).load(uris["manufacturers.csv"]) | |
) | |
# Add a manufacturers column, to join with the manufacturers list. | |
first_word_udf = udf(lambda x: x.split()[0], StringType()) | |
auto_mpg_df = auto_mpg_df.withColumn( | |
"manufacturer", first_word_udf(auto_mpg_df.carname) | |
) | |
# Join the DataFrames. | |
joined = auto_mpg_df.join(manufacturers_df, "manufacturer") | |
# Output the results. | |
output_uri = f"oci://{output_bucket}@{namespace}/joined.csv" | |
print("Writing joined DataFrame to " + output_uri) | |
joined.coalesce(1).write.csv(output_uri, header="true") | |
print("Wrote {} rows".format(joined.count())) | |
main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-i https://pypi.org/simple | |
certifi==2020.6.20 | |
cffi==1.14.0 | |
configparser==4.0.2 | |
cryptography==2.8 | |
oci==2.18.0 | |
pycparser==2.20 | |
pyopenssl==19.1.0 | |
python-dateutil==2.8.1 | |
pytz==2020.1 | |
six==1.15.0 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment