Last active
November 30, 2017 15:17
-
-
Save Swalloow/2d48f8265a7cf0275a78ce602e1635b4 to your computer and use it in GitHub Desktop.
KKBOX Feature Engineering
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pyspark.sql import SparkSession | |
from pyspark.sql.functions import * | |
from pyspark.sql.types import * | |
from pyspark.ml.feature import StringIndexer | |
def preprocess_age(age): | |
if age > 70 or age < 5: | |
return -1 | |
else: | |
return age | |
path = "s3://..." | |
spark = SparkSession.builder.appName("members-prep").getOrCreate() | |
# Read csv and fill missing value | |
member = spark.read.option("header", "true").csv(path + "/members.csv") | |
member = member.fillna('null') | |
# Preprocessing age feature | |
prep_age = udf(lambda a: preprocess_age(a), IntegerType()) | |
member = member.withColumn("age", prep_age(member.bd.cast(IntegerType()))) | |
# Categorical to numeric | |
indexer = StringIndexer(inputCol="gender", outputCol="gender_ix") | |
member = indexer.fit(member).transform(member) | |
# Type checking | |
member = member.drop("bd", "gender") \ | |
.withColumn("registration_init_time", date_format((unix_timestamp(member.registration_init_time, 'yyyyMMdd').cast(TimestampType())), 'yyyy-MM-dd').alias('registration_init_time')) \ | |
.withColumn("expiration_date", date_format((unix_timestamp(member.expiration_date, 'yyyyMMdd').cast(TimestampType())), 'yyyy-MM-dd').alias('expiration_date')) \ | |
.withColumn("city", member.city.cast(IntegerType())) \ | |
.withColumn("gender_ix", member.gender_ix.cast(IntegerType())) \ | |
.withColumn("registered_via", member.registered_via.cast(IntegerType())) \ | |
.cache() | |
# Save to S3 | |
member.printSchema() | |
member.repartition(2).write.mode("overwrite") \ | |
.option("header", "true") \ | |
.option("compression", "gzip") \ | |
.csv(path + "/members_prep", sep=",") | |
spark.stop() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from datetime import datetime, timedelta | |
from dateutil.relativedelta import relativedelta | |
from pyspark.sql import SparkSession | |
from pyspark.sql.functions import * | |
# Feb Churn user | |
churn_month = '2017-02-01' | |
dt = datetime.strptime(churn_month, "%Y-%m-%d") | |
before_day = datetime.strftime(dt - timedelta(days=1), "%Y-%m-%d") | |
before_one = datetime.strftime(dt - relativedelta(month=1), "%Y-%m-%d") | |
before_half = datetime.strftime(dt - timedelta(days=30*6), "%Y-%m-%d") | |
before_year = datetime.strftime(dt.replace(year=dt.year-1), "%Y-%m-%d") | |
path = "s3://..." | |
spark = SparkSession.builder.appName("train-feature").getOrCreate() | |
# Read csv | |
train = spark.read.option("header", "true").csv(path + "/train.csv").cache() | |
print("train count : {}".format(train.count())) | |
train.createOrReplaceTempView("train") | |
member = spark.read.option("header", "true").option("inferSchema", "true").csv(path + "/members_prep").cache() | |
print("member count : {}".format(member.count())) | |
member.createOrReplaceTempView("member") | |
# Transaction window (2016/02 ~ 2017/02) | |
trans = spark.read.option("header", "true").option("inferSchema", "true").csv(path + "/trans_prep") \ | |
.where("transaction_date between '" + before_year + "' and '" + before_day + "'").cache() | |
print("trans count : {}".format(trans.count())) | |
trans.createOrReplaceTempView("trans") | |
# Feature extraction | |
query_mem = """ | |
SELECT a.msno, b.city, b.registered_via, b.registration_init_time, b.expiration_date, b.age, b.gender_ix, | |
datediff(b.expiration_date, b.registration_init_time) as regist_duration, | |
datediff(b.expiration_date, b.registration_init_time) / 365 as long_time_user, a.is_churn | |
FROM train as a | |
LEFT JOIN member as b | |
ON a.msno = b.msno | |
""" | |
mem = spark.sql(query_mem) | |
query_last = """ | |
SELECT *, month(transaction_date) as last_month, dayofmonth(transaction_date) as last_day, | |
date_format(transaction_date, 'EEEE') as last_day_of_week, plan_list_price - actual_amount_paid as last_discount, | |
CASE WHEN plan_list_price - actual_amount_paid = 0 THEN 1 ELSE 0 END as is_discount, | |
round(actual_amount_paid / payment_plan_days, 1) as amt_per_day, | |
datediff(transaction_date, '{}') as last_transaction_diff, | |
datediff(membership_expire_date, transaction_date) as membership_duration | |
FROM | |
( | |
SELECT b.*, row_number() OVER (PARTITION BY a.msno ORDER BY transaction_date DESC) as row_num | |
FROM train as a | |
LEFT JOIN trans as b | |
ON a.msno = b.msno | |
WHERE transaction_date between '{}' and '{}' | |
) | |
WHERE row_num = 1 | |
""" | |
last = spark.sql(query_last.format(churn_month, before_year, before_day)).drop("row_num") | |
query_hist = """ | |
SELECT a.msno, count(*) as transaction_cnt, count(CASE WHEN is_cancel = 1 THEN 1 END) as is_cancel_cnt | |
FROM train as a | |
LEFT JOIN | |
( | |
SELECT * FROM trans | |
WHERE transaction_date between '{}' and '{}' | |
) as b | |
ON a.msno = b.msno | |
GROUP BY a.msno | |
""" | |
hist = spark.sql(query_hist.format(before_year, before_day)) | |
result = mem.join(last, "msno", "left").join(hist, "msno", "left").cache() | |
print("result count : {}".format(result.count())) | |
# Save to S3 | |
result.printSchema() | |
result.coalesce(1).write.mode("overwrite") \ | |
.option("header", "true") \ | |
.option("compression", "gzip") \ | |
.csv(path + "/train_feature", sep=",") | |
spark.stop() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from datetime import datetime, timedelta | |
from dateutil.relativedelta import relativedelta | |
from pyspark.sql import SparkSession | |
from pyspark.sql.functions import * | |
from pyspark.sql.types import TimestampType | |
# Feb Churn user | |
churn_month = '2017-02-01' | |
dt = datetime.strptime(churn_month, "%Y-%m-%d") | |
before_day = datetime.strftime(dt - timedelta(days=1), "%Y-%m-%d") | |
before_one = datetime.strftime(dt - relativedelta(month=1), "%Y-%m-%d") | |
before_half = datetime.strftime(dt - timedelta(days=30*6), "%Y-%m-%d") | |
path = "s3://..." | |
spark = SparkSession.builder.appName("train-feature").getOrCreate() | |
# Read csv | |
train = spark.read.option("header", "true").option("inferSchema", "true").csv(path + "/train_feature").cache() | |
print("train count : {}".format(train.count())) | |
train.createOrReplaceTempView("train") | |
# Userlog window (2016/09 ~ 2017/02) | |
userlog = spark.read.option("header", "true").csv(path + "/user_logs/*") | |
userlog = userlog.withColumn("date", date_format((unix_timestamp(userlog.date, 'yyyyMMdd').cast(TimestampType())), 'yyyy-MM-dd').alias('date')).cache() | |
print("userlog count : {}".format(userlog.count())) | |
userlog.createOrReplaceTempView("userlog") | |
# Feature extraction | |
query_half = """ | |
SELECT a.msno, sum(num_25) as half_sum_num_25, sum(num_50) as half_sum_num_50, | |
sum(num_75) as half_sum_num_75, sum(num_985) as half_sum_num_985, sum(num_100) as half_sum_num_100, | |
sum(num_unq) as half_sum_num_unq, round(sum(total_secs) / 60, 1) as half_sum_total_min, | |
max(date) as last_listen_date, datediff(max(date), '{}') as last_listen_diff | |
FROM train as a | |
LEFT JOIN | |
( | |
SELECT * FROM userlog | |
WHERE date between '{}' and '{}' | |
) as b | |
ON a.msno = b.msno | |
GROUP BY a.msno | |
""" | |
half = spark.sql(query_half.format(churn_month, before_half, before_day)) | |
query_month = """ | |
SELECT a.msno, sum(num_25) as month_sum_num_25, sum(num_50) as month_sum_num_50, | |
sum(num_75) as month_sum_num_75, sum(num_985) as month_sum_num_985, sum(num_100) as month_sum_num_100, | |
sum(num_unq) as month_sum_num_unq, round(sum(total_secs) / 60, 1) as month_sum_total_min | |
FROM train as a | |
LEFT JOIN | |
( | |
SELECT * FROM userlog | |
WHERE date between '{}' and '{}' | |
) as b | |
ON a.msno = b.msno | |
GROUP BY a.msno | |
""" | |
month = spark.sql(query_month.format(before_one, before_day)) | |
result = half.join(month, "msno", "left").cache() | |
print("result count : {}".format(result.count())) | |
# Save to S3 | |
result.printSchema() | |
result.coalesce(1).write.mode("overwrite") \ | |
.option("header", "true") \ | |
.option("compression", "gzip") \ | |
.csv(path + "/train_feature_last", sep=",") | |
spark.stop() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment