Last active
July 21, 2021 23:03
-
-
Save rom1504/5a93e1b9e22b15049de8961c95be17d2 to your computer and use it in GitHub Desktop.
parquet_to_tfrecord_pyspark.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# I advised to run this in an interactive environment (python shell, jupyter, ...) to understand well all the steps | |
from pyspark.sql import SparkSession | |
# Let's get tfrecord and rapids (rapids is not necessary, remove all mention if wanted) | |
# wget https://search.maven.org/remotecontent?filepath=com/linkedin/sparktfrecord/spark-tfrecord_2.12/0.3.2/spark-tfrecord_2.12-0.3.2.jar -O tfrecord.jar | |
# wget https://repo1.maven.org/maven2/com/nvidia/rapids-4-spark_2.12/21.06.0/rapids-4-spark_2.12-21.06.0.jar -O rapids.jar | |
# wget https://repo1.maven.org/maven2/ai/rapids/cudf/21.06.1/cudf-21.06.1-cuda11.jar -O cudf.jar | |
# creating the spark session with tfrecord plugin, rapids plugins, and some basic options, with local executor | |
spark = SparkSession.builder.config("spark.jars", "tfrecord.jar,rapids.jar,cudf.jar").config("spark.plugins","com.nvidia.spark.SQLPlugin").config("spark.rapids.sql.incompatibleOps.enabled","true").config("spark.driver.memory", "16G") .master("local[16]").appName('spark-stats').getOrCreate() | |
# without rapids it would be spark = SparkSession.builder.config("spark.jars", "tfrecord.jar").config("spark.driver.memory", "16G") .master("local[16]").appName('spark-stats').getOrCreate() | |
# example parquets from http://3080.rom1504.fr/uniref90_with_annotations/ | |
# wget --recursive --no-parent -nd -P uniref90_with_annotations http://3080.rom1504.fr/uniref90_with_annotations/ | |
df = spark.read.parquet("uniref90_with_annotations") | |
df = df.select("uniprot_id", "go_annotations", "seq") | |
# the repartition number is the number of output files | |
df.repartition(100).write.mode("overwrite").format("tfrecord").option("recordType", "Example").save("uniref90_tfrecords") | |
# you may want to use .option("compression","gzip"). next to the other option |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment