Skip to content

Instantly share code, notes, and snippets.

@RussellSpitzer
Created January 30, 2015 01:02
Show Gist options
  • Save RussellSpitzer/371562c078ef5d01e55d to your computer and use it in GitHub Desktop.
Save RussellSpitzer/371562c078ef5d01e55d to your computer and use it in GitHub Desktop.
package com.datastax.spark.connector.rdd
import java.net.InetAddress
import com.datastax.driver.core.BatchStatement.Type
import com.datastax.driver.core.ConsistencyLevel
import com.datastax.spark.connector.{SomeColumns, NamedColumnRef, AllColumns, ColumnSelector}
import com.datastax.spark.connector.cql.CassandraConnector
import com.datastax.spark.connector.rdd.partitioner.CassandraPartition
import com.datastax.spark.connector.rdd.reader.RowReaderFactory
import com.datastax.spark.connector.util.CountingIterator
import com.datastax.spark.connector.writer._
import org.apache.spark.{Partitioner, TaskContext, Partition}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext._
import scala.reflect.ClassTag
import scala.util.Random
// O[ld] Is the type of the RDD we are Mapping From, N[ew] the type were are mapping too Old
class CassandraPartitionKeysRDD[O, N](prev: RDD[O],
keyspaceName: String,
tableName: String,
connector: CassandraConnector,
rwf: RowWriterFactory[O],
columnsToRead:ColumnSelector = AllColumns)
(implicit oldTag: ClassTag[O], newTag: ClassTag[N])
extends CassandraRDD[N](prev.sparkContext, connector, keyspaceName, tableName, columnsToRead) {
private def keyByReplica(implicit rwf: RowWriterFactory[O]): RDD[(Set[InetAddress], O)] = {
val converter = ReplicaMapper[O](connector, keyspaceName, tableName)
this.mapPartitions(primaryKeys =>
converter.mapReplicas(primaryKeys)
)
}
private def singleKeyCqlQuery: (String) = {
val columns = selectedColumnNames.map(_.cql).mkString(", ")
val partitionWhere = tableDef.partitionKey.map(_.columnName).map(name => s"$name = :$name").mkString(" AND ")
val filter = ( where.predicates +: partitionWhere ).mkString(" AND ")
val quotedKeyspaceName = quote(keyspaceName)
val quotedTableName = quote(tableName)
(s"SELECT $columns FROM $quotedKeyspaceName.$quotedTableName WHERE $filter")
}
override def compute(split: Partition, context: TaskContext): Iterator[N] = {
connector.withSessionDo { session =>
val stmt = session.prepare(singleKeyCqlQuery).setConsistencyLevel(readConf.consistencyLevel)
val queryExecutor = new QueryExecutor(session, 20) //TODO add readconf for this
val converter = ReplicaMapper(connector, keyspaceName, tableName)
converter.bindStatements(prev.iterator(split, context), stmt)
}
}
case class EndpointPartition(index: Int, endpoint: Set[InetAddress]) extends Partition
override protected def getPartitions: Array[Partition] = partitions
override def getPreferredLocations(split: Partition): Seq[String] = {
split match {
case epp:EndpointPartition =>
epp.endpoint.map(_.getHostAddress).toSeq
case _ =>
super.getPreferredLocations(split)
}
}
private def partitionByReplica(partitionsPerReplicaSet: Int = 10)
(implicit rwf: RowWriterFactory[O]): CassandraPartitionKeysRDD[O,N] = {
class ReplicaPartitioner(partitionsPerReplicaSet: Int) extends Partitioner {
val hosts = connector.hosts
val hostMap = hosts.zipWithIndex.toMap //TODO We Need JAVA-312 to get sets of replicas instead of single endpoints
val indexMap = hostMap.map(_.swap)
val numHosts = hosts.size
val rand = new Random()
override def getPartition(key: Any): Int = {
val replicaSet = key.asInstanceOf[Set[InetAddress]]
val offset = rand.nextInt(partitionsPerReplicaSet)
hostMap.getOrElse(replicaSet.last, rand.nextInt(numHosts)) + offset
}
override def numPartitions: Int = partitionsPerReplicaSet * numHosts
def getEndpointsForParition(index: Int): InetAddress = {
indexMap.getOrElse(index, throw new RuntimeException(s"${indexMap} : Can't get an endpoint for Partition $index"))
}
}
val part = new ReplicaPartitioner(partitionsPerReplicaSet)
val output = this.keyByReplica
.partitionBy(part)
.map(_._2)
new CassandraPartitionKeysRDD[O,N](prev = output, keyspaceName=keyspaceName, tableName=tableName,
connector=connector, rwf = rwf)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment