Skip to content

Instantly share code, notes, and snippets.

@munro
Created January 27, 2017 20:57
Show Gist options
  • Select an option

  • Save munro/9dbbdcb1c23f9d68f74eb72a0a3869bd to your computer and use it in GitHub Desktop.

Select an option

Save munro/9dbbdcb1c23f9d68f74eb72a0a3869bd to your computer and use it in GitHub Desktop.

jdbc_fast_read

Eases loading data quickly into PySpark, by automatically setting upperBound, lowerBound, and numPartitions for partitioned jdbc reading.

Install package from S3

sc.addPyFile('s3://circleup-oss/jdbc_fast_read-0.0.1-py2.7.egg')

Install package from pip

pip install jdbc_fast_read

Usage

from jdbc_fast_read import jdbc_fast_read

conn = dict(
    url='jdbc:postgresql://HOST/DATABASE',
    properties={'user': 'USER', 'password': 'PASSWORD', 'driver': 'com.mysql.jdbc.Driver'}
)

df = jdbc_fast_read(sc, column='id', table='TABLE', **conn)

License

MIT

from __future__ import unicode_literals
import random
from datetime import datetime, timedelta
from pyspark.sql import SparkSession
def jdbc_fast_read(sc, table, column, url, properties, numPartitions=None):
spark = SparkSession(sc)
# get the min & max of the table
size = jdbc_min_max(sc, table, column, url, properties)
# automatically decide num of partitions
if numPartitions is None:
mib = jdbc_estimate_json_size(
sc=sc, table=table, column=column, url=url, properties=properties,
lowerBound=size.min, upperBound=size.max
) / 1024.0 / 1024.0
numPartitions = int(mib / 100.0)
numPartitions = min(4000, max(10, numPartitions))
# pull down the data quickly!
return spark.read.jdbc(
url=url,
properties=properties,
table=table,
column=column,
lowerBound=size.min,
upperBound=size.max,
numPartitions=numPartitions
)
def jdbc_min_max(sc, table, column, url, properties):
spark = SparkSession(sc)
return spark.read.jdbc(
url=url,
properties=properties,
table='''
(select min({column}) as min, max({column}) as max from {table}) a
'''.format(table=table, column=column)
).collect()[0]
def jdbc_estimate_json_size(
sc, table, column, url, properties, lowerBound=None, upperBound=None,
sample_size=1000, duration=timedelta(seconds=4)
):
spark = SparkSession(sc)
if lowerBound is None or upperBound is None:
size = jdbc_min_max(sc, table, column, url, properties)
lowerBound = lowerBound or size.min
upperBound = upperBound or size.max
sample_rows = 0
sample_bytes = 0
end_time = datetime.now() + duration
while datetime.now() < end_time:
start = random.randint(lowerBound, upperBound - sample_size)
end = start + sample_size
df = spark.read.jdbc(url=url, properties=properties, table='''(
select * from {table} where {column} >= {start} and {column} <= {end}) a
'''.format(
table=table, column=column, start=start, end=end
))
try:
rows, bytes = df.toJSON().map(lambda x: (1, len(x))).reduce(
lambda (s1, l1), (s2, l2): (s1 + s2, l1 + l2)
)
sample_rows += sample_size
sample_bytes += bytes
except ValueError:
pass
if not sample_rows:
return None
return float(upperBound - lowerBound) * float(sample_bytes / sample_rows)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment