|
from __future__ import unicode_literals |
|
|
|
import random |
|
from datetime import datetime, timedelta |
|
from pyspark.sql import SparkSession |
|
|
|
|
|
def jdbc_fast_read(sc, table, column, url, properties, numPartitions=None): |
|
spark = SparkSession(sc) |
|
|
|
# get the min & max of the table |
|
size = jdbc_min_max(sc, table, column, url, properties) |
|
|
|
# automatically decide num of partitions |
|
if numPartitions is None: |
|
mib = jdbc_estimate_json_size( |
|
sc=sc, table=table, column=column, url=url, properties=properties, |
|
lowerBound=size.min, upperBound=size.max |
|
) / 1024.0 / 1024.0 |
|
numPartitions = int(mib / 100.0) |
|
numPartitions = min(4000, max(10, numPartitions)) |
|
|
|
# pull down the data quickly! |
|
return spark.read.jdbc( |
|
url=url, |
|
properties=properties, |
|
table=table, |
|
column=column, |
|
lowerBound=size.min, |
|
upperBound=size.max, |
|
numPartitions=numPartitions |
|
) |
|
|
|
|
|
def jdbc_min_max(sc, table, column, url, properties): |
|
spark = SparkSession(sc) |
|
return spark.read.jdbc( |
|
url=url, |
|
properties=properties, |
|
table=''' |
|
(select min({column}) as min, max({column}) as max from {table}) a |
|
'''.format(table=table, column=column) |
|
).collect()[0] |
|
|
|
|
|
def jdbc_estimate_json_size( |
|
sc, table, column, url, properties, lowerBound=None, upperBound=None, |
|
sample_size=1000, duration=timedelta(seconds=4) |
|
): |
|
spark = SparkSession(sc) |
|
|
|
if lowerBound is None or upperBound is None: |
|
size = jdbc_min_max(sc, table, column, url, properties) |
|
lowerBound = lowerBound or size.min |
|
upperBound = upperBound or size.max |
|
|
|
sample_rows = 0 |
|
sample_bytes = 0 |
|
end_time = datetime.now() + duration |
|
while datetime.now() < end_time: |
|
start = random.randint(lowerBound, upperBound - sample_size) |
|
end = start + sample_size |
|
|
|
df = spark.read.jdbc(url=url, properties=properties, table='''( |
|
select * from {table} where {column} >= {start} and {column} <= {end}) a |
|
'''.format( |
|
table=table, column=column, start=start, end=end |
|
)) |
|
try: |
|
rows, bytes = df.toJSON().map(lambda x: (1, len(x))).reduce( |
|
lambda (s1, l1), (s2, l2): (s1 + s2, l1 + l2) |
|
) |
|
sample_rows += sample_size |
|
sample_bytes += bytes |
|
except ValueError: |
|
pass |
|
if not sample_rows: |
|
return None |
|
return float(upperBound - lowerBound) * float(sample_bytes / sample_rows) |