Created
September 22, 2013 23:49
-
-
Save etorreborre/6664982 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// | |
// Copyright 2013 Paytronix Systems, Inc. | |
// | |
// Licensed under the Apache License, Version 2.0 (the "License"); | |
// you may not use this file except in compliance with the License. | |
// You may obtain a copy of the License at | |
// | |
// http://www.apache.org/licenses/LICENSE-2.0 | |
// | |
// Unless required by applicable law or agreed to in writing, software | |
// distributed under the License is distributed on an "AS IS" BASIS, | |
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
// See the License for the specific language governing permissions and | |
// limitations under the License. | |
// | |
package com.paytronix.datainsights.storage.utils.scoobi | |
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInput, DataInputStream, DataOutput, DataOutputStream, InterruptedIOException, IOException} | |
import java.net.InetAddress | |
import java.util.Arrays | |
import javax.naming.NamingException | |
import com.nicta.scoobi.application.{DListPersister, ScoobiConfiguration} | |
import com.nicta.scoobi.core.DList | |
import com.nicta.scoobi.io.{DataSink, DataSource, InputConverter, OutputConverter} | |
import org.apache.hadoop.conf.Configuration | |
import org.apache.hadoop.io.{NullWritable, Writable} | |
import org.apache.hadoop.hbase.HConstants | |
import org.apache.hadoop.hbase.client.{Delete, HBaseAdmin, HTable, Put, Result => HBaseResult, Scan} | |
import org.apache.hadoop.hbase.io.ImmutableBytesWritable | |
import org.apache.hadoop.hbase.mapreduce.{TableInputFormat, TableMapReduceUtil, TableOutputFormat, TableRecordReaderImpl} | |
import org.apache.hadoop.hbase.util.{Addressing, Base64, Bytes, Strings} | |
import org.apache.hadoop.mapreduce.{InputSplit, Job, JobContext, RecordReader, TaskAttemptContext} | |
import org.apache.hadoop.net.DNS | |
import org.slf4j.LoggerFactory | |
object HBaseIO { | |
def fromTable(table: String, scans: Seq[Scan] = Seq(new Scan()), estimatedSize: Long = 1024*1024*1024): DList[HBaseResult] = | |
DList.fromSource(new HBaseSource(table, scans, estimatedSize)) | |
def toTable[A <: Writable](table: String, dl: DList[A]): DListPersister[A] = | |
new DListPersister(dl, new HBaseSink(table)) | |
class HBaseSource(val tableName: String, val scans: Seq[Scan], val estimatedSize: Long) extends DataSource[ImmutableBytesWritable, HBaseResult, HBaseResult] { | |
val inputFormat = classOf[MultiScanTableInputFormat] | |
def inputCheck(implicit sc: ScoobiConfiguration) = | |
if (!new HBaseAdmin(sc.configuration).isTableAvailable(tableName)) | |
throw new IOException("Table " + tableName + " is not available in HBase, so cannot input from it") | |
def inputConfigure(job: Job)(implicit sc: ScoobiConfiguration) = { | |
val conf = job.getConfiguration | |
conf.set(TableInputFormat.INPUT_TABLE, tableName) | |
MultiScanTableInputFormat.configureScans(conf, scans) | |
} | |
def inputSize(implicit sc: ScoobiConfiguration): Long = | |
estimatedSize | |
lazy val inputConverter = new InputConverter[ImmutableBytesWritable, HBaseResult, HBaseResult] { | |
def fromKeyValue(context: InputContext, k: ImmutableBytesWritable, v: HBaseResult) = v | |
} | |
} | |
class HBaseSink[A <: Writable](val tableName: String) extends DataSink[NullWritable, Writable, A] { | |
val outputFormat = classOf[TableOutputFormat[NullWritable]] | |
val outputKeyClass = classOf[NullWritable] | |
val outputValueClass = classOf[Writable] | |
def outputCheck(implicit sc: ScoobiConfiguration) = | |
if (!new HBaseAdmin(sc.configuration).isTableAvailable(tableName)) | |
throw new IOException("Table " + tableName + " is not available in HBase, so cannot output to it") | |
def outputConfigure(job: Job)(implicit sc: ScoobiConfiguration) = { | |
val conf = job.getConfiguration | |
conf.set(TableOutputFormat.OUTPUT_TABLE, tableName) | |
} | |
lazy val outputConverter = new OutputConverter[NullWritable, Writable, A] { | |
def toKeyValue(a: A) = (NullWritable.get, a) | |
} | |
} | |
} | |
object MultiScanTableInputFormat { | |
val NUM_SCANS = "hbase.mapreduce.numscans" | |
def SCAN(index: Int) = "hbase.mapreduce.scan." + index | |
def configureScans(conf: Configuration, scans: Seq[Scan]): Unit = { | |
// FIXME in the future. This is only compatible with HBase 0.94.2 or so, and becomes invalid for 0.95+ | |
conf.set(NUM_SCANS, scans.size.toString) | |
scans.zipWithIndex.foreach { case (scan, index) => | |
val baos = new ByteArrayOutputStream() | |
val dos = new DataOutputStream(baos) | |
scan.write(dos) | |
conf.set(SCAN(index), Base64.encodeBytes(baos.toByteArray)) | |
} | |
} | |
def recoverScans(conf: Configuration): Seq[Scan] = | |
(0 until (conf.getInt(NUM_SCANS, 0))) map { index => | |
val bais = new ByteArrayInputStream(Base64.decode(conf.get(SCAN(index)))) | |
val dis = new DataInputStream(bais) | |
val scan = new Scan() | |
scan.readFields(dis) | |
scan | |
} | |
} | |
class MultiScanTableInputFormat extends TableInputFormat { | |
private val logger = LoggerFactory.getLogger(getClass) | |
private var scans: Seq[Scan] = Seq.empty | |
private var recordReader: Option[TableRecordMultiReader] = None | |
override def setConf(conf: Configuration): Unit = { | |
super.setConf(conf) | |
scans = MultiScanTableInputFormat.recoverScans(conf) | |
setScan(scans.headOption.orNull) // the single scan will be used for non-row-range parameters, so make sure something is set | |
} | |
@throws(classOf[IOException]) | |
override def createRecordReader(split: InputSplit, context: TaskAttemptContext): RecordReader[ImmutableBytesWritable, HBaseResult] = | |
if (getHTable == null) throw new IOException ( | |
"Cannot create a record reader because of a" + | |
" previous error. Please look at the previous logs lines from" + | |
" the task's full log for more details." | |
) else { | |
val tms = split.asInstanceOf[TableMultiSplit] | |
val trmr = recordReader.getOrElse { | |
val reader = new TableRecordMultiReader() | |
recordReader = Some(reader) | |
reader | |
} | |
trmr.setup(tms.makeScans(getScan), getHTable) | |
try trmr.initialize(tms, context) catch { case e: InterruptedException => | |
throw new InterruptedIOException(e.getMessage) | |
} | |
trmr | |
} | |
private var reverseDNSCacheMap = Map.empty[InetAddress, String] | |
private def reverseDNS(nameServer: String, ipAddress: InetAddress): String = | |
this.reverseDNSCacheMap.get(ipAddress).getOrElse { | |
val hostName = Strings.domainNamePointerToHostName(DNS.reverseDns(ipAddress, nameServer)) | |
reverseDNSCacheMap += ipAddress -> hostName | |
hostName | |
} | |
@throws(classOf[IOException]) | |
override def getSplits(context: JobContext): java.util.List[InputSplit] = | |
if (getHTable == null) throw new IOException("No table was provided.") | |
else { | |
val table = getHTable | |
// Get the name server address and the default value is null. | |
// can't set this because it's private | |
val nameServer = context.getConfiguration.get("hbase.nameserver.address", null) | |
val rawRegionRanges = getHTable.getStartEndKeys | |
if (rawRegionRanges == null || rawRegionRanges.getFirst == null || rawRegionRanges.getFirst.length == 0) { | |
val regLoc = getHTable.getRegionLocation(HConstants.EMPTY_BYTE_ARRAY, false) | |
if (regLoc == null) throw new IOException("Expecting at least one region.") | |
val split = new TableMultiSplit ( | |
getHTable.getTableName, Seq((HConstants.EMPTY_BYTE_ARRAY, HConstants.EMPTY_BYTE_ARRAY)), | |
regLoc.getHostnamePort.split(Addressing.HOSTNAME_PORT_SEPARATOR)(0) | |
) | |
java.util.Arrays.asList(split) | |
} else { | |
val rowRanges = scans.map { scan => | |
(scan.getStartRow, scan.getStopRow) | |
}.sortWith { case ((a, _), (b, _)) => Bytes.compareTo(a, b) < 0 } | |
val regionRanges = rawRegionRanges.getFirst zip rawRegionRanges.getSecond | |
Arrays.asList(regionRanges.flatMap { | |
case (regionStart, regionStop) if !includeRegionInSplit(regionStart, regionStop) => | |
None | |
case (regionStart, regionStop) => | |
val regionServerAddress = table.getRegionLocation(regionStart).getServerAddress | |
val regionAddress = regionServerAddress.getInetSocketAddress.getAddress | |
val regionLocation = try reverseDNS(nameServer, regionAddress) catch { case e: NamingException => | |
logger.error("Cannot resolve the host name for " + regionAddress + " because of " + e) | |
regionServerAddress.getHostname | |
} | |
val overlappingRanges = rowRanges.filter { case (startRow, stopRow) => | |
(startRow.length == 0 || regionStop.length == 0 || Bytes.compareTo(startRow, regionStop) < 0) && | |
(stopRow .length == 0 || Bytes.compareTo(stopRow, regionStart) > 0) | |
} | |
if (overlappingRanges.nonEmpty) { | |
val trimmedRanges = overlappingRanges.map { case (startRow, stopRow) => | |
( | |
if (startRow.length == 0 || Bytes.compareTo(regionStart, startRow) >= 0) regionStart else startRow, | |
if ((stopRow.length == 0 || Bytes.compareTo(regionStop, stopRow) <= 0) && regionStop.length > 0) regionStop else stopRow | |
) | |
} | |
val split = new TableMultiSplit(getHTable.getTableName, trimmedRanges, regionLocation) | |
// really noisy with a large number of splits | |
// if (logger.isDebugEnabled) logger.debug("getSplits: " + split) | |
Some(split) | |
} else None | |
}: _*) | |
} | |
} | |
} | |
class TableMultiSplit(var tableName: Array[Byte], var ranges: Seq[(Array[Byte], Array[Byte])], var regionLocation: String) | |
extends InputSplit with Writable with Comparable[TableMultiSplit] { | |
def this() = this(Array(), Seq((HConstants.EMPTY_BYTE_ARRAY, HConstants.EMPTY_BYTE_ARRAY)), "") | |
def makeScans(donor: Scan): Seq[Scan] = | |
ranges.map { case (startRow, stopRow) => | |
val scan = new Scan(donor) | |
scan.setStartRow(startRow) | |
scan.setStopRow(stopRow) | |
scan | |
} | |
override def getLocations() = Array(regionLocation) | |
override def getLength() = 0L | |
@throws(classOf[IOException]) | |
override def readFields(in: DataInput) = { | |
tableName = Bytes.readByteArray(in) | |
ranges = List.fill(in.readInt())((Bytes.readByteArray(in), Bytes.readByteArray(in))) | |
regionLocation = Bytes.toString(Bytes.readByteArray(in)) | |
} | |
@throws(classOf[IOException]) | |
override def write(out: DataOutput) = { | |
Bytes.writeByteArray(out, tableName) | |
out.writeInt(ranges.length) | |
ranges.foreach { case (start, stop) => | |
Bytes.writeByteArray(out, start) | |
Bytes.writeByteArray(out, stop) | |
} | |
Bytes.writeByteArray(out, Bytes.toBytes(regionLocation)) | |
} | |
override def toString() = | |
regionLocation + ":" + ranges.map { case (start, stop) => Bytes.toStringBinary(start) + "," + Bytes.toStringBinary(stop) }.mkString("[", ", ", "]") | |
override def compareTo(split: TableMultiSplit) = | |
Bytes.compareTo(ranges.head._1, split.ranges.head._1) | |
override def equals(anySplit: Any) = | |
anySplit match { | |
case split: TableMultiSplit if | |
Bytes.equals(tableName, split.tableName) && | |
ranges.length == split.ranges.length && | |
(ranges zip split.ranges).forall { case (a, b) => Bytes.equals(a._1, b._1) && Bytes.equals(a._2, b._2) } && | |
regionLocation == split.regionLocation => true | |
case _ => false | |
} | |
override def hashCode() = | |
List ( | |
Seq ( | |
Option(tableName).map(Arrays.hashCode).getOrElse(0), | |
Option(regionLocation).map(_.hashCode).getOrElse(0) | |
), | |
ranges.flatMap { case (start, stop) => | |
Seq ( | |
Option(start).map(Arrays.hashCode).getOrElse(0), | |
Option(stop).map(Arrays.hashCode).getOrElse(0) | |
) | |
} | |
).flatten.foldLeft(0) { (accum, hash) => 31 * accum + hash } | |
} | |
class TableRecordMultiReader extends RecordReader[ImmutableBytesWritable, HBaseResult] { | |
private val impl = new TableRecordReaderImpl() | |
private var _scans: Seq[Scan] = Seq.empty | |
private var _scanIndex: Int = 0 | |
private var _scanIterator: Iterator[Scan] = Iterator.empty | |
def setup(scans: Seq[Scan], htable: HTable): Unit = { | |
_scans = scans | |
impl.setHTable(htable) | |
} | |
override def close() = impl.close() | |
@throws(classOf[IOException]) | |
override def getCurrentKey(): ImmutableBytesWritable = impl.getCurrentKey() | |
@throws(classOf[IOException]) | |
@throws(classOf[InterruptedException]) | |
override def getCurrentValue(): HBaseResult = impl.getCurrentValue() | |
@throws(classOf[IOException]) | |
@throws(classOf[InterruptedException]) | |
override def initialize(genericSplit: InputSplit, context: TaskAttemptContext): Unit = { | |
_scanIterator = _scans.iterator | |
_scanIndex = 0 | |
impl.setScan(_scanIterator.next) | |
impl.initialize(genericSplit, context) | |
} | |
@throws(classOf[IOException]) | |
@throws(classOf[InterruptedException]) | |
override def nextKeyValue(): Boolean = | |
if (impl.nextKeyValue()) true | |
else if (_scanIterator.hasNext) { | |
val newScan = _scanIterator.next | |
_scanIndex += 1 | |
impl.setScan(newScan) | |
impl.restart(newScan.getStartRow) | |
nextKeyValue() | |
} else false | |
override def getProgress() = { | |
val fractionPerScan = 1.0f / (_scans.length: Float) | |
(fractionPerScan * _scanIndex) + impl.getProgress() | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment