Created
January 30, 2015 01:02
-
-
Save RussellSpitzer/371562c078ef5d01e55d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.datastax.spark.connector.rdd | |
import java.net.InetAddress | |
import com.datastax.driver.core.BatchStatement.Type | |
import com.datastax.driver.core.ConsistencyLevel | |
import com.datastax.spark.connector.{SomeColumns, NamedColumnRef, AllColumns, ColumnSelector} | |
import com.datastax.spark.connector.cql.CassandraConnector | |
import com.datastax.spark.connector.rdd.partitioner.CassandraPartition | |
import com.datastax.spark.connector.rdd.reader.RowReaderFactory | |
import com.datastax.spark.connector.util.CountingIterator | |
import com.datastax.spark.connector.writer._ | |
import org.apache.spark.{Partitioner, TaskContext, Partition} | |
import org.apache.spark.annotation.DeveloperApi | |
import org.apache.spark.rdd.RDD | |
import org.apache.spark.SparkContext._ | |
import scala.reflect.ClassTag | |
import scala.util.Random | |
// O[ld] Is the type of the RDD we are Mapping From, N[ew] the type were are mapping too Old | |
class CassandraPartitionKeysRDD[O, N](prev: RDD[O], | |
keyspaceName: String, | |
tableName: String, | |
connector: CassandraConnector, | |
rwf: RowWriterFactory[O], | |
columnsToRead:ColumnSelector = AllColumns) | |
(implicit oldTag: ClassTag[O], newTag: ClassTag[N]) | |
extends CassandraRDD[N](prev.sparkContext, connector, keyspaceName, tableName, columnsToRead) { | |
private def keyByReplica(implicit rwf: RowWriterFactory[O]): RDD[(Set[InetAddress], O)] = { | |
val converter = ReplicaMapper[O](connector, keyspaceName, tableName) | |
this.mapPartitions(primaryKeys => | |
converter.mapReplicas(primaryKeys) | |
) | |
} | |
private def singleKeyCqlQuery: (String) = { | |
val columns = selectedColumnNames.map(_.cql).mkString(", ") | |
val partitionWhere = tableDef.partitionKey.map(_.columnName).map(name => s"$name = :$name").mkString(" AND ") | |
val filter = ( where.predicates +: partitionWhere ).mkString(" AND ") | |
val quotedKeyspaceName = quote(keyspaceName) | |
val quotedTableName = quote(tableName) | |
(s"SELECT $columns FROM $quotedKeyspaceName.$quotedTableName WHERE $filter") | |
} | |
override def compute(split: Partition, context: TaskContext): Iterator[N] = { | |
connector.withSessionDo { session => | |
val stmt = session.prepare(singleKeyCqlQuery).setConsistencyLevel(readConf.consistencyLevel) | |
val queryExecutor = new QueryExecutor(session, 20) //TODO add readconf for this | |
val converter = ReplicaMapper(connector, keyspaceName, tableName) | |
converter.bindStatements(prev.iterator(split, context), stmt) | |
} | |
} | |
case class EndpointPartition(index: Int, endpoint: Set[InetAddress]) extends Partition | |
override protected def getPartitions: Array[Partition] = partitions | |
override def getPreferredLocations(split: Partition): Seq[String] = { | |
split match { | |
case epp:EndpointPartition => | |
epp.endpoint.map(_.getHostAddress).toSeq | |
case _ => | |
super.getPreferredLocations(split) | |
} | |
} | |
private def partitionByReplica(partitionsPerReplicaSet: Int = 10) | |
(implicit rwf: RowWriterFactory[O]): CassandraPartitionKeysRDD[O,N] = { | |
class ReplicaPartitioner(partitionsPerReplicaSet: Int) extends Partitioner { | |
val hosts = connector.hosts | |
val hostMap = hosts.zipWithIndex.toMap //TODO We Need JAVA-312 to get sets of replicas instead of single endpoints | |
val indexMap = hostMap.map(_.swap) | |
val numHosts = hosts.size | |
val rand = new Random() | |
override def getPartition(key: Any): Int = { | |
val replicaSet = key.asInstanceOf[Set[InetAddress]] | |
val offset = rand.nextInt(partitionsPerReplicaSet) | |
hostMap.getOrElse(replicaSet.last, rand.nextInt(numHosts)) + offset | |
} | |
override def numPartitions: Int = partitionsPerReplicaSet * numHosts | |
def getEndpointsForParition(index: Int): InetAddress = { | |
indexMap.getOrElse(index, throw new RuntimeException(s"${indexMap} : Can't get an endpoint for Partition $index")) | |
} | |
} | |
val part = new ReplicaPartitioner(partitionsPerReplicaSet) | |
val output = this.keyByReplica | |
.partitionBy(part) | |
.map(_._2) | |
new CassandraPartitionKeysRDD[O,N](prev = output, keyspaceName=keyspaceName, tableName=tableName, | |
connector=connector, rwf = rwf) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment