create database test_db;
create table t_random as select s, md5(random()::text) from generate_Series(1,5000) s;
In [1]: df=spark.read.jdbc(url="jdbc:postgresql://localhost:5432/test_db", table="t_random", properties={"driver": "org.postgresql.Driver"}).repartition(10)
In [2]: row_count = spark.sparkContext.accumulator(0)
In [3]: def onEachPart(part):
...: count = 0
...: for row in part:
...: count += 1
...: yield row
...: print("Add " + str(count))
...: row_count.add(count)
...:
In [4]: df = df.rdd.mapPartitions(onEachPart).toDF()
In [5]: df.write.parquet("/tmp/t_random12345")
def get_df_stats(df: DataFrame, spark: SparkSession) -> (DataFrame, Dict[str, Accumulator]):
row_count: Accumulator = spark.sparkContext.accumulator(0)
size: Accumulator = spark.sparkContext.accumulator(0)
def onEachPart(part):
count = 0
size_in_bytes = 0
for row in part:
size_in_bytes += reduce(lambda a, b: a + b, map(lambda x: sys.getsizeof(x), list(row)))
count += 1
yield row
row_count.add(count)
size.add(size_in_bytes)
counters = {
"row_count": row_count,
"size": size
}
return df.rdd.mapPartitions(onEachPart).toDF(), counters
Add 500
Add 500
Add 500
Add 500
Add 500
Add 500
Add 500
Add 500
Add 500
Add 500
In [6]: row_count.value
Out[6]: 5000