Skip to content

Instantly share code, notes, and snippets.

@etorreborre
Created September 22, 2013 23:49
Show Gist options
  • Save etorreborre/6664982 to your computer and use it in GitHub Desktop.
Save etorreborre/6664982 to your computer and use it in GitHub Desktop.
//
// Copyright 2013 Paytronix Systems, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package com.paytronix.datainsights.storage.utils.scoobi
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInput, DataInputStream, DataOutput, DataOutputStream, InterruptedIOException, IOException}
import java.net.InetAddress
import java.util.Arrays
import javax.naming.NamingException
import com.nicta.scoobi.application.{DListPersister, ScoobiConfiguration}
import com.nicta.scoobi.core.DList
import com.nicta.scoobi.io.{DataSink, DataSource, InputConverter, OutputConverter}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.io.{NullWritable, Writable}
import org.apache.hadoop.hbase.HConstants
import org.apache.hadoop.hbase.client.{Delete, HBaseAdmin, HTable, Put, Result => HBaseResult, Scan}
import org.apache.hadoop.hbase.io.ImmutableBytesWritable
import org.apache.hadoop.hbase.mapreduce.{TableInputFormat, TableMapReduceUtil, TableOutputFormat, TableRecordReaderImpl}
import org.apache.hadoop.hbase.util.{Addressing, Base64, Bytes, Strings}
import org.apache.hadoop.mapreduce.{InputSplit, Job, JobContext, RecordReader, TaskAttemptContext}
import org.apache.hadoop.net.DNS
import org.slf4j.LoggerFactory
object HBaseIO {
def fromTable(table: String, scans: Seq[Scan] = Seq(new Scan()), estimatedSize: Long = 1024*1024*1024): DList[HBaseResult] =
DList.fromSource(new HBaseSource(table, scans, estimatedSize))
def toTable[A <: Writable](table: String, dl: DList[A]): DListPersister[A] =
new DListPersister(dl, new HBaseSink(table))
class HBaseSource(val tableName: String, val scans: Seq[Scan], val estimatedSize: Long) extends DataSource[ImmutableBytesWritable, HBaseResult, HBaseResult] {
val inputFormat = classOf[MultiScanTableInputFormat]
def inputCheck(implicit sc: ScoobiConfiguration) =
if (!new HBaseAdmin(sc.configuration).isTableAvailable(tableName))
throw new IOException("Table " + tableName + " is not available in HBase, so cannot input from it")
def inputConfigure(job: Job)(implicit sc: ScoobiConfiguration) = {
val conf = job.getConfiguration
conf.set(TableInputFormat.INPUT_TABLE, tableName)
MultiScanTableInputFormat.configureScans(conf, scans)
}
def inputSize(implicit sc: ScoobiConfiguration): Long =
estimatedSize
lazy val inputConverter = new InputConverter[ImmutableBytesWritable, HBaseResult, HBaseResult] {
def fromKeyValue(context: InputContext, k: ImmutableBytesWritable, v: HBaseResult) = v
}
}
class HBaseSink[A <: Writable](val tableName: String) extends DataSink[NullWritable, Writable, A] {
val outputFormat = classOf[TableOutputFormat[NullWritable]]
val outputKeyClass = classOf[NullWritable]
val outputValueClass = classOf[Writable]
def outputCheck(implicit sc: ScoobiConfiguration) =
if (!new HBaseAdmin(sc.configuration).isTableAvailable(tableName))
throw new IOException("Table " + tableName + " is not available in HBase, so cannot output to it")
def outputConfigure(job: Job)(implicit sc: ScoobiConfiguration) = {
val conf = job.getConfiguration
conf.set(TableOutputFormat.OUTPUT_TABLE, tableName)
}
lazy val outputConverter = new OutputConverter[NullWritable, Writable, A] {
def toKeyValue(a: A) = (NullWritable.get, a)
}
}
}
object MultiScanTableInputFormat {
val NUM_SCANS = "hbase.mapreduce.numscans"
def SCAN(index: Int) = "hbase.mapreduce.scan." + index
def configureScans(conf: Configuration, scans: Seq[Scan]): Unit = {
// FIXME in the future. This is only compatible with HBase 0.94.2 or so, and becomes invalid for 0.95+
conf.set(NUM_SCANS, scans.size.toString)
scans.zipWithIndex.foreach { case (scan, index) =>
val baos = new ByteArrayOutputStream()
val dos = new DataOutputStream(baos)
scan.write(dos)
conf.set(SCAN(index), Base64.encodeBytes(baos.toByteArray))
}
}
def recoverScans(conf: Configuration): Seq[Scan] =
(0 until (conf.getInt(NUM_SCANS, 0))) map { index =>
val bais = new ByteArrayInputStream(Base64.decode(conf.get(SCAN(index))))
val dis = new DataInputStream(bais)
val scan = new Scan()
scan.readFields(dis)
scan
}
}
class MultiScanTableInputFormat extends TableInputFormat {
private val logger = LoggerFactory.getLogger(getClass)
private var scans: Seq[Scan] = Seq.empty
private var recordReader: Option[TableRecordMultiReader] = None
override def setConf(conf: Configuration): Unit = {
super.setConf(conf)
scans = MultiScanTableInputFormat.recoverScans(conf)
setScan(scans.headOption.orNull) // the single scan will be used for non-row-range parameters, so make sure something is set
}
@throws(classOf[IOException])
override def createRecordReader(split: InputSplit, context: TaskAttemptContext): RecordReader[ImmutableBytesWritable, HBaseResult] =
if (getHTable == null) throw new IOException (
"Cannot create a record reader because of a" +
" previous error. Please look at the previous logs lines from" +
" the task's full log for more details."
) else {
val tms = split.asInstanceOf[TableMultiSplit]
val trmr = recordReader.getOrElse {
val reader = new TableRecordMultiReader()
recordReader = Some(reader)
reader
}
trmr.setup(tms.makeScans(getScan), getHTable)
try trmr.initialize(tms, context) catch { case e: InterruptedException =>
throw new InterruptedIOException(e.getMessage)
}
trmr
}
private var reverseDNSCacheMap = Map.empty[InetAddress, String]
private def reverseDNS(nameServer: String, ipAddress: InetAddress): String =
this.reverseDNSCacheMap.get(ipAddress).getOrElse {
val hostName = Strings.domainNamePointerToHostName(DNS.reverseDns(ipAddress, nameServer))
reverseDNSCacheMap += ipAddress -> hostName
hostName
}
@throws(classOf[IOException])
override def getSplits(context: JobContext): java.util.List[InputSplit] =
if (getHTable == null) throw new IOException("No table was provided.")
else {
val table = getHTable
// Get the name server address and the default value is null.
// can't set this because it's private
val nameServer = context.getConfiguration.get("hbase.nameserver.address", null)
val rawRegionRanges = getHTable.getStartEndKeys
if (rawRegionRanges == null || rawRegionRanges.getFirst == null || rawRegionRanges.getFirst.length == 0) {
val regLoc = getHTable.getRegionLocation(HConstants.EMPTY_BYTE_ARRAY, false)
if (regLoc == null) throw new IOException("Expecting at least one region.")
val split = new TableMultiSplit (
getHTable.getTableName, Seq((HConstants.EMPTY_BYTE_ARRAY, HConstants.EMPTY_BYTE_ARRAY)),
regLoc.getHostnamePort.split(Addressing.HOSTNAME_PORT_SEPARATOR)(0)
)
java.util.Arrays.asList(split)
} else {
val rowRanges = scans.map { scan =>
(scan.getStartRow, scan.getStopRow)
}.sortWith { case ((a, _), (b, _)) => Bytes.compareTo(a, b) < 0 }
val regionRanges = rawRegionRanges.getFirst zip rawRegionRanges.getSecond
Arrays.asList(regionRanges.flatMap {
case (regionStart, regionStop) if !includeRegionInSplit(regionStart, regionStop) =>
None
case (regionStart, regionStop) =>
val regionServerAddress = table.getRegionLocation(regionStart).getServerAddress
val regionAddress = regionServerAddress.getInetSocketAddress.getAddress
val regionLocation = try reverseDNS(nameServer, regionAddress) catch { case e: NamingException =>
logger.error("Cannot resolve the host name for " + regionAddress + " because of " + e)
regionServerAddress.getHostname
}
val overlappingRanges = rowRanges.filter { case (startRow, stopRow) =>
(startRow.length == 0 || regionStop.length == 0 || Bytes.compareTo(startRow, regionStop) < 0) &&
(stopRow .length == 0 || Bytes.compareTo(stopRow, regionStart) > 0)
}
if (overlappingRanges.nonEmpty) {
val trimmedRanges = overlappingRanges.map { case (startRow, stopRow) =>
(
if (startRow.length == 0 || Bytes.compareTo(regionStart, startRow) >= 0) regionStart else startRow,
if ((stopRow.length == 0 || Bytes.compareTo(regionStop, stopRow) <= 0) && regionStop.length > 0) regionStop else stopRow
)
}
val split = new TableMultiSplit(getHTable.getTableName, trimmedRanges, regionLocation)
// really noisy with a large number of splits
// if (logger.isDebugEnabled) logger.debug("getSplits: " + split)
Some(split)
} else None
}: _*)
}
}
}
class TableMultiSplit(var tableName: Array[Byte], var ranges: Seq[(Array[Byte], Array[Byte])], var regionLocation: String)
extends InputSplit with Writable with Comparable[TableMultiSplit] {
def this() = this(Array(), Seq((HConstants.EMPTY_BYTE_ARRAY, HConstants.EMPTY_BYTE_ARRAY)), "")
def makeScans(donor: Scan): Seq[Scan] =
ranges.map { case (startRow, stopRow) =>
val scan = new Scan(donor)
scan.setStartRow(startRow)
scan.setStopRow(stopRow)
scan
}
override def getLocations() = Array(regionLocation)
override def getLength() = 0L
@throws(classOf[IOException])
override def readFields(in: DataInput) = {
tableName = Bytes.readByteArray(in)
ranges = List.fill(in.readInt())((Bytes.readByteArray(in), Bytes.readByteArray(in)))
regionLocation = Bytes.toString(Bytes.readByteArray(in))
}
@throws(classOf[IOException])
override def write(out: DataOutput) = {
Bytes.writeByteArray(out, tableName)
out.writeInt(ranges.length)
ranges.foreach { case (start, stop) =>
Bytes.writeByteArray(out, start)
Bytes.writeByteArray(out, stop)
}
Bytes.writeByteArray(out, Bytes.toBytes(regionLocation))
}
override def toString() =
regionLocation + ":" + ranges.map { case (start, stop) => Bytes.toStringBinary(start) + "," + Bytes.toStringBinary(stop) }.mkString("[", ", ", "]")
override def compareTo(split: TableMultiSplit) =
Bytes.compareTo(ranges.head._1, split.ranges.head._1)
override def equals(anySplit: Any) =
anySplit match {
case split: TableMultiSplit if
Bytes.equals(tableName, split.tableName) &&
ranges.length == split.ranges.length &&
(ranges zip split.ranges).forall { case (a, b) => Bytes.equals(a._1, b._1) && Bytes.equals(a._2, b._2) } &&
regionLocation == split.regionLocation => true
case _ => false
}
override def hashCode() =
List (
Seq (
Option(tableName).map(Arrays.hashCode).getOrElse(0),
Option(regionLocation).map(_.hashCode).getOrElse(0)
),
ranges.flatMap { case (start, stop) =>
Seq (
Option(start).map(Arrays.hashCode).getOrElse(0),
Option(stop).map(Arrays.hashCode).getOrElse(0)
)
}
).flatten.foldLeft(0) { (accum, hash) => 31 * accum + hash }
}
class TableRecordMultiReader extends RecordReader[ImmutableBytesWritable, HBaseResult] {
private val impl = new TableRecordReaderImpl()
private var _scans: Seq[Scan] = Seq.empty
private var _scanIndex: Int = 0
private var _scanIterator: Iterator[Scan] = Iterator.empty
def setup(scans: Seq[Scan], htable: HTable): Unit = {
_scans = scans
impl.setHTable(htable)
}
override def close() = impl.close()
@throws(classOf[IOException])
override def getCurrentKey(): ImmutableBytesWritable = impl.getCurrentKey()
@throws(classOf[IOException])
@throws(classOf[InterruptedException])
override def getCurrentValue(): HBaseResult = impl.getCurrentValue()
@throws(classOf[IOException])
@throws(classOf[InterruptedException])
override def initialize(genericSplit: InputSplit, context: TaskAttemptContext): Unit = {
_scanIterator = _scans.iterator
_scanIndex = 0
impl.setScan(_scanIterator.next)
impl.initialize(genericSplit, context)
}
@throws(classOf[IOException])
@throws(classOf[InterruptedException])
override def nextKeyValue(): Boolean =
if (impl.nextKeyValue()) true
else if (_scanIterator.hasNext) {
val newScan = _scanIterator.next
_scanIndex += 1
impl.setScan(newScan)
impl.restart(newScan.getStartRow)
nextKeyValue()
} else false
override def getProgress() = {
val fractionPerScan = 1.0f / (_scans.length: Float)
(fractionPerScan * _scanIndex) + impl.getProgress()
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment