Skip to content

Instantly share code, notes, and snippets.

@ipeirotis
Last active September 2, 2022 12:32
Show Gist options
  • Save ipeirotis/ce26e0e76a5192f89c2e to your computer and use it in GitHub Desktop.
Save ipeirotis/ce26e0e76a5192f89c2e to your computer and use it in GitHub Desktop.
### Cohort Analysis
import matplotlib.pyplot as plt
# Connect to the BigQuery API
from googleapiclient.discovery import build
from oauth2client import client
credentials = client._get_application_default_credential_from_file('client_secrets.json')
credentials = credentials.create_scoped('https://www.googleapis.com/auth/bigquery')
bigquery_service = build('bigquery', 'v2', credentials=credentials)
# Run a SQL query against the HITgroup data
# to find out the dates of first and last activity
# for each requester. Then, aggregate on top,
# and report the number of requesters that were firstSeen
# on a month X, and were lastSeen on a month Y
#
# The string manipulation is just to get the dates
# represented as YYYY-MM
query_request = bigquery_service.jobs()
query_data = {
'query': (
'''
SELECT
firstSeen,
lastSeen,
COUNT(*) AS cnt
FROM (
SELECT
requesterId,
STRING(YEAR(MIN(firstSeen))) + '-' + RIGHT('0' + STRING(MONTH(MIN(firstSeen))), 2) AS firstSeen,
STRING(YEAR(MAX(lastSeen))) + '-' + RIGHT('0' + STRING(MONTH(MAX(lastSeen))), 2) AS lastSeen,
COUNT(groupId) AS HITgroupsPosted
FROM
entities.HITgroup
GROUP BY
requesterId)
GROUP BY
firstSeen,
lastSeen
ORDER BY
firstSeen,
lastSeen
''')
}
query_response = query_request.query(
projectId='mturk-tracker',
body=query_data).execute()
# Put the SQL results in a Pandas Dataframe
import pandas as pd
columns = [f.get('name') for f in query_response['schema']['fields']]
rows = [tuple([row['f'][i]['v'] for i in range(len(row['f']))]) for row in query_response['rows']]
df = pd.DataFrame(data=rows, columns=columns, dtype=int)
# The code above is for connecting to the BigQuery API and running the SQL query
# Alternatively, you can just download the CSV file from
# https://gist.github.com/ipeirotis/6af638e971537b0f9524
# and load it in a dataframe
# df = pd.read_csv("cohort_analysis.csv", sep="\t")
# Transform the dataframe into a table with lastSeen month as rows and firstSeen month as columns
# This table contains how many requesters have left on any given month (shown in the rows),
# from each cohort (shown as columns)
pivot = pd.pivot_table(df, values='cnt', index=['lastSeen'], columns=['firstSeen'])
# We will do a set of transformations here, to create a pivotTable
# with each column containing a cohort of users, and each row
# showing how many requesters from that cohort are still active
# (active means that they have posted a HIT on that month or later)
import numpy as np
# The cumulative sum creates columns that have the "opposite" order
# than what we want. Each column indicates how the total number
# of requesters for that cohort that have left until that month
pc = pivot.cumsum()
# Ideally, I would like to "reverse" the columns ignoring NaNs
# but I cannot figure out an easy way to do so. So, I revert
# to a set of for loops.
#
# The np.amax(pc,axis=0)[c] gets the maximum value for the column c
# which is the total number of requesters that were firstSeen
# on month c. Then, for each row at month Y, we subtract
# the number of requesters that have left until that month
# We use the .shift() operator because we do not want to subtract
# from month Y the requesters that were still active on month Y
pc2 = pd.DataFrame(pivot)
for c in pc.columns.values:
pc2[c] = np.amax(pc,axis=0)[c] - pc[c].shift()
# After the operation above, the diagonal is NaN. We need to fill
# it in with the total number of requesters that appeared in that
# month.
for c in pc.columns.values:
pc2.at[c,c] = np.amax(pc,axis=0)[c]
# Create the plot showing the cohorts
f = plt.figure(edgecolor='k')
ax=f.gca()
pc2.plot(kind='area', stacked=True, legend=True, figsize=(16,8), cmap='Paired', grid=True, ax=ax);
plt.title('Amazon Mechanical Turk Cohort Analysis', color='black')
plt.legend(loc='lower center', ncol=8, bbox_to_anchor=[0.5, -0.25])
ax.set_ylabel("Active Requesters")
ax.set_xlabel("Date")
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment