The existing examples for this are good, but they miss a pretty critical observation, the number of partitions and how this affects things.
Assume we have the following script, aggregate_by_key.py:
import pprint
from pyspark.context import SparkContext
def sequence_fn(x, y):
# At first, x is always the neutral "zero" value passed to aggregateByKey,
# y is always a value from the pair rdd on which you're aggregating
return '[ %s %s ]' % (x, y['value'])
def comb_fn(a1, a2):
# This function is only used to merge aggregates created across partitions
return '{ %s %s }' % (a1, a2)
with SparkContext() as sc:
letters = sc.parallelize(list('aaaaaabbcccccddddeeeeeeeeeeeee'))\
.repartition(1)\
.map(lambda l: (l, {'value': l}))\
.aggregateByKey('start', sequence_fn, comb_fn)
partitions = letters.getNumPartitions()
letters = letters.collectAsMap()
print('Letter partitions: %d' % partitions)
pprint.pprint(letters)
First, we'll run with: spark-submit aggregate_by_key.py
which gives us :
Letter partitions: 1
{'a': '[ [ [ [ [ [ start a ] a ] a ] a ] a ] a ]',
'b': '[ [ start b ] b ]',
'c': '[ [ [ [ [ start c ] c ] c ] c ] c ]',
'd': '[ [ [ [ start d ] d ] d ] d ]',
'e': '[ [ [ [ [ [ [ [ [ [ [ [ [ start e ] e ] e ] e ] e ] e ] e ] e ] e ] e ] e ] e ] e ]'}
A few things to note here.
First of all for the b
key we see an aggregate of [ [ start b ] b ]
this indicates two function calls to seq_fn
:
seq_fn('start', {'value': 'b'})
seq_fn('[ start b ]', {'value': 'b'})
Key takeaway here is that seq_fn
is a bad name for what is essentially, a per-partition reduce. For Pythonistas, it's helpful to think of the seq_fn like so:
>>> reduce(lambda a, b: '[ %s %s ]' % (a, b), list('aaaa'), 'start')
'[ [ [ [ start a ] a ] a ] a ]'
Notice that since we only have a single partition, comb_fn
is never called to merge aggregates across partitions. If we have multitple partitions, this changes. We could try the following script:
import pprint
from pyspark.context import SparkContext
def sequence_fn(x, y):
# At first, x is always the neutral "zero" value passed to aggregateByKey,
# y is always a value from the pair rdd on which you're aggregating
return '[ %s %s ]' % (x, y['value'])
def comb_fn(a1, a2):
# This function is only used to merge aggregates created across partitions
return '{ %s %s }' % (a1, a2)
with SparkContext() as sc:
letters = sc.parallelize(list('aaaaaabbcccccddddeeeeeeeeeeeee'))\
# more partitions!
.repartition(5)\
.map(lambda l: (l, {'value': l}))\
.aggregateByKey('start', sequence_fn, comb_fn)
partitions = letters.getNumPartitions()
letters = letters.collectAsMap()
print('Letter partitions: %d' % partitions)
pprint.pprint(letters)
In case you missed it, we changed the number of partitions from 1 to 5. You'll get output similar to (but not exactly like) the following:
Letter partitions: 5
{'a': '[ [ [ [ [ [ start a ] a ] a ] a ] a ] a ]',
'b': '[ [ start b ] b ]',
'c': '[ [ [ [ [ start c ] c ] c ] c ] c ]',
'd': '{ [ start d ] [ [ [ start d ] d ] d ] }',
'e': '{ { [ [ [ [ [ [ [ start e ] e ] e ] e ] e ] e ] e ] [ [ start e ] e ] } [ [ [ [ start e ] e ] e ] e ] }'}
All of a sudden, we have curly braces because we merged aggregates both within and across partitions.
Finally, the documentation (and this example) makes this pretty clear, but your per-partition reduce function need not produce a result of the same type as the input values in the pair RDD. In the example above, the input pair RDD used dicts for values, but strings for per-partition results.
It is expected though that the multi-partition reduce function return the same type of result as that of the per-partition reduce.
Hope that clears things up for some people!
Mike