Skip to content

Instantly share code, notes, and snippets.

@lovasoa
Created November 17, 2016 12:53
Show Gist options
  • Save lovasoa/dfe0960c36f6de0acbb882fd72452209 to your computer and use it in GitHub Desktop.
Save lovasoa/dfe0960c36f6de0acbb882fd72452209 to your computer and use it in GitHub Desktop.
Spark star JOIN with small dimension tables
# Spark star JOIN with all dimension tables in cache
# Ophir LOJKINE, 2016
from __future__ import print_function
import sys
from operator import add
from pyspark import SparkContext
sc = SparkContext(appName="OphirJoin")
def readTable(tname):
def parseVal(v):
try:
return int(v)
except ValueError:
return v
lines = sc.textFile("column-db/"+tname+".txt")
return lines.zipWithIndex().map(lambda (line,key):(key+1, parseVal(line)))
products = readTable("products")
tojoin = (
("dimension", "products", lambda p: True),
("dimension", "customers", lambda c: c=="Alice"),
("fact", "orders_quantity", lambda q: q<30),
)
def doJoin((typ, tname, condition)):
table = readTable(tname)
def dofilter((k,v)): return condition(v)
if typ == "dimension":
thash = sc.broadcast(table.filter(dofilter).collectAsMap())
bigTable = readTable("orders_fk_" + tname)
def getVal(v):
res = thash.value.get(v)
return [res] if res != None else []
bigTable = bigTable.flatMapValues(getVal)
else:
bigTable = table.filter(dofilter)
return bigTable
results = reduce(add, map(doJoin, tojoin))
def finalize(values):
return [tuple(values)] if len(values) == len(tojoin) else []
results = results.groupByKey().flatMapValues(finalize)
print(results.collect())
sc.stop()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment