Last active
June 26, 2023 21:40
-
-
Save rom1504/8d4b2536a95bd0f8ef4c8a0b20faaaf9 to your computer and use it in GitHub Desktop.
spark_session_aws.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pyspark.sql import SparkSession | |
import os | |
import sys | |
from pyspark import SparkContext | |
from pyspark.sql.functions import rand | |
from pyspark.sql import SparkSession | |
import random | |
import math | |
import time | |
import boto3 | |
def aws_ec2_s3_spark_session(master, num_cores=128, mem_gb=256): | |
"""Build a spark session on AWS EC2""" | |
# .aws/sparkconfig should be the minimal profile | |
os.environ["AWS_CONFIG_FILE"]=os.path.expanduser('~') + "/.aws/sparkconfig" | |
session = boto3.session.Session(profile_name='default') | |
sts_connection = session.client('sts') | |
response = sts_connection.assume_role(RoleArn='arn:aws:iam::842865360552:role/s3_access_from_ec2', RoleSessionName='hi',DurationSeconds=12*3600) | |
credentials = response['Credentials'] | |
os.environ["PYSPARK_PYTHON"] = sys.executable | |
os.environ["PYSPARK_DRIVER_PYTHON"] = sys.executable | |
main_memory = str(int(mem_gb * 0.9)) + "g" | |
memory_overhead = str(mem_gb - int(mem_gb * 0.9)) + "g" | |
spark = ( | |
SparkSession.builder.config("spark.submit.deployMode", "client") | |
.config("spark.driver.cores", "20") | |
.config("spark.driver.memory", "50g") | |
.config("spark.driver.maxResultSize", "10g") | |
.config("spark.executor.memory", main_memory) | |
.config("spark.executor.cores", str(num_cores)) # this can be set to the number of cores of the machine | |
.config("spark.task.cpus", "1") | |
.config("spark.executor.memoryOverhead", memory_overhead) | |
.config("spark.task.maxFailures", "10") | |
# com.amazonaws:aws-java-sdk-bundle:1.12.353, | |
.config("spark.jars.packages", "org.apache.hadoop:hadoop-aws:3.3.1,org.apache.spark:spark-hadoop-cloud_2.13:3.3.1,com.amazonaws:aws-java-sdk-bundle:1.12.353") | |
# change to the appropriate auth method, see https://hadoop.apache.org/docs/stable/hadoop-aws/tools/hadoop-aws/index.html | |
.config("spark.hadoop.fs.s3a.access.key", credentials["AccessKeyId"]) | |
.config("spark.hadoop.fs.s3a.secret.key", credentials["SecretAccessKey"]) | |
.config("spark.hadoop.fs.s3a.session.token", credentials["SessionToken"]) | |
# ton of options to try and make s3a run faster | |
.config("spark.hadoop.fs.s3a.threads.max", "512") | |
.config("spark.hadoop.fs.s3a.connection.maximum", "2048") | |
.config("spark.hadoop.fs.s3a.fast.upload", "true") | |
.config("spark.sql.shuffle.partitions", "40000") | |
.config("spark.hadoop.fs.s3a.directory.marker.retention", "keep") | |
.config("spark.hadoop.fs.s3a.max.total.tasks", "512") | |
.config("spark.hadoop.fs.s3a.multipart.threshold", "5M") | |
.config("spark.hadoop.fs.s3a.multipart.size", "5M") | |
.config("spark.hadoop.fs.s3a.fast.upload.active.blocks", "512") | |
.config("spark.hadoop.fs.s3a.connection.establish.timeout", "5000") | |
.config("spark.hadoop.fs.s3a.connection.timeout", "600000") | |
.config("spark.hadoop.fs.s3a.readahead.range", "2M") | |
.config("spark.hadoop.fs.s3a.socket.recv.buffer", "65536") | |
.config("spark.hadoop.fs.s3a.socket.send.buffer", "65536") | |
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") | |
.config("spark.hadoop.fs.s3a.experimental.input.fadvise", "random") | |
.config("spark.hadoop.fs.s3a.block.size", "2M") | |
.config("spark.hadoop.fs.s3a.fast.buffer.size", "100M") | |
.config("spark.hadoop.fs.s3a.fast.upload.buffer", "array") | |
.config("spark.hadoop.fs.s3a.bucket.all.committer.magic.enabled", "true") | |
.master(master) # this should be set to the spark master url | |
.appName("cc2dataset") | |
.getOrCreate() | |
) | |
return spark | |
m = "spark://26.0.128.157:7077" | |
spark = aws_ec2_s3_spark_session(m, 48, 256) | |
p = "s3://s-laion/cc-proc-test/outputs_audio/2022-12-22-02-47-52/merged/" | |
p = p.replace("s3","s3a") | |
df = spark.read.parquet(p) | |
df.count() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment