Created
July 24, 2020 14:57
-
-
Save andrewgross/c9947006826b61301bdf0b1631e48854 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from urllib.parse import urlparse | |
from pyspark.sql.functions import desc, asc | |
from pyspark.sql.types import ( | |
StructType, | |
StructField, | |
StringType, | |
LongType, | |
TimestampType, | |
) | |
from yipit_databricks_utils.helpers import get_spark_session | |
from yipit_glue.sessions import get_s3_client | |
def browse_s3(s3_path, delimiter="/", token=None): | |
s3_path = _normalize_s3_path(s3_path) | |
bucket, prefix = parse_location(s3_path) | |
resp = _get_s3_objects(bucket, prefix) | |
if resp["IsTruncated"]: | |
print("Limited to 1000 Results") | |
df = _convert_resp_to_dataframe(bucket, resp, delimiter="/") | |
return df | |
def parse_location(s3_location): | |
parsed = urlparse(s3_location) | |
bucket = parsed.netloc | |
prefix = parsed.path | |
prefix = prefix[1:] # Trim leading / | |
return bucket, prefix | |
def _normalize_s3_path(s3_path): | |
""" | |
Normalize by ensuring we start with s3:// and adding a trailing slash if necessary | |
""" | |
if not s3_path.endswith("/"): | |
s3_path += "/" | |
if not (s3_path.startswith("s3://") or s3_path.startswith("dbfs:/")): | |
s3_path = "s3://{}".format(s3_path) | |
return s3_path | |
def _convert_resp_to_dataframe(bucket, resp, delimiter="/"): | |
bucket = _normalize_s3_path(bucket) | |
cleaned_rows = [] | |
for row in resp.get("CommonPrefixes", []): | |
path = row["Prefix"].split(delimiter)[ | |
-2 | |
] # These paths always end with the delimiter | |
cleaned_rows.append( | |
{ | |
"path": path, | |
"type": "Folder", | |
"modified": None, | |
"size": None, | |
"etag": "", | |
"full_path": row["Prefix"], | |
"full_s3_prefix": os.path.join(bucket, row["Prefix"]), | |
} | |
) | |
for row in resp.get("Contents", []): | |
path = row["Key"].split(delimiter)[-1] | |
# Exclude empty paths | |
if not path: | |
continue | |
cleaned_rows.append( | |
{ | |
"path": path, | |
"type": "File", | |
"modified": row["LastModified"], | |
"size": row["Size"], | |
"etag": row["ETag"].replace('"', ""), | |
"full_path": row["Key"], | |
"full_s3_prefix": os.path.join(bucket, row["Key"]), | |
} | |
) | |
schema = _get_s3_preview_schema() | |
spark = get_spark_session() | |
df = spark.createDataFrame(cleaned_rows, schema=schema) | |
return df.sort(desc("type"), asc("path")) | |
def _get_s3_objects(bucket, prefix, delimiter="/", token=None): | |
kwargs = { | |
"Bucket": bucket, | |
"Prefix": prefix, | |
"Delimiter": delimiter, | |
} | |
if token is not None: | |
kwargs["ContinuationToken"] = token | |
client = get_s3_client() | |
resp = client.list_objects_v2(**kwargs) | |
return resp | |
def _get_s3_preview_schema(): | |
columns = [ | |
StructField("path", StringType(), True), | |
StructField("type", StringType(), True), | |
StructField("modified", TimestampType(), True), | |
StructField("size", LongType(), True), | |
StructField("etag", StringType(), True), | |
StructField("full_path", StringType(), True), | |
StructField("full_s3_prefix", StringType(), True), | |
] | |
return StructType(columns) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment