Last active
April 6, 2016 14:30
-
-
Save opikalo/98291082e9d935fa196c58cd6744e0a9 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# encoding: utf-8 | |
# This file lives in tests/project_test.py in the usual disutils structure | |
# Remember to set the SPARK_HOME evnironment variable to the path of your spark installation | |
import logging | |
import sys | |
import unittest | |
from nose.tools import eq_, set_trace | |
def add_pyspark_path(): | |
""" | |
Add PySpark to the PYTHONPATH | |
Thanks go to this project: https://github.com/holdenk/sparklingpandas | |
""" | |
import sys | |
import os | |
try: | |
sys.path.append(os.path.join(os.environ['SPARK_HOME'], "python")) | |
sys.path.append(os.path.join(os.environ['SPARK_HOME'], | |
"python","lib","py4j-0.9-src.zip")) | |
except KeyError: | |
print "SPARK_HOME not set" | |
sys.exit(1) | |
add_pyspark_path() # Now we can import pyspark | |
from pyspark import SparkContext | |
from pyspark import SparkConf | |
from pyspark.sql import SQLContext, HiveContext | |
from pyspark.sql.window import Window | |
import pyspark.sql.functions as func | |
def quiet_py4j(): | |
""" turn down spark logging for the test context """ | |
logger = logging.getLogger('py4j') | |
logger.setLevel(logging.WARN) | |
class GSparkTestCase(unittest.TestCase): | |
def setUp(self): | |
quiet_py4j() | |
# Setup a new spark context for each test | |
conf = SparkConf() | |
conf.set("spark.executor.memory","1g") | |
conf.set("spark.cores.max", "1") | |
#conf.set("spark.master", "spark://192.168.1.2:7077") | |
conf.set("spark.app.name", "nosetest") | |
self.sc = SparkContext(conf=conf) | |
self.sqlContext = HiveContext(self.sc) | |
def tearDown(self): | |
self.sc.stop() | |
# This would go in tests/project_test.py | |
class BasicSparkTests(GSparkTestCase): | |
def null_test(self): | |
df = self.sqlContext.createDataFrame([ | |
(1, 1, None), | |
(1, 2, 109), | |
(1, 3, None), | |
(1, 4, None), | |
(1, 5, 109), | |
(1, 6, None), | |
(1, 7, 110), | |
(1, 8, None), | |
(1, 9, None), | |
], ("session", "timestamp", "id")) | |
eq_(df.count(), 9) | |
def process(df): | |
df_na = df.na.fill(-1) | |
lag = df_na.withColumn('id_lag', func.lag('id', default=-1)\ | |
.over(Window.partitionBy('session')\ | |
.orderBy('timestamp'))) | |
switch = lag.withColumn('id_change', | |
((lag['id'] != lag['id_lag']) & | |
(lag['id'] != -1)).cast('integer')) | |
switch_sess = switch.withColumn( | |
'sub_session', | |
func.sum("id_change") | |
.over( | |
Window.partitionBy("session") | |
.orderBy("timestamp") | |
.rowsBetween(-sys.maxsize, 0)) | |
) | |
fid = switch_sess.withColumn('nn_id', | |
func.first('id')\ | |
.over(Window.partitionBy('sub_session')\ | |
.orderBy('timestamp'))) | |
fid_na = fid.replace(-1, 'null') | |
ff = fid_na.drop('id').drop('id_lag')\ | |
.drop('id_change')\ | |
.drop('sub_session').\ | |
withColumnRenamed('nn_id', 'id') | |
return ff | |
df_filled = process(df) | |
df_exp = self.sqlContext.createDataFrame([ | |
(1, 1, None), | |
(1, 2, 109), | |
(1, 3, 109), | |
(1, 4, 109), | |
(1, 5, 109), | |
(1, 6, 109), | |
(1, 7, 110), | |
(1, 8, 110), | |
(1, 9, 110), | |
], ("session", "timestamp", "id")) | |
eq_(df_filled.collect(), df_exp.collect()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment