Last active
September 5, 2016 15:34
-
-
Save crakjie/d6066785d81ad8df828b to your computer and use it in GitHub Desktop.
CassandraLeftJoinRDD
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.datastax.spark.connector.rdd | |
import org.apache.spark.metrics.InputMetricsUpdater | |
import com.datastax.driver.core.Session | |
import com.datastax.spark.connector._ | |
import com.datastax.spark.connector.cql._ | |
import com.datastax.spark.connector.rdd.reader._ | |
import com.datastax.spark.connector.util.CqlWhereParser.{EqPredicate, InListPredicate, InPredicate, RangePredicate} | |
import com.datastax.spark.connector.util.{CountingIterator, CqlWhereParser} | |
import com.datastax.spark.connector.writer._ | |
import com.datastax.spark.connector.util.Quote._ | |
import org.apache.spark.rdd.RDD | |
import org.apache.spark.{Partition, TaskContext} | |
import scala.reflect.ClassTag | |
/** | |
* An [[org.apache.spark.rdd.RDD RDD]] that will do a selecting join between `left` RDD and the specified | |
* Cassandra Table This will perform individual selects to retrieve the rows from Cassandra and will take | |
* advantage of RDDs that have been partitioned with the | |
* [[com.datastax.spark.connector.rdd.partitioner.ReplicaPartitioner]] | |
* | |
* @tparam L item type on the left side of the join (any RDD) | |
* @tparam R item type on the right side of the join (fetched from Cassandra) | |
*/ | |
class CassandraLeftJoinRDD[L, R] private[connector]( | |
left: RDD[L], | |
val keyspaceName: String, | |
val tableName: String, | |
val connector: CassandraConnector, | |
val columnNames: ColumnSelector = AllColumns, | |
val joinColumns: ColumnSelector = PartitionKeyColumns, | |
val where: CqlWhereClause = CqlWhereClause.empty, | |
val limit: Option[Long] = None, | |
val clusteringOrder: Option[ClusteringOrder] = None, | |
val readConf: ReadConf = ReadConf())( | |
implicit | |
val leftClassTag: ClassTag[L], | |
val rightClassTag: ClassTag[R], | |
@transient val rowWriterFactory: RowWriterFactory[L], | |
@transient val rowReaderFactory: RowReaderFactory[R]) | |
extends CassandraRDD[(L, Seq[R])](left.sparkContext, left.dependencies) | |
with CassandraTableRowReaderProvider[R] { | |
override type Self = CassandraJoinRDD[L, R] | |
override protected val classTag = rightClassTag | |
override protected def copy( | |
columnNames: ColumnSelector = columnNames, | |
where: CqlWhereClause = where, | |
limit: Option[Long] = limit, | |
clusteringOrder: Option[ClusteringOrder] = None, | |
readConf: ReadConf = readConf, | |
connector: CassandraConnector = connector): Self = { | |
new CassandraJoinRDD[L, R]( | |
left = left, | |
keyspaceName = keyspaceName, | |
tableName = tableName, | |
connector = connector, | |
columnNames = columnNames, | |
joinColumns = joinColumns, | |
where = where, | |
limit = limit, | |
clusteringOrder = clusteringOrder, | |
readConf = readConf) | |
} | |
lazy val joinColumnNames: Seq[ColumnRef] = joinColumns match { | |
case AllColumns => throw new IllegalArgumentException( | |
"Unable to join against all columns in a Cassandra Table. Only primary key columns allowed.") | |
case PartitionKeyColumns => | |
tableDef.partitionKey.map(col => col.columnName: ColumnRef) | |
case SomeColumns(cs @ _*) => | |
checkColumnsExistence(cs) | |
cs.map { | |
case c: ColumnRef => c | |
case _ => throw new IllegalArgumentException( | |
"Unable to join against unnamed columns. No CQL Functions allowed.") | |
} | |
} | |
override def count(): Long = { | |
columnNames match { | |
case SomeColumns(_) => | |
logWarning("You are about to count rows but an explicit projection has been specified.") | |
case _ => | |
} | |
val counts = | |
new CassandraJoinRDD[L, Long]( | |
left = left, | |
connector = connector, | |
keyspaceName = keyspaceName, | |
tableName = tableName, | |
columnNames = SomeColumns(RowCountRef), | |
joinColumns = joinColumns, | |
where = where, | |
limit = limit, | |
clusteringOrder = clusteringOrder, | |
readConf= readConf) | |
counts.map(_._2).reduce(_ + _) | |
} | |
/** This method will create the RowWriter required before the RDD is serialized. | |
* This is called during getPartitions */ | |
protected def checkValidJoin(): Seq[ColumnRef] = { | |
val partitionKeyColumnNames = tableDef.partitionKey.map(_.columnName).toSet | |
val primaryKeyColumnNames = tableDef.primaryKey.map(_.columnName).toSet | |
val colNames = joinColumnNames.map(_.columnName).toSet | |
// Initialize RowWriter and Query to be used for accessing Cassandra | |
rowWriter.columnNames | |
singleKeyCqlQuery.length | |
def checkSingleColumn(column: ColumnRef): Unit = { | |
require( | |
primaryKeyColumnNames.contains(column.columnName), | |
s"Can't pushdown join on column $column because it is not part of the PRIMARY KEY") | |
} | |
// Make sure we have all of the clustering indexes between the 0th position and the max requested | |
// in the join: | |
val chosenClusteringColumns = tableDef.clusteringColumns | |
.filter(cc => colNames.contains(cc.columnName)) | |
if (!tableDef.clusteringColumns.startsWith(chosenClusteringColumns)) { | |
val maxCol = chosenClusteringColumns.last | |
val maxIndex = maxCol.componentIndex.get | |
val requiredColumns = tableDef.clusteringColumns.takeWhile(_.componentIndex.get <= maxIndex) | |
val missingColumns = requiredColumns.toSet -- chosenClusteringColumns.toSet | |
throw new IllegalArgumentException( | |
s"Can't pushdown join on column $maxCol without also specifying [ $missingColumns ]") | |
} | |
val missingPartitionKeys = partitionKeyColumnNames -- colNames | |
require( | |
missingPartitionKeys.isEmpty, | |
s"Can't join without the full partition key. Missing: [ $missingPartitionKeys ]") | |
joinColumnNames.foreach(checkSingleColumn) | |
joinColumnNames | |
} | |
lazy val rowWriter = implicitly[RowWriterFactory[L]].rowWriter( | |
tableDef, joinColumnNames.toIndexedSeq) | |
def on(joinColumns: ColumnSelector): CassandraJoinRDD[L, R] = { | |
new CassandraJoinRDD[L, R]( | |
left = left, | |
connector = connector, | |
keyspaceName = keyspaceName, | |
tableName = tableName, | |
columnNames = columnNames, | |
joinColumns = joinColumns, | |
where = where, | |
limit = limit, | |
clusteringOrder = clusteringOrder, | |
readConf = readConf) | |
} | |
//We need to make sure we get selectedColumnRefs before serialization so that our RowReader is | |
//built | |
lazy val singleKeyCqlQuery: (String) = { | |
val whereClauses = where.predicates.flatMap(CqlWhereParser.parse) | |
val joinColumns = joinColumnNames.map(_.columnName) | |
val joinColumnPredicates = whereClauses.collect { | |
case EqPredicate(c, _) if joinColumns.contains(c) => c | |
case InPredicate(c) if joinColumns.contains(c) => c | |
case InListPredicate(c, _) if joinColumns.contains(c) => c | |
case RangePredicate(c, _, _) if joinColumns.contains(c) => c | |
}.toSet | |
require( | |
joinColumnPredicates.isEmpty, | |
s"""Columns specified in both the join on clause and the where clause. | |
|Partition key columns are always part of the join clause. | |
|Columns in both: ${joinColumnPredicates.mkString(", ")}""".stripMargin | |
) | |
logDebug("Generating Single Key Query Prepared Statement String") | |
logDebug(s"SelectedColumns : $selectedColumnRefs -- JoinColumnNames : $joinColumnNames") | |
val columns = selectedColumnRefs.map(_.cql).mkString(", ") | |
val joinWhere = joinColumnNames.map(_.columnName).map(name => s"${quote(name)} = :$name") | |
val limitClause = limit.map(limit => s"LIMIT $limit").getOrElse("") | |
val orderBy = clusteringOrder.map(_.toCql(tableDef)).getOrElse("") | |
val filter = (where.predicates ++ joinWhere).mkString(" AND ") | |
val quotedKeyspaceName = quote(keyspaceName) | |
val quotedTableName = quote(tableName) | |
val query = | |
s"SELECT $columns " + | |
s"FROM $quotedKeyspaceName.$quotedTableName " + | |
s"WHERE $filter $limitClause $orderBy" | |
logDebug(s"Query : $query") | |
query | |
} | |
/** | |
* When computing a CassandraPartitionKeyRDD the data is selected via single CQL statements | |
* from the specified C* Keyspace and Table. This will be preformed on whatever data is | |
* available in the previous RDD in the chain. | |
*/ | |
override def compute(split: Partition, context: TaskContext): Iterator[(L, Seq[R])] = { | |
val session = connector.openSession() | |
implicit val pv = protocolVersion(session) | |
val stmt = session.prepare(singleKeyCqlQuery).setConsistencyLevel(consistencyLevel) | |
val bsb = new BoundStatementBuilder[L](rowWriter, stmt, pv, where.values) | |
val metricsUpdater = InputMetricsUpdater(context, readConf) | |
val rowIterator = fetchIterator(session, bsb, left.iterator(split, context)) | |
val countingIterator = new CountingIterator(rowIterator, limit) | |
context.addTaskCompletionListener { (context) => | |
val duration = metricsUpdater.finish() / 1000000000d | |
logDebug( | |
f"Fetched ${countingIterator.count} rows " + | |
f"from $keyspaceName.$tableName " + | |
f"for partition ${split.index} in $duration%.3f s.") | |
session.close() | |
} | |
countingIterator | |
} | |
private def fetchIterator( | |
session: Session, | |
bsb: BoundStatementBuilder[L], | |
lastIt: Iterator[L]): Iterator[(L, Seq[R])] = { | |
val columnNamesArray = selectedColumnRefs.map(_.selectedAs).toArray | |
implicit val pv = protocolVersion(session) | |
for (leftSide <- lastIt ) yield { | |
val rightSide = { | |
val rs = session.execute(bsb.bind(leftSide)) | |
val iterator = new PrefetchingResultSetIterator(rs, fetchSize) | |
iterator.map(rowReader.read(_, columnNamesArray)) | |
}) | |
(leftSide, rightSide) | |
} | |
} | |
override protected def getPartitions: Array[Partition] = { | |
verify() | |
checkValidJoin() | |
left.partitions | |
} | |
override def getPreferredLocations(split: Partition): Seq[String] = left.preferredLocations(split) | |
override def toEmptyCassandraRDD: EmptyCassandraRDD[(L, Seq[R])] = | |
new EmptyCassandraRDD[(L, Seq[R])]( | |
sc = left.sparkContext, | |
keyspaceName = keyspaceName, | |
tableName = tableName, | |
columnNames = columnNames, | |
where = where, | |
limit = limit, | |
clusteringOrder = clusteringOrder, | |
readConf = readConf) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
protocolVersion is not getting recognized as a valid function.
Am I missing a lib?