Created
July 21, 2016 11:18
-
-
Save aristotle0x01/30eed733d9a378cbf6192f9ddc17c3d6 to your computer and use it in GitHub Desktop.
simhash duplicates detection
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// scalastyle:off println | |
import duplicate.SimHash | |
import org.apache.commons.lang3.StringUtils | |
import org.apache.log4j.{Level, Logger} | |
import org.apache.spark.{SparkConf, SparkContext} | |
import scopt.OptionParser | |
import scala.collection.mutable.ArrayBuffer | |
object GroupedDuplicate { | |
val HAMMING_DISTANCE = 7 | |
def main(args: Array[String]) { | |
val defaultParams = Params() | |
val parser = new OptionParser[Params]("Duplicates") { | |
head("Duplicates: find duplicates in code snippets.") | |
opt[String]("input") | |
.text("input file path") | |
.action((x, c) => c.copy(input = x)) | |
opt[String]("output") | |
.text("output file path") | |
.action((x, c) => c.copy(output = x)) | |
} | |
parser.parse(args, defaultParams).map { params => | |
run(params) | |
}.getOrElse { | |
parser.showUsageAsError | |
sys.exit(1) | |
} | |
} | |
def parseLong(s:String): java.lang.Long ={ | |
try { | |
return java.lang.Long.parseLong(s.trim) | |
} catch { | |
case e: Exception => 0L | |
} | |
} | |
private def run(params: Params) { | |
val conf = new SparkConf().setAppName(s"Duplicates with $params").set("spark.driver.maxResultSize", "64g") | |
val sc = new SparkContext(conf) | |
Logger.getRootLogger.setLevel(Level.WARN) | |
val preprocessStart = System.nanoTime() | |
val textRDD = sc.textFile(params.input) | |
// lines | |
val id_hash = textRDD. | |
map { case (text) => StringUtils.split(text, 0x01.toChar)}. | |
filter{case (array) => array.length == 2}. | |
filter{case (array) => !array.head.trim.isEmpty && !array.last.isEmpty}. | |
map { case (array) => (array.head.trim, parseLong(array.last))}. | |
filter{case (id, hash) => hash != 0L} | |
id_hash.cache | |
val validIdCount = id_hash.count | |
println(s"\t valid id: $validIdCount") | |
// group byte-wise | |
val groups = new ArrayBuffer[Map[Long, scala.Iterable[String]]]() | |
for( index <- 0 to 7){ | |
val g = id_hash. | |
map{case (id, hash) => (SimHash.extractSub(hash,8,index*8), id)}. | |
groupByKey. | |
collect. | |
toMap | |
groups += g | |
} | |
val zgroups = groups. | |
zipWithIndex. | |
map{case (map, index) => (index, map)}. | |
toMap | |
// id hash map | |
val hash_map = id_hash.collect.toMap | |
val results = id_hash. | |
map{ | |
case (id, hash) => { | |
// 存放八个分组中对应hash值的id列表 | |
val ids = new ArrayBuffer[String] | |
for(i <- 0 to 7){ | |
// hash在本组的取值 | |
val hi = SimHash.extractSub(hash, 8, i*8) | |
// 对应组hash值及其id列表 | |
val m = zgroups.get(i).get | |
val gid = m.get(hi).getOrElse(null) | |
if(gid != null && gid.size > 0){ | |
ids.appendAll(gid) | |
} | |
} | |
var list = new ArrayBuffer[String] | |
for(m <- ids.distinct){ | |
if(!id.equalsIgnoreCase(m) && SimHash.hammingDistance(hash.longValue,hash_map.get(m).get.longValue) <= HAMMING_DISTANCE){ | |
list += m | |
} | |
} | |
(id, list) | |
} | |
}. | |
filter{case (id,list) => list.length > 0} | |
// duplicate ids concatenated by comma in a single line | |
val duplicates = results. | |
map{ case (id,list) => id + "," +list.mkString(",")} | |
duplicates.coalesce(1, true).saveAsTextFile(params.output) | |
val dCount = duplicates.count | |
println() | |
println(s"Duplicates summary:") | |
println(s"\t count of duplicates is: $dCount") | |
println() | |
val preprocessElapsed = (System.nanoTime() - preprocessStart) / 1e9 | |
println(s"\t Elapsed time is: $preprocessElapsed") | |
sc.stop | |
} | |
private case class Params( | |
input: String = "", | |
output: String = "") extends AbstractParams[Params] | |
} | |
// scalastyle:on println |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment