Skip to content

Instantly share code, notes, and snippets.

@tuan3w
Last active February 9, 2017 07:56
Show Gist options
  • Save tuan3w/c968e56ea8ef135096eeedb08af097db to your computer and use it in GitHub Desktop.
Save tuan3w/c968e56ea8ef135096eeedb08af097db to your computer and use it in GitHub Desktop.
SignRandomProjectionLSH
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.feature
import scala.util.Random
import breeze.linalg.normalize
import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.linalg._
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.HasSeed
import org.apache.spark.ml.util._
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.StructType
import scala.collection.mutable
import scala.collection.mutable.ListBuffer
/**
* :: Experimental ::
*
* Params for [[SignRandomProjectionLSH]].
*/
private[ml] trait SignRandomProjectionLSHParams extends Params {
/**
* The length of each hash bucket, a larger bucket lowers the false negative rate. The number of
* buckets will be `(max L2 norm of input vectors) / bucketLength`.
*
*
* If input vectors are normalized, 1-10 times of pow(numRecords, -1/inputDim) would be a
* reasonable value
* @group param
*/
val bucketLength: DoubleParam = new DoubleParam(this, "bucketLength",
"the length of each hash bucket, a larger bucket lowers the false negative rate.",
ParamValidators.gt(0))
/** @group getParam */
final def getBucketLength: Double = $(bucketLength)
}
/**
* :: Experimental ::
*
* Model produced by [[SignRandomProjectionLSH]], where multiple random vectors are stored. The
* vectors are normalized to be unit vectors and each vector is used in a hash function:
* `h_i(x) = floor(r_i.dot(x) / bucketLength)`
* where `r_i` is the i-th random unit vector. The number of buckets will be `(max L2 norm of input
* vectors) / bucketLength`.
*
* @param randUnitVectors An array of random unit vectors. Each vector represents a hash function.
*/
@Experimental
@Since("2.1.0")
class SignRandomProjectionLSHModel private[ml](
override val uid: String,
private[ml] val randUnitVectors: Array[Vector])
extends LSHModel[SignRandomProjectionLSHModel] with SignRandomProjectionLSHParams {
@Since("2.1.0")
override protected[ml] val hashFunction: Vector => Array[Vector] = {
key: Vector => {
val hashValues: Array[Double] = randUnitVectors.map({
randUnitVector => {
val ve=(key.asBreeze dot randUnitVector.asBreeze)
if(ve.toDouble>0) 1.0 else 0.0
}
})
hashValues.map(f=>Vectors.dense(f))
}
}
@Since("2.1.0")
override protected[ml] def keyDistance(x: Vector, y: Vector): Double = {
Math.sqrt(Vectors.sqdist(x, y))
}
@Since("2.1.0")
override protected[ml] def hashDistance(x: Seq[Vector], y: Seq[Vector]): Double = {
// Since it's generated by hashing, it will be a pair of dense vectors.
x.zip(y).map(vectorPair => Vectors.sqdist(vectorPair._1, vectorPair._2)).min
}
@Since("2.1.0")
override def copy(extra: ParamMap): this.type = defaultCopy(extra)
@Since("2.1.0")
override def write: MLWriter = {
new SignRandomProjectionLSHModel.SignRandomProjectionLSHModelWriter(this)
}
}
/**
* :: Experimental ::
*
* This [[SignRandomProjectionLSH]] implements Locality Sensitive Hashing functions for
* Euclidean distance metrics.
*
* The input is dense or sparse vectors, each of which represents a point in the Euclidean
* distance space. The output will be vectors of configurable dimension. Hash values in the
* same dimension are calculated by the same hash function.
*
* References:
*
* 1. <a href="https://en.wikipedia.org/wiki/Locality-sensitive_hashing#Stable_distributions">
* Wikipedia on Stable Distributions</a>
*
* 2. Wang, Jingdong et al. "Hashing for similarity search: A survey." arXiv preprint
* arXiv:1408.2927 (2014).
*/
@Experimental
@Since("2.1.0")
class SignRandomProjectionLSH(override val uid: String)
extends LSH[SignRandomProjectionLSHModel]
with SignRandomProjectionLSHParams with HasSeed {
@Since("2.1.0")
override def setInputCol(value: String): this.type = super.setInputCol(value)
@Since("2.1.0")
override def setOutputCol(value: String): this.type = super.setOutputCol(value)
@Since("2.1.0")
override def setNumHashTables(value: Int): this.type = super.setNumHashTables(value)
@Since("2.1.0")
def this() = {
this(Identifiable.randomUID("brp-lsh"))
}
/** @group setParam */
@Since("2.1.0")
def setBucketLength(value: Double): this.type = set(bucketLength, value)
/** @group setParam */
@Since("2.1.0")
def setSeed(value: Long): this.type = set(seed, value)
@Since("2.1.0")
override protected[this] def createRawLSHModel(
inputDim: Int): SignRandomProjectionLSHModel = {
val rand = new Random($(seed))
val randUnitVectors: Array[Vector] = {
Array.fill($(numHashTables)) {
val randArray = Array.fill(inputDim)(rand.nextGaussian())
Vectors.fromBreeze(normalize(breeze.linalg.Vector(randArray)))
}
}
new SignRandomProjectionLSHModel(uid, randUnitVectors)
}
@Since("2.1.0")
override def transformSchema(schema: StructType): StructType = {
SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT)
validateAndTransformSchema(schema)
}
@Since("2.1.0")
override def copy(extra: ParamMap): this.type = defaultCopy(extra)
}
@Since("2.1.0")
object SignRandomProjectionLSH extends DefaultParamsReadable[SignRandomProjectionLSH] {
@Since("2.1.0")
override def load(path: String): SignRandomProjectionLSH = super.load(path)
}
@Since("2.1.0")
object SignRandomProjectionLSHModel extends MLReadable[SignRandomProjectionLSHModel] {
@Since("2.1.0")
override def read: MLReader[SignRandomProjectionLSHModel] = {
new SignRandomProjectionLSHModelReader
}
@Since("2.1.0")
override def load(path: String): SignRandomProjectionLSHModel = super.load(path)
private[SignRandomProjectionLSHModel] class SignRandomProjectionLSHModelWriter(
instance: SignRandomProjectionLSHModel) extends MLWriter {
// TODO: Save using the existing format of Array[Vector] once SPARK-12878 is resolved.
private case class Data(randUnitVectors: Matrix)
override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
val numRows = instance.randUnitVectors.length
require(numRows > 0)
val numCols = instance.randUnitVectors.head.size
val values = instance.randUnitVectors.map(_.toArray).reduce(Array.concat(_, _))
val randMatrix = Matrices.dense(numRows, numCols, values)
val data = Data(randMatrix)
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}
private class SignRandomProjectionLSHModelReader
extends MLReader[SignRandomProjectionLSHModel] {
/** Checked against metadata when loading model */
private val className = classOf[SignRandomProjectionLSHModel].getName
override def load(path: String): SignRandomProjectionLSHModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.parquet(dataPath)
val Row(randUnitVectors: Matrix) = MLUtils.convertMatrixColumnsToML(data, "randUnitVectors")
.select("randUnitVectors")
.head()
val model = new SignRandomProjectionLSHModel(metadata.uid,
randUnitVectors.rowIter.toArray)
DefaultParamsReader.getAndSetParams(model, metadata)
model
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment