Skip to content

Instantly share code, notes, and snippets.

@dapangmao
Last active August 29, 2015 14:17
Show Gist options
  • Save dapangmao/41d77e1f8207804ee7a7 to your computer and use it in GitHub Desktop.
Save dapangmao/41d77e1f8207804ee7a7 to your computer and use it in GitHub Desktop.
maintain a median

###On a single machine

import heapq
class find_median(object):
    def __init__(self):
        self.first_half = [] # will be a max heap
        self.second_half = [] # will be a min heap, 1/2 chance has one more element
        self.N = 0

    def insert(self, x):
        heapq.heappush(self.first_half, -x)
        self.N += 1
        if len(self.second_half) == len(self.first_half):
            to_second, to_first = map(heapq.heappop, [self.first_half, self.second_half])
            heapq.heappush(self.second_half, -to_second)
            heapq.heappush(self.first_half, -to_first)
        else:
            to_second = heapq.heappop(self.first_half)
            heapq.heappush(self.second_half, -to_second)
        
    def median(self):
        if self.N == 0:
            raise IOError('please use insert method to input')
        if self.N % 2 == 0:
            return (-self.first_half[0] + self.second_half[0]) / 2.0
        return self.second_half[0]

if __name__ == '__main__':
    test = find_median()
    input = [3, 5, 6, 9, 8, 1, 100, 2, 11, 7, 1]
    for i, x in enumerate(input):
        test.insert(x)
        print test.median(), test.first_half, test.second_half

###On a Spark cluster

from pyspark.streaming import *
ssc = StreamingContext(sc, 31)

lines = ssc.socketTextStream(host, int(port))
inputs = lines.flatMap(lambda line: line.split(" ")).map(int)
test = find_median()

def process(time, rdd):
    try:
        for x in rdd.collect():
            test.insert(x)
        print "========= %s =========" % str(time)
        print test.median()
    except:
        pass

inputs.foreachRDD(process)
ssc.start()
ssc.awaitTermination()

####Further design #####Step 1:

class find_median:
    def findMedianSortedstreamrrays(self, stream, rdd):
        length = len(stream) + len(rdd)
        if length % 2 == 0:
            return ( self.findKth(stream, 0, rdd, 0, length / 2) + \
                self.findKth(stream, 0, rdd, 0, length / 2 + 1) ) / 2.0
        else:
            return self.findKth(stream, 0, rdd, 0, length / 2 + 1)

    def findKth(self, stream, stream_start, rdd, rdd_start, k):
        if stream_start >= len(stream):
            return rdd[rdd_start + k - 1]
        if rdd_start >= len(rdd):
            return stream[stream_start + k - 1]
        if k == 1:
            return min(stream[stream_start], rdd[rdd_start])
        if stream_start + k/2 -1 < len(stream):
            stream_key = stream[stream_start + k/2 -1]
        else:
            stream_key = 9223372036854775807
        if rdd_start + k/2 -1 < len(rdd):
            rdd_key = rdd[rdd_start + k/2 -1]
        else:
            rdd_key = 9223372036854775807
        if stream_key < rdd_key:
            return self.findKth(stream, stream_start + k / 2, rdd, rdd_start, k - k/2)
        else:
            return self.findKth(stream, stream_start, rdd, rdd_start + k / 2, k - k/2)

#####Step 2

def merge(self, rdd, m, stream, n):
    m -= 1
    n -= 1
    last = m + n + 1
    while m >= 0 and n >= 0:
        if rdd[m] >= stream[n]:
            rdd[last] = rdd[m]
            m -= 1
        else:
            rdd[last] = stream[n]
            n -= 1
        last -= 1
    if n >= 0:
        for i, x in enumerate(stream[:n+1]):
            rdd[i] = x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment