Skip to content

Instantly share code, notes, and snippets.

@davidandrzej
Created September 14, 2011 00:23
Show Gist options
  • Select an option

  • Save davidandrzej/1215547 to your computer and use it in GitHub Desktop.

Select an option

Save davidandrzej/1215547 to your computer and use it in GitHub Desktop.
Use quicksort-style pivots to find top k items in expected linear time
(defn- pivot-partition
"Partition values into (less-than, equal-to, greater-than) the pivot value"
[argvals keyfn pivotval]
(loop [curval (first argvals)
vals (rest argvals)
lt (transient (vector))
p (transient (vector))
gt (transient (vector))]
(if (nil? curval)
[(persistent! lt) (persistent! p) (persistent! gt)]
(recur (first vals)
(rest vals)
(if (< (keyfn curval) (keyfn pivotval))
(conj! lt curval)
lt)
(if (== (keyfn curval) (keyfn pivotval))
(conj! p curval)
p)
(if (> (keyfn curval) (keyfn pivotval))
(conj! gt curval)
gt)))))
(defn- random-element
"Return a random element from vector v"
[v]
(nth v (.nextInt rng (count v))))
(defn- randomly-pivot
"Randomly select a pivot element and partition collection with respect to it"
[v keyfn]
(pivot-partition v keyfn (random-element v)))
(defn top-k
"Use quicksort-style pivots to find top k items in expected linear time"
([v k]
(top-k v identity k))
([v keyfn k]
(let [[lt p gt] (randomly-pivot v keyfn)]
(cond
;; Top k are entirely contained in less-than set
(> (count lt) k)
(top-k lt keyfn k)
;; Top k are exactly equal to less-than set (what luck!)
(== (count lt) k)
lt
;; Top k consist of less-than set, plus...?
(< (count lt) k)
(cond
;; some, but not all, of the pivot items
(> (+ (count lt) (count p)) k)
(concat lt (take (- k (count lt)) p))
;; exactly all of the pivot items
(== (+ (count lt) (count p)) k)
(concat lt p)
;; all of the pivot items, plus some of the greater-than set
(< (+ (count lt) (count p)) k)
(concat lt p (top-k gt keyfn (- k (+ (count lt) (count p))))))))))
;;topk.topk> (def rvals (for [i (range 1000000)] (.nextInt rng 100000000)))
;;#'topk.topk/rvals
;;topk.topk> (time (take 10 (sort rvals)))
;;"Elapsed time: 4105.471 msecs"
;;(76 89 386 476 619 708 798 842 843 891)
;;topk.topk> (time (top-k rvals 10))
;;"Elapsed time: 288.22 msecs"
;;(619 89 386 476 76 708 798 842 843 891)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment