Skip to content

Instantly share code, notes, and snippets.

@jewer
Last active March 22, 2016 15:39
Show Gist options
  • Select an option

  • Save jewer/db4ac6057884ea4f943b to your computer and use it in GitHub Desktop.

Select an option

Save jewer/db4ac6057884ea4f943b to your computer and use it in GitHub Desktop.
sample logistic regression in pyspark
{
"cells": [
{
"cell_type": "code",
"execution_count": 65,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from pyspark.sql import SQLContext\n",
"\n",
"sqlContext = SQLContext(sc)\n"
]
},
{
"cell_type": "code",
"execution_count": 81,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from pyspark.sql.types import *\n",
"\n",
"customSchema = StructType([ \n",
" StructField(\"SEG\", IntegerType(), True),\n",
"StructField(\"UU_ID\", StringType(), True),\n",
"StructField(\"post_evar3\", StringType(), True),\n",
"StructField(\"MCS_BUSINESS_OWNER_\", IntegerType(), True),\n",
"StructField(\"MCS_BUSINESS_OWNER_CONSUMER_US\", IntegerType(), True),\n",
"StructField(\"MCS_MARKETING_PROGRAM_\", IntegerType(), True),\n",
"StructField(\"MCS_MARKETING_PROGRAM_N3___12FOR12_JUL13\", IntegerType(), True),\n",
"StructField(\"MCS_MARKETING_PROGRAM_DIGITAL_ENTITLEMENTS\", IntegerType(), True),\n",
"StructField(\"MCS_MARKETING_PROGRAM_NP___3FOR1_NOV12\", IntegerType(), True),\n",
"StructField(\"MCS_SUBSCR_STATE_NAME_Active\", IntegerType(), True),\n",
"StructField(\"MCS_SUBSCR_STATE_NAME_Terminated\", IntegerType(), True),\n",
"StructField(\"MCS_SUBSCR_STATE_NAME_Migrated\", IntegerType(), True),\n",
"StructField(\"MCS_CARD_TYPE_\", IntegerType(), True),\n",
"StructField(\"MCS_CARD_TYPE_AX\", IntegerType(), True),\n",
"StructField(\"MCS_CARD_TYPE_MC\", IntegerType(), True),\n",
"StructField(\"MCS_CARD_TYPE_VI\", IntegerType(), True),\n",
"StructField(\"MCS_AUTORENEW_IND_0\", IntegerType(), True),\n",
"StructField(\"MCS_AUTORENEW_IND_1\", IntegerType(), True),\n",
"StructField(\"MCS_BUNDLE_NAME_\", IntegerType(), True),\n",
"StructField(\"MCS_BUNDLE_NAME_Digital_Plus__Print\", IntegerType(), True),\n",
"StructField(\"MCS_BUNDLE_NAME_Digital_Plus\", IntegerType(), True),\n",
"StructField(\"MCS_BUNDLE_NAME_Online__Print\", IntegerType(), True),\n",
"StructField(\"MCS_BUNDLE_NAME_WSJ_Online__Print_Journa\", IntegerType(), True),\n",
"StructField(\"MCS_PROD_TYPE_MOBILE\", IntegerType(), True),\n",
"StructField(\"MCS_PROD_TYPE_PRINT\", IntegerType(), True),\n",
"StructField(\"MCS_PROD_TYPE_ONLINE\", IntegerType(), True),\n",
"StructField(\"MCS_PAY_TYPE_\", IntegerType(), True),\n",
"StructField(\"MCS_PAY_TYPE_Credit\", IntegerType(), True),\n",
"StructField(\"MCS_AUP_TMPL_CD_\", IntegerType(), True),\n",
"StructField(\"MCS_AUP_TMPL_CD_WSJ\", IntegerType(), True),\n",
"StructField(\"MCS_AUP_TMPL_CD_WSJ_IPAD\", IntegerType(), True),\n",
"StructField(\"MCS_SBSCR_TYPE_CODE_ENTITLEMENTS\", IntegerType(), True),\n",
"StructField(\"MCS_SBSCR_TYPE_CODE_MARKETPLACE\", IntegerType(), True),\n",
"StructField(\"MCS_SBSCR_TYPE_CODE_OLF\", IntegerType(), True),\n",
"StructField(\"MCS_TENURE_BY_CUST_MTH_24\", IntegerType(), True),\n",
"StructField(\"MCS_TENURE_BY_CUST_MTH_55\", IntegerType(), True),\n",
"StructField(\"MCS_TENURE_BY_CUST_MTH_36\", IntegerType(), True),\n",
"StructField(\"MCS_TENURE_BY_CUST_MTH_12\", IntegerType(), True),\n",
"StructField(\"MCS_PROD_SUBTYPE_TABLET\", IntegerType(), True),\n",
"StructField(\"MCS_PROD_SUBTYPE_STANDARD\", IntegerType(), True),\n",
"StructField(\"MCS_PROD_SUBTYPE_SMART_PHONE\", IntegerType(), True),\n",
"StructField(\"MCS_CHANNEL_\", IntegerType(), True),\n",
"StructField(\"MCS_CHANNEL_INTERNAL_EMAIL\", IntegerType(), True),\n",
"StructField(\"MCS_CHANNEL_CUSTOMER_SERVICE\", IntegerType(), True),\n",
"StructField(\"MCS_CHANNEL_IN_APP\", IntegerType(), True),\n",
"StructField(\"MCS_CHANNEL_MISC_INTERNET\", IntegerType(), True),\n",
"StructField(\"MCS_CHANNEL_DIRECT_MAIL\", IntegerType(), True),\n",
"StructField(\"MCS_CHANNEL_ONSITE\", IntegerType(), True),\n",
"StructField(\"MCS_GRP_TYPE_NAME_SINGLE\", IntegerType(), True),\n",
"StructField(\"MCS_GRP_TYPE_NAME_BUNDLE\", IntegerType(), True),\n",
"StructField(\"MCS_CAMPAIGN_TYPE_\", IntegerType(), True),\n",
"StructField(\"MCS_CAMPAIGN_TYPE_ACQUISITION\", IntegerType(), True),\n",
"StructField(\"MCS_CAMPAIGN_TYPE_RETENTION\", IntegerType(), True),\n",
"StructField(\"MCS_FREE_PAID_TYPE_\", IntegerType(), True),\n",
"StructField(\"MCS_FREE_PAID_TYPE_IFP\", IntegerType(), True),\n",
"StructField(\"MCS_FREE_PAID_TYPE_FREE\", IntegerType(), True),\n",
"StructField(\"MCS_FREE_PAID_TYPE_PAID\", IntegerType(), True),\n",
"StructField(\"MCS_SBSCR_STATE_CODE_CANCELLED_HARD_DECLINED\", IntegerType(), True),\n",
"StructField(\"MCS_SBSCR_STATE_CODE_CANCELLED_BY_CSR\", IntegerType(), True),\n",
"StructField(\"MCS_SBSCR_STATE_CODE_CANCELLED_MIGRATED\", IntegerType(), True),\n",
"StructField(\"MCS_SBSCR_STATE_CODE_CANCELLED_EXPIRED\", IntegerType(), True),\n",
"StructField(\"MCS_SBSCR_STATE_CODE_ACTIVE\", IntegerType(), True),\n",
"StructField(\"MCS_SBSCR_TERM_TYPE_CODE_FREE\", IntegerType(), True),\n",
"StructField(\"MCS_SBSCR_TERM_TYPE_CODE_PAID\", IntegerType(), True),\n",
"StructField(\"MCS_BRAND_NAME_WSJ\", IntegerType(), True),\n",
"StructField(\"MCS_CAMPAIGN_SUBTYPE_\", IntegerType(), True),\n",
"StructField(\"MCS_CAMPAIGN_SUBTYPE_REGISTRATION\", IntegerType(), True),\n",
"StructField(\"MCS_SBSCR_GRP_TYPE_ID_1\", IntegerType(), True),\n",
"StructField(\"MCS_SBSCR_GRP_TYPE_ID_2\", IntegerType(), True),\n",
"StructField(\"MCS_OFFER_TYPE_UNKNOWN\", IntegerType(), True),\n",
"StructField(\"MCS_OFFER_TYPE_DIGITAL_PLUS\", IntegerType(), True),\n",
"StructField(\"MCS_OFFER_TYPE_DIGITAL_PLUS_PRINT\", IntegerType(), True),\n",
"StructField(\"MCS_OFFER_TYPE_OFFLINE_REGISTRATION\", IntegerType(), True),\n",
"StructField(\"MCS_OFFER_TYPE_ENTITLEMENTS_TO_PRINT\", IntegerType(), True),\n",
"StructField(\"MCS_PROD_NAME_Complementary_Mobile_Subscription\", IntegerType(), True),\n",
"StructField(\"MCS_PROD_NAME_Online\", IntegerType(), True),\n",
"StructField(\"MCS_PROD_NAME_Print_Edition\", IntegerType(), True),\n",
"StructField(\"MCS_PROD_NAME_WSJ_PRINT\", IntegerType(), True),\n",
"StructField(\"MCS_PROD_NAME_Mobile_Reader\", IntegerType(), True),\n",
"StructField(\"MCS_PROD_NAME_Tablet_Edition\", IntegerType(), True),\n",
"StructField(\"MCS_PROD_NAME_WSJ_iPad\", IntegerType(), True),\n",
"StructField(\"Content_searchresults\", IntegerType(), True),\n",
"StructField(\"Content_video\", IntegerType(), True),\n",
"StructField(\"Content_home_page\", IntegerType(), True),\n",
"StructField(\"Content_summaries\", IntegerType(), True),\n",
"StructField(\"Content_\", IntegerType(), True),\n",
"StructField(\"Content_video_emb\", IntegerType(), True),\n",
"StructField(\"Content_marketing_and_support\", IntegerType(), True),\n",
"StructField(\"Content_slideshow\", IntegerType(), True),\n",
"StructField(\"Content_blogs_\", IntegerType(), True),\n",
"StructField(\"Content_article\", IntegerType(), True),\n",
"StructField(\"Content_article_preview\", IntegerType(), True),\n",
"StructField(\"Content_video_embedded_onsite\", IntegerType(), True),\n",
"StructField(\"Content_blogs_article\", IntegerType(), True),\n",
"StructField(\"Content_login\", IntegerType(), True),\n",
"StructField(\"Access_sub\", IntegerType(), True),\n",
"StructField(\"Access_subscriber\", IntegerType(), True),\n",
"StructField(\"Access_free_Search\", IntegerType(), True),\n",
"StructField(\"Access_\", IntegerType(), True),\n",
"StructField(\"Access_free\", IntegerType(), True),\n",
"StructField(\"Access_free_Blogs\", IntegerType(), True),\n",
"StructField(\"Access_free_Multimedia\", IntegerType(), True),\n",
"StructField(\"Access_paid_Article\", IntegerType(), True),\n",
"StructField(\"Access_free_Summaries\", IntegerType(), True),\n",
"StructField(\"Access_free_Home\", IntegerType(), True),\n",
"StructField(\"Access_free_Article\", IntegerType(), True),\n",
"StructField(\"Section_\", IntegerType(), True),\n",
"StructField(\"Section_Home\", IntegerType(), True),\n",
"StructField(\"Section_Home_Page\", IntegerType(), True),\n",
"StructField(\"Section_WSJ_Markets\", IntegerType(), True),\n",
"StructField(\"Section_WSJ_Home_Page_Subscriber\", IntegerType(), True),\n",
"StructField(\"Section_WSJ_US\", IntegerType(), True),\n",
"StructField(\"Section_WSJ_Life_and_Style\", IntegerType(), True),\n",
"StructField(\"Section_WSJ_\", IntegerType(), True),\n",
"StructField(\"Section_WSJ_Life\", IntegerType(), True),\n",
"StructField(\"Section_WSJ_Business\", IntegerType(), True),\n",
"StructField(\"Section_WSJ_Opinion\", IntegerType(), True),\n",
"StructField(\"Section_WSJ_World\", IntegerType(), True),\n",
"StructField(\"Section_WSJ_Tech\", IntegerType(), True),\n",
"StructField(\"Section_Login\", IntegerType(), True),\n",
"StructField(\"Section_WSJ_Home\", IntegerType(), True),\n",
"StructField(\"Section_WSJ_Login\", IntegerType(), True),\n",
"StructField(\"Traffic_searchresults\", IntegerType(), True),\n",
"StructField(\"Traffic_home\", IntegerType(), True),\n",
"StructField(\"Traffic_home_page\", IntegerType(), True),\n",
"StructField(\"Traffic_summaries\", IntegerType(), True),\n",
"StructField(\"Traffic_blogs_\", IntegerType(), True),\n",
"StructField(\"Traffic_\", IntegerType(), True),\n",
"StructField(\"Traffic_marketing_and_support\", IntegerType(), True),\n",
"StructField(\"Traffic_slideshow\", IntegerType(), True),\n",
"StructField(\"Traffic_article\", IntegerType(), True),\n",
"StructField(\"Traffic_account\", IntegerType(), True),\n",
"StructField(\"Traffic_blogs_article\", IntegerType(), True),\n",
"StructField(\"Traffic_login\", IntegerType(), True),\n",
"StructField(\"Channel_Multimedia\", IntegerType(), True),\n",
"StructField(\"Channel_Research_and_Tools\", IntegerType(), True),\n",
"StructField(\"Channel_Customer_Resources\", IntegerType(), True),\n",
"StructField(\"Channel_\", IntegerType(), True),\n",
"StructField(\"Channel_Article\", IntegerType(), True),\n",
"StructField(\"Channel_Search\", IntegerType(), True),\n",
"StructField(\"Channel_Blogs\", IntegerType(), True),\n",
"StructField(\"Channel_Home\", IntegerType(), True),\n",
"StructField(\"Channel_Video\", IntegerType(), True),\n",
"StructField(\"Channel_SummariesStructField\", IntegerType(), True)])\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 67,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"#load training data\n",
"training_df = sqlContext.read \\\n",
" .format('com.databricks.spark.csv') \\\n",
" .load('data/segmentation-training.csv', schema = customSchema)"
]
},
{
"cell_type": "code",
"execution_count": 68,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from pyspark.ml.feature import VectorAssembler\n",
"from pyspark.mllib.regression import LabeledPoint\n",
"\n",
"ignore = ['UU_ID', 'post_evar3', 'SEG']\n",
"\n",
"assembler = VectorAssembler(\n",
" inputCols=[x for x in training_df.columns if x not in ignore],\n",
" outputCol='features')\n",
"\n",
"features_df = assembler.transform(training_df)\n",
"\n",
"features_df = features_df['SEG', 'features'] \\\n",
" .map(lambda x: LabeledPoint(x['SEG'] - 1, x['features']))\n",
" \n",
"train_df, test_df = dfFeatures.randomSplit([0.8, 0.2], seed = 200)"
]
},
{
"cell_type": "code",
"execution_count": 69,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from pyspark.mllib.classification import LogisticRegressionWithLBFGS\n",
"\n",
"lr = LogisticRegressionWithLBFGS.train(train_df, iterations = 10, intercept = True, numClasses = 8)"
]
},
{
"cell_type": "code",
"execution_count": 70,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(3.0,(141,[1,6,8,10,14,20,22,24,26,28,29,35,36,37,45,46,53,56,58,59,60,61,64,65,66,71,72,75,76,77,87,99,113,127,132,135],[3.0,5.0,1.0,3.0,6.0,5.0,1.0,3.0,6.0,3.0,3.0,2.0,1.0,3.0,1.0,5.0,3.0,1.0,5.0,1.0,5.0,6.0,1.0,5.0,3.0,1.0,1.0,2.0,1.0,1.0,1.0,1.0,1.0,1.0,39.0,1.0])) 3\n",
"(0.0,(141,[1,7,8,12,14,16,20,21,22,24,26,28,29,35,36,37,45,46,53,55,56,59,60,61,64,65,66,68,71,72,73,75,76,80,90,97,101,115,121,129,137,138],[8.0,8.0,1.0,8.0,9.0,8.0,5.0,2.0,2.0,8.0,9.0,1.0,8.0,2.0,4.0,3.0,1.0,8.0,8.0,4.0,1.0,1.0,8.0,9.0,1.0,8.0,1.0,8.0,1.0,2.0,2.0,2.0,2.0,2.0,1.0,1.0,2.0,1.0,2.0,1.0,1.0,2.0])) 0\n",
"(3.0,(141,[1,3,6,7,8,12,13,14,16,18,19,20,21,22,24,28,29,30,35,36,37,42,45,46,48,51,53,55,56,58,60,61,64,65,66,68,72,73,75,76,78,82,87,88,91,93,96,99,102,109,113,118,119,127,128,130,133,135],[9.0,1.0,1.0,8.0,4.0,5.0,4.0,9.0,4.0,4.0,4.0,6.0,4.0,3.0,5.0,6.0,5.0,2.0,3.0,7.0,3.0,4.0,1.0,12.0,4.0,4.0,5.0,4.0,4.0,1.0,13.0,13.0,1.0,12.0,4.0,8.0,3.0,4.0,3.0,3.0,1.0,3.0,3.0,5.0,28.0,20.0,11.0,7.0,1.0,2.0,2.0,6.0,1.0,3.0,25.0,6.0,31.0,8.0])) 3\n",
"(1.0,(141,[1,6,7,13,14,17,19,20,21,22,26,28,30,35,36,37,45,46,53,55,58,59,60,61,64,65,66,69,71,72,73,74,75,76,78,81,87,88,91,96,99,100,101,102,103,106,108,111,112,113,114,115,118,119,120,122,127,130,131,133,135,138,140],[8.0,4.0,9.0,4.0,9.0,8.0,4.0,7.0,3.0,3.0,13.0,10.0,3.0,3.0,6.0,4.0,1.0,12.0,8.0,9.0,4.0,1.0,12.0,13.0,1.0,12.0,5.0,8.0,1.0,3.0,1.0,2.0,3.0,3.0,5.0,4.0,19.0,1.0,1.0,3.0,18.0,4.0,6.0,2.0,2.0,2.0,1.0,1.0,2.0,13.0,2.0,1.0,1.0,5.0,6.0,4.0,19.0,1.0,2.0,1.0,20.0,6.0,4.0])) 1\n",
"(0.0,(141,[1,6,10,14,21,24,26,29,36,45,53,58,60,61,64,73,81,85,87,88,90,97,98,99,100,101,102,108,112,113,114,120,122,126,127,129,131,135,137,138,140],[1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,13.0,2.0,4.0,4.0,4.0,4.0,2.0,5.0,13.0,13.0,3.0,4.0,5.0,3.0,2.0,13.0,13.0,2.0,4.0,4.0,2.0,8.0,4.0,13.0,13.0])) 2\n"
]
}
],
"source": [
"samples = train_df.sample(False, 10.0 / 250.0).collect()\n",
"for point in samples:\n",
" print point, lr.predict(point.features)"
]
},
{
"cell_type": "code",
"execution_count": 71,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"LBFGS training error = 0.166666666667\n",
"LBFGS test error = 0.785714285714\n"
]
}
],
"source": [
"trainingLabelAndPreds1 = train_df.map(lambda point: (point.label, lr.predict(point.features)))\n",
"trainingError1 = trainingLabelAndPreds1.map(lambda (r1, r2): float(r1 != r2)).mean()\n",
"print 'LBFGS training error =', trainingError1\n",
"\n",
"testLabelAndPreds1 = test_df.map(lambda point: (point.label, lr.predict(point.features)))\n",
"testError1 = testLabelAndPreds1.map(lambda (r1, r2): float(r1 != r2)).mean()\n",
"print 'LBFGS test error =',testError1"
]
},
{
"cell_type": "code",
"execution_count": 82,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"#now score the actual users\n",
"#hack out that SEG column from the schema\n",
"\n",
"schema_no_seg = StructType(customSchema.fields[1:])\n",
"users_df = sqlContext.read \\\n",
" .format('com.databricks.spark.csv') \\\n",
" .load('data/segmentation-users.csv', schema = schema_no_seg)"
]
},
{
"cell_type": "code",
"execution_count": 93,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"user_features_df = assembler.transform(users_df)\n",
"user_features_df = user_features_df['UU_ID', 'features'].map(lambda x: (x['UU_ID'], lr.predict(x['features']) + 1))"
]
},
{
"cell_type": "code",
"execution_count": 94,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"[(u'a49386ed-c0a2-4be3-9c43-57e7150bc1d6', 8),\n",
" (u'gregmadie', 5),\n",
" (u'074a37bc-ca46-45e5-94e9-b6bcd97378b9', 5)]"
]
},
"execution_count": 94,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"user_features_df.take(3)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.10"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment