Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save jamiekt/94962e38cabeec788a4b291b50d995d5 to your computer and use it in GitHub Desktop.
Save jamiekt/94962e38cabeec788a4b291b50d995d5 to your computer and use it in GitHub Desktop.
PySpark code to derive a histogram of number of products in a basket
from pyspark.sql.functions import col, lit
# df is a Spark DataFrame: DataFrame[basket: string, product: string, customer: string, store: string]
baskets_tally = df.groupBy().agg(countDistinct(col('basket'))).collect()[0][0]
df = df.groupBy(col('basket')).count().withColumnRenamed('count', 'tally_of_products_per_basket')
df = df.groupBy("tally_of_products_per_basket") \
.count() \
.withColumnRenamed('count', 'tally_of_baskets_containing_products_tally') \
.orderBy(col("tally_of_products_per_basket").asc())
df = df.withColumn(
'fraction_of_baskets_containing_products_tally',
col('tally_of_baskets_containing_products_tally') / lit(baskets_tally)
)
"""
To illustrate the calculated data:
[(row[0], row[2]) for row in df \
.orderBy(col('tally_of_products_per_basket')) \
.limit(10) \
.collect()]
returns this histogram (x, y):
[(1, 0.1324264771159618),
(2, 0.1370917625841512),
(3, 0.11921614825173989),
(4, 0.09889617171282909),
(5, 0.08101257741810314),
(6, 0.06629952190050782),
(7, 0.054068733410406467),
(8, 0.044480808689380966),
(9, 0.03708986734310147),
(10, 0.030869985466493634)]
where y is the probability of a basket containing x products
"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment