Last active
December 15, 2015 18:39
-
-
Save berngp/5305183 to your computer and use it in GitHub Desktop.
HDFS Thrift Serialization in Scala
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
* Copyright 2012-2013 Bernardo Gomez Palacio. | |
* | |
* Licensed under the Apache License, Version 2.0 (the "License"); | |
* you may not use this file except in compliance with the License. | |
* You may obtain a copy of the License at | |
* | |
* http://www.apache.org/licenses/LICENSE-2.0 | |
* | |
* Unless required by applicable law or agreed to in writing, software | |
* distributed under the License is distributed on an "AS IS" BASIS, | |
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
* See the License for the specific language governing permissions and | |
* limitations under the License. | |
*/ | |
package com.github.berngp.thriftexample | |
import java.io._ | |
import net.liftweb.common.Logger | |
import org.apache.hadoop.io.serializer.{Serializer, Deserializer} | |
import org.apache.hadoop.util.ReflectionUtils | |
import org.apache.thrift._ | |
import org.apache.thrift.protocol.{TBinaryProtocol, TProtocolFactory, TProtocol} | |
import org.apache.thrift.transport.{TTransport, TIOStreamTransport} | |
abstract class ThriftBytesWritable[T <: TBase[T, F], F <: TFieldIdEnum](bytes: Array[Byte], length: Int) | |
extends org.apache.hadoop.io.BytesWritable(bytes, length) | |
with Logger { | |
private var base: Option[T] = None | |
def this(b: Array[Byte]) = this(b, b.length) | |
def this() = this(Array.empty[Byte]) | |
def this(thrift: T) = { | |
this() | |
base = Some(thrift) | |
resetBytesFromThrift | |
} | |
private def resetBytesFromThrift = { | |
val array = toBytes | |
setCapacity(array.length) | |
set(array, 0, array.length) | |
} | |
def newBaseInstance: T | |
val baseClass: Class[T] | |
private val factory: TProtocolFactory = new TBinaryProtocol.Factory() | |
/** Sets the Max Length of bytes allocated for serializing the Thrift Object, please overload if the object you expect is larger. | |
* Current Max Value is **1073741824**. | |
*/ | |
protected def getMaxLength = 1073741824 | |
private def _ensureBaseInstance = synchronized { | |
base match { | |
case None => | |
base = Some(newBaseInstance) | |
case _ => | |
} | |
base | |
} | |
private def toBytes: Array[Byte] = _ensureBaseInstance match { | |
case Some(t) => | |
val ser = new TSerializer(factory) | |
val bytes = ser.serialize(t) | |
bytes | |
case None => | |
Array.empty[Byte] | |
} | |
override def getBytes(): Array[Byte] = toBytes | |
override def copyBytes(): Array[Byte] = { | |
val array = toBytes | |
if (array.length > 0) { | |
val buffer = new Array[Byte](array.length) | |
array.copyToArray(buffer) | |
buffer | |
} else { | |
Array.empty[Byte] | |
} | |
} | |
@throws[IOException] | |
def getFromBytes: T = _ensureBaseInstance match { | |
case Some(t) => | |
readFields(new DataInputStream(new ByteArrayInputStream(this.bytes))) | |
t | |
case None => | |
throw new IllegalStateException("Unable to ensure a base reference.") | |
} | |
@throws[IOException] | |
override def write(out: DataOutput) { | |
val bytes = toBytes | |
require(bytes.length < getMaxLength, | |
s"Length of the writable ${bytes.length} exceeds the max allowed of ${getMaxLength} bytes, if intended please override `getMaxLength`.") | |
if (bytes.length > 0) { | |
out.writeInt(bytes.length) | |
out.write(bytes, 0, bytes.length) | |
} else { | |
out.write(0) | |
} | |
} | |
@Override | |
@throws[IOException] | |
override def readFields(in: DataInput) = _ensureBaseInstance match { | |
case Some(t) => | |
val length = in.readInt() | |
require(length < getMaxLength, | |
s"Length of the writable [${length}}] exceeds the max allowed of ${getMaxLength} bytes, if intended please override `getMaxLength`.") | |
val buff = new Array[Byte](length) | |
in.readFully(buff, 0, length) | |
val dSer = new TDeserializer(factory) | |
dSer.deserialize(t, buff) | |
case None => | |
throw new IllegalStateException("Unable to ensure a base reference.") | |
} | |
override def toString(): String = { | |
s""" | |
Base Class:${baseClass} | |
TBase:${base} | |
Bytes: | |
${super.toString()}""".stripMargin | |
} | |
} | |
trait ThriftHadoopWritable[T <: TBase[_, _], F <: TFieldIdEnum] | |
extends TBase[T, F] | |
with org.apache.hadoop.io.Writable | |
with Logger { | |
private val factory: TProtocolFactory = new TBinaryProtocol.Factory() | |
/** Sets the Max Length of bytes allocated for serializing the Thrift Object, please overload if the object you expect is larger. | |
* Current Max Value is **1073741824**. | |
* */ | |
protected def getMaxLength = 1073741824 | |
@throws[IOException] | |
override def write(out: DataOutput) { | |
val ser = new TSerializer(factory) | |
val bytes = ser.serialize(this) | |
require(bytes.length < getMaxLength, | |
s"Length of the writable ${bytes.length} exceeds the max allowed of ${getMaxLength} bytes, if intended please override `getMaxLength`.") | |
if (bytes.length > 0) { | |
out.writeInt(bytes.length) | |
out.write(bytes, 0, bytes.length) | |
} else { | |
out.write(0) | |
} | |
} | |
@Override | |
@throws[IOException] | |
override def readFields(in: DataInput) { | |
val length = in.readInt() | |
require(length < getMaxLength, | |
s"Length of the writable [${length}}] exceeds the max allowed of ${getMaxLength} bytes, if intended please override `getMaxLength`.") | |
val buff = new Array[Byte](length) | |
in.readFully(buff, 0, length) | |
val dser = new TDeserializer(factory) | |
dser.deserialize(this, buff) | |
} | |
} | |
/** | |
* | |
* TODO Refacotr, avoid duplication on the _*Serializers_ | |
*/ | |
class ThriftSerialization[T <: TBase[T, _]] | |
extends org.apache.hadoop.io.serializer.Serialization[T] with Logger { | |
def accept(c: Class[_]): Boolean = { | |
(c: @unchecked).getInterfaces.contains(classOf[TBase[_, _]]) | |
} | |
private def getProtocolFactory = new TBinaryProtocol.Factory() | |
def getSerializer(c: Class[T]): Serializer[T] = new TSerializerAdapter | |
class TSerializerAdapter extends org.apache.hadoop.io.serializer.Serializer[T] { | |
protected val factory: TProtocolFactory = getProtocolFactory | |
private def getOutputTransport(out: OutputStream): TIOStreamTransport = { | |
new TIOStreamTransport(out) | |
} | |
private def getOutputProtocol(transport: TTransport): TProtocol = { | |
factory getProtocol transport | |
} | |
private var _transport: TTransport = null | |
private var _protocol: TProtocol = null | |
private var _out: OutputStream = null | |
def open(out: OutputStream) = synchronized { | |
require(out != null, "OutputStream required!") | |
_out = out | |
_transport = getOutputTransport(_out) | |
_protocol = getOutputProtocol(_transport) | |
} | |
@throws[IOException] | |
def serialize(t: T) { | |
require(_protocol != null, "A Transport Protocol is missing, please open the Serializer!") | |
try { | |
t.write(_protocol) | |
} catch { | |
case e: TException => | |
throw new IOException(e) | |
} | |
} | |
//TODO Refactor, clean duplication and consider using tryo | |
@throws[IOException] | |
def close() = synchronized { | |
if (_transport != null) { | |
try { | |
_transport.close() | |
} catch { | |
case t: Throwable => | |
warn("Throwable caught while closing transport.", t) | |
} | |
} | |
if (_out != null) { | |
try { | |
_out.close() | |
} catch { | |
case t: Throwable => | |
warn("Throwable caught while closing Output Stream.", t) | |
} | |
} | |
} | |
} | |
def getDeserializer(c: Class[T]): Deserializer[T] = new TDeserializerAdapter(c) | |
class TDeserializerAdapter(tClass: Class[T]) extends org.apache.hadoop.io.serializer.Deserializer[T] { | |
protected val factory: TProtocolFactory = getProtocolFactory | |
private def getInputTransport(in: InputStream): TIOStreamTransport = { | |
new TIOStreamTransport(in) | |
} | |
private def getInputProtocol(transport: TTransport): TProtocol = { | |
factory getProtocol transport | |
} | |
private var _transport: TTransport = null | |
private var _protocol: TProtocol = null | |
private var _in: InputStream = null | |
def open(in: InputStream) = synchronized { | |
_in = in | |
_transport = getInputTransport(_in) | |
_protocol = getInputProtocol(_transport) | |
} | |
private def _getThriftBase(t: T) = { | |
if (t == null) { | |
ReflectionUtils.newInstance(tClass, null) | |
} else { | |
t.clear() | |
t | |
} | |
} | |
@throws[IOException] | |
def deserialize(t: T): T = { | |
require(_protocol != null, "A Transport Protocol is missing, please open the Deserializer!") | |
val base = _getThriftBase(t) | |
base.read(_protocol) | |
base | |
} | |
//TODO Refactor, clean duplication and consider using tryo | |
@throws[IOException] | |
def close() = synchronized { | |
if (_transport != null) { | |
try { | |
_transport.close() | |
} catch { | |
case t: Throwable => | |
warn("Throwable caught while closing transport.", t) | |
} | |
} | |
if (_in != null) { | |
try { | |
_in.close() | |
} catch { | |
case t: Throwable => | |
warn("Throwable caught while closing InputStream.", t) | |
} | |
} | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment