Last active
July 18, 2022 06:33
-
-
Save opyate/1927001 to your computer and use it in GitHub Desktop.
Consistent Hashing in Scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// scalac -unchecked -optimise ConsistentHash.scala && scala CHApp | |
import scala.collection.immutable.LinearSeq | |
import java.util.{TreeMap => JTreeMap} | |
import java.util.{SortedMap => JSortedMap} | |
/** | |
* Inspired by http://www.lexemetech.com/2007/11/consistent-hashing.html | |
*/ | |
object ConsistentHash { | |
def apply[K <% Ordered[K], V]( | |
hashFunction: HashFunction[K], | |
numberOfReplicas: Int, | |
nodes: LinearSeq[V]) = { | |
val ch = new ConsistentHash(hashFunction, numberOfReplicas, nodes) | |
nodes.foreach{n => { | |
ch.add(Some(n)) | |
}} | |
ch | |
} | |
} | |
/** | |
* HashFunction returns a hash of type K, which will be Ordered for the TreeMap | |
* Nodes are of type V | |
*/ | |
class ConsistentHash[K <% Ordered[K], V] private( | |
hashFunction: HashFunction[K], | |
numberOfReplicas: Int, | |
nodes: LinearSeq[V]) { | |
var numNodes = 0 | |
val SEPARATOR = ":" | |
val circle: JSortedMap[K, V] = new JTreeMap[K, V]() | |
/** | |
* Add a node V to the pool. | |
*/ | |
def add(node: Option[V]) { | |
node.map{n => | |
(0 to numberOfReplicas).foreach{i => | |
val hash = hashFunction.hash("%s%s%s".format(n, SEPARATOR, i)) | |
circle.put(hash, n) | |
} | |
} | |
numNodes = numNodes + 1 | |
} | |
/** | |
* Remove a node V from the pool. | |
*/ | |
def remove(node: Option[V]) { | |
node.map{n => | |
(0 to numberOfReplicas).foreach{i => | |
val hash = hashFunction.hash("%s%s%s".format(n, SEPARATOR, i)) | |
circle.remove(hash) | |
} | |
} | |
numNodes = numNodes - 1 | |
} | |
/** | |
* @return the cache node that will contain our object, by key. | |
*/ | |
def get(key: Option[AnyRef]): Option[V] = circle.isEmpty match { | |
case true => throw new RuntimeException("No nodes in ring") | |
case false => { | |
key.map{k => | |
hashFunction.hash(k) match { | |
case hash if (!circle.containsKey(hash)) => { | |
circle.tailMap(hash) match { | |
case tailMap if (tailMap.isEmpty) => Some(circle.get(circle.firstKey).asInstanceOf[V]) | |
case tailMap => Some(circle.get(tailMap.firstKey).asInstanceOf[V]) | |
} | |
} | |
case hash => Some(circle.get(hash).asInstanceOf[V]) | |
} | |
}.getOrElse(None) | |
} | |
} | |
def ringSize(): Int = numNodes | |
} | |
trait HashFunction[H] { | |
def hash(o: AnyRef): H | |
} | |
import java.security.MessageDigest | |
import java.io.ByteArrayOutputStream | |
import java.io.ObjectOutputStream | |
import java.math.BigInteger | |
class MD5HashFunction extends HashFunction[String] { | |
val md: MessageDigest = MessageDigest.getInstance("MD5") | |
override def hash(o: AnyRef): String = { | |
byteArrayToString(md.digest(toBinary(o))) | |
} | |
private[this] def byteArrayToString(data: Array[Byte]): String = { | |
val bigInteger = new BigInteger(1, data) | |
var hash = bigInteger.toString(16) | |
while (hash.length() < 32) { | |
hash = "0" + hash | |
} | |
hash | |
} | |
private[this] def toHex(b: Byte): Char = { | |
require(b >= 0 && b <= 15, "Byte " + b + " was not between 0 and 15") | |
if(b < 10) | |
('0'.asInstanceOf[Int] + b).asInstanceOf[Char] | |
else | |
('a'.asInstanceOf[Int] + (b-10)).asInstanceOf[Char] | |
} | |
private[this] def toBinary(obj: AnyRef): Array[Byte] = { | |
val bos = new ByteArrayOutputStream | |
val out = new ObjectOutputStream(bos) | |
out.writeObject(obj) | |
out.close | |
bos.toByteArray | |
} | |
} | |
import java.util.Random | |
object CHApp extends App { | |
val r = new Runner | |
val stats = for { | |
vnodes <- (1 to 500).toList | |
} yield { | |
print(" "+vnodes) // make-shift progress indicator | |
r.run(10000, vnodes) | |
} | |
// writes to file called "dat" in '.' | |
r.write(stats) | |
} | |
/** | |
* Drives our implementation, and generates some stats. | |
*/ | |
class Runner { | |
val quiet = true | |
implicit def cacheAsString(cache: Option[Cache]): String = cache.map{_.toString}.getOrElse("-NOCACHE-") | |
implicit def optionAnyToString(o: Option[Any]): String = o.map(_.toString).getOrElse("") | |
def stat(ch: ConsistentHash[_,_], objects: IndexedSeq[CachableThingy]) = { | |
(for {o <- objects} yield ch.get(Some(o))).groupBy(x=>x).map{case (k,v) => v.size} | |
} | |
def sq(i: Double) = i*i | |
def debug(s: String) { | |
if (!quiet) println(s) | |
} | |
def write(data: List[Tuple2[Int,Double]]) { | |
import java.io._ | |
def printToFile(f: java.io.File)(op: java.io.PrintWriter => Unit) { | |
val p = new java.io.PrintWriter(f) | |
try { op(p) } finally { p.close() } | |
} | |
printToFile(new File("dat"))(p => { | |
data.foreach(tup => {p.println("%s\t%s".format(tup._1, tup._2))}) | |
}) | |
} | |
def run(numobjects: Int, vnodes: Int) = { | |
debug("\n\n%s objects, and %s vnodes".format(numobjects, vnodes)) | |
val ch = ConsistentHash(new MD5HashFunction, vnodes, | |
List(Cache("alpha"), | |
Cache("beta"), | |
Cache("gamma"), | |
Cache("delta"), | |
Cache("epsilon"), | |
Cache("zeta"), | |
Cache("eta"), | |
Cache("theta"), | |
Cache("iota"), | |
Cache("kappa"))) | |
val rand = new Random(System.currentTimeMillis()) | |
val objects: IndexedSeq[CachableThingy] = for{ i <- 0 until numobjects} yield {CachableThingy(rand.nextInt(10000))} | |
// mean: average | |
// variance: average of the squared differences from the mean. | |
// standard deviation: square root of the variance | |
val stats = stat(ch, objects) | |
// we already know the total number of objects, and can cheat to get the mean | |
val mean = numobjects / ch.ringSize | |
val sqDistFromMean = stats.map(opc => { | |
debug("%s of %s / size(%s) = Mean(%s)".format(opc, numobjects, ch.ringSize, mean)) | |
sq(mean - opc) | |
}) | |
debug("%s".format(sqDistFromMean)) | |
val variance: Double = sqDistFromMean.sum / ch.ringSize | |
debug("Variance: %s".format(variance)) | |
val sd = scala.math.sqrt(variance) | |
debug("Standard deviation: %s".format(sd)) | |
val sdm = (sd / mean) * 100 | |
debug("Standard deviation as a percentage of the mean: %s".format(sdm)) | |
debug("-") | |
(vnodes, sdm) | |
} | |
} | |
case class Cache(s: String) | |
case class CachableThingy(i: Int) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
reset | |
set logscale | |
# xrange to max number of vnodes | |
#set xrange [1:1000] | |
# set specific tic labels | |
set ytics ("5" 5, "10" 10, "20" 20, "50" 50, "100" 100) | |
set xtics ("1" 1, "2" 2, "5" 5, "10" 10, "20" 20, "50" 50, "100" 100, "200" 200, "500" 500, "1000" 1000, "2000" 2000, "5000" 5000, "10000" 10000) | |
# x and y labels | |
set xlabel 'Number of Virtual Nodes' | |
set ylabel 'Standard Deviation as % of mean' | |
# linear regression | |
f(x) = m*x + b | |
fit f(x) "dat" using 1:2 via m,b | |
#...and since m and b is stored, just call it with the original plot | |
plot "dat" using 1:2 title 'Consistent Hashing' with points pointtype 5, f(x) title 'Line Fit' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
set ytics ("1" 1, "2" 2, "3" 3, "4" 4, "5" 5, "10" 10, "20" 20, "50" 50, "100" 100) | |
replot |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment