Last active
October 28, 2016 09:13
-
-
Save skylander86/9b7d6b1ca64cb88ef669 to your computer and use it in GitHub Desktop.
Scala module for I/O on different filesystems
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package nlp.salience.fsutils | |
/** | |
* @author yc <https://github.com/skylander86/> | |
* @version 0.1 | |
* | |
* This package contains code for transparently reading and writing files to/from supported filesystems. | |
* This is useful as a drop in replacement that understands filesystem differences. | |
* Right now, it only supports local file, S3 (either using Hadoop or Amazon AWS SDK) and Hadoop filesystem. | |
* | |
* Example usage: | |
* | |
* {{{ | |
* val uos = URIOutputStream(new URI("s3://bucket/key/a.txt")) | |
* uos.close() // must close!!! thats where all the magic happens | |
* | |
* val uis = URIInputStream(new URI("s3://bucket/key/a.txt")) | |
* | |
* try { | |
* val luis = LocalizedURIInputStream(new URI("s3://bucket/key/a.txt")) | |
* println(s"Local file is ${luis.file}") | |
* } | |
* finally { luis.close() } | |
* }}} | |
* | |
* The following dependencies are required: | |
* - `"com.amazonaws" % "aws-java-sdk" % "1.10.56"` for use with `s3` filesystem. See [[https://aws.amazon.com/sdk-for-java/ AWS Java SDK]]. | |
* - `"org.apache.hadoop" % "hadoop-hdfs" % "<version>"` for use with `s3a`, `s3n`, and `hdfs` filesystems. See [[http://hadoop.apache.org/ Hadoop]]. | |
*/ | |
import java.io.{BufferedInputStream, BufferedOutputStream, ByteArrayInputStream, ByteArrayOutputStream, File, FileInputStream, FileOutputStream, InputStream, OutputStream} | |
import java.net.URI | |
import java.nio.file.{Files, Path, Paths} | |
import com.amazonaws.services.s3.AmazonS3Client | |
import org.slf4j.LoggerFactory | |
import helper._ | |
/** | |
* @brief An output stream that saves contents to the given [[java.net.URI]] location on [[close]]. | |
* @details The underlying implementation is based on [[java.io.ByteArrayOutputStream]] for remote files (and [[java.io.BufferedOutputStream]]/[[java.io.FileOutputStream]] for local files) and overriding the [[close]] method to upload the contents. | |
* | |
* We determine the filesystem to use based on the URI scheme. If URI scheme is `file` or `null`, we default to the local filesystem. | |
* | |
* For some filesystems, we can pass in additional configuration options through a [[Map]] object. | |
* - `s3`: `accessKeyId`, `secretAccessKey`, `credentials` and `region`. Credentials can also be passed through the user info section of the URI, i.e., `s3://<accessKeyId>:<secretAccessKey>@<bucket>/<file>` | |
* - `s3n`, `s3a`, `hdfs`: `configuration` | |
* - `file`, `null`: `file://./a/b.txt` for current working directory | |
* | |
* For details about how these settings are passed on to the underlying filesystem, please refer to the source code. | |
* | |
* @param uri [[java.net.URI]] we want to save our file at | |
* @param data byte array we want to save | |
* @param config configuration settings to pass to the underlying filesystem | |
* @return an [[Any]] object depending on the library | |
* @throws [[java.lang.UnsupportedOperationException]] when URI scheme is supported | |
*/ | |
case class URIOutputStream(uri: URI, config: Map[String, Any] = Map()) extends OutputStream { | |
private val logger = LoggerFactory.getLogger(getClass) | |
val is_local = { | |
val scheme = uri.getScheme() | |
(scheme == "file" || scheme == null) | |
} | |
private var closed = false | |
val os: OutputStream = if (is_local) new BufferedOutputStream(new FileOutputStream((if (uri.getAuthority() == null) "" else uri.getAuthority()) + uri.getPath())) else new ByteArrayOutputStream() | |
override def flush() = os.flush() | |
override def write(b: Array[Byte]) = os.write(b) | |
override def write(b: Array[Byte], off: Int, len: Int) = os.write(b, off, len) | |
override def write(b: Int) = os.write(b) | |
override def close() { | |
if (!closed) { | |
if (is_local) { | |
logger.debug(s"File written to <${uri}>.") | |
} | |
else { | |
val baos = os.asInstanceOf[ByteArrayOutputStream] | |
val bytes = baos.toByteArray() | |
uri.getScheme() match { | |
case "s3" => | |
val s3client = createAmazonS3Client(uri, config) | |
val bucket = uri.getHost() | |
val key = uri.getPath().substring(1) // to ignore the leading / | |
val metadata = config.getOrElse("metadata", new com.amazonaws.services.s3.model.ObjectMetadata()).asInstanceOf[com.amazonaws.services.s3.model.ObjectMetadata] | |
s3client.putObject(bucket, key, new ByteArrayInputStream(bytes), metadata) | |
case "s3n" | "s3a" | "hdfs" => | |
val fs = getHadoopFilesystem(uri, config) | |
val dos = fs.create(new org.apache.hadoop.fs.Path(uri)) | |
dos.write(bytes) | |
dos.close() | |
fs.close() | |
case "file" | null => | |
val bos = new BufferedOutputStream(new FileOutputStream((if (uri.getAuthority() == null) "" else uri.getAuthority()) + uri.getPath())) | |
bos.write(bytes) | |
bos.close() | |
case _ => throw new UnsupportedOperationException(s"URI scheme `${uri.getScheme()}` is not supported.") | |
} | |
logger.debug(s"Saved ${bytes.size} bytes to <${uri}>.") | |
} | |
closed = true | |
os.close() | |
} | |
} | |
} | |
/** | |
* @brief Masquerades as a [[java.io.BufferedInputStream]] pointing to the given [[java.net.URI]]. | |
*/ | |
class URIInputStream(is: InputStream) extends BufferedInputStream(is) | |
/** | |
* @brief Companion object to [[URIInputStream]] | |
* | |
* @param uri [[java.net.URI]] to read from | |
* @param config configuration settings to pass to the underlying filesystem. See [[URIOutputStream]] for detailed explanation of the config parameter. | |
* @return [[URIInputStream]] for the given [[java.net.URI]] and [[config]] | |
* @see [[URIOutputStream]] | |
* @throws [[java.lang.UnsupportedOperationException]] when URI scheme is supported | |
*/ | |
object URIInputStream { | |
private val logger = LoggerFactory.getLogger(getClass) | |
def apply(uri: URI, config: Map[String, Any] = Map()): InputStream = { | |
val is = uri.getScheme() match { | |
case "s3" => | |
val s3client = createAmazonS3Client(uri, config) | |
val bucket = uri.getHost() | |
val key = uri.getPath().substring(1) // to ignore the leading / | |
val s3object = s3client.getObject(bucket, key) | |
s3object.getObjectContent().asInstanceOf[InputStream] | |
case "s3n" | "s3a" | "hdfs" => | |
val fs = getHadoopFilesystem(uri, config) | |
fs.open(new org.apache.hadoop.fs.Path(uri)) | |
case "file" | null => | |
new FileInputStream((if (uri.getAuthority() == null) "" else uri.getAuthority()) + uri.getPath()) | |
case _ => throw new UnsupportedOperationException(s"URI scheme `${uri.getScheme()}` is not supported.") | |
} | |
logger.debug(s"Opened input stream for <${uri}>.") | |
new URIInputStream(is) | |
} | |
} | |
/** | |
* @brief An extension of [[URIInputStream]] that actually downloads the URI's content to a local temporary file, and deleting it on close. | |
* @details This is useful for functions that require reading from a local file. | |
* | |
* @param uri [[java.net.URI]] of remote file | |
* @param config configuration settings to pass to the underlying filesystem. See [[URIOutputStream]] for detailed explanation of the config parameter. | |
* @see [[URIOutputStream]] | |
*/ | |
class LocalizedURIInputStream(uri: URI, temp_path: Path) extends URIInputStream(new FileInputStream(temp_path.toFile())) { | |
private val logger = LoggerFactory.getLogger(getClass) | |
override def close() { | |
val scheme = uri.getScheme() | |
if (scheme != "file" && scheme != null) { | |
temp_path.toFile.delete() | |
logger.debug(s"Temporary file ${temp_path} deleted.") | |
} | |
} | |
def file = temp_path.toFile | |
} | |
object LocalizedURIInputStream { | |
private val logger = LoggerFactory.getLogger(getClass) | |
def apply(uri: URI, config: Map[String, Any] = Map()): LocalizedURIInputStream = { | |
val scheme = uri.getScheme() | |
if (scheme == "file" || scheme == null) | |
new LocalizedURIInputStream(uri, Paths.get((if (uri.getAuthority() == null) "" else uri.getAuthority()) + uri.getPath())) | |
else { | |
val temp_path = Files.createTempFile("fsutils_", "") | |
val uis = URIInputStream(uri, config) | |
val len = Files.copy(uis, temp_path, java.nio.file.StandardCopyOption.REPLACE_EXISTING) | |
uis.close() | |
logger.debug(s"Downloaded ${len} bytes from <${uri}> to <${temp_path}>.") | |
temp_path.toFile().deleteOnExit() | |
new LocalizedURIInputStream(uri, temp_path) | |
} | |
} | |
} | |
object URIUtils { | |
def exists(uri: URI, config: Map[String, Any] = Map()): Boolean = { | |
uri.getScheme() match { | |
case "s3" => | |
val s3client = createAmazonS3Client(uri, config) | |
val bucket = uri.getHost() | |
val key = uri.getPath().substring(1) // to ignore the leading / | |
s3client.doesObjectExist(bucket, key) | |
case "s3n" | "s3a" | "hdfs" => | |
val fs = getHadoopFilesystem(uri, config) | |
fs.exists(new org.apache.hadoop.fs.Path(uri)) | |
case "file" | null => | |
val f = new File((if (uri.getAuthority() == null) "" else uri.getAuthority()) + uri.getPath()) | |
f.exists() && f.isFile() | |
case _ => | |
throw new UnsupportedOperationException(s"URI scheme `${uri.getScheme()}` is not supported.") | |
false | |
} | |
} | |
} | |
package object helper { | |
/** | |
* @brief Helper function to create an [[com.amazonaws.services.s3.AmazonS3Client]]. | |
* @param uri [[java.net.URI]] of remote file | |
* @param config configuration settings to pass to the underlying filesystem. See [[URIOutputStream]] for detailed explanation of the config parameter. | |
* @return an [[com.amazonaws.services.s3.AmazonS3Client]] with the appropriate settings taken from [[config]] | |
*/ | |
def createAmazonS3Client(uri: URI, config: Map[String, Any]): AmazonS3Client = { | |
val s3client = { | |
if (config.contains("accessKeyId") && config.contains("secretAccessKey")) new AmazonS3Client(new com.amazonaws.auth.BasicAWSCredentials(config("accessKeyId").asInstanceOf[String], config("secretAccessKey").asInstanceOf[String])) | |
else if (config.contains("credentials")) new AmazonS3Client(config("credentials").asInstanceOf[com.amazonaws.auth.AWSCredentials]) | |
else if (uri.getUserInfo() != null) { | |
val Array(accessKeyId, secretAccessKey) = uri.getUserInfo().split(":", 2) | |
new AmazonS3Client(new com.amazonaws.auth.BasicAWSCredentials(accessKeyId, secretAccessKey)) | |
} | |
else new AmazonS3Client() | |
} | |
if (config.contains("region")) s3client.setRegion(config("region").asInstanceOf[com.amazonaws.regions.Region]) | |
s3client | |
} | |
/** | |
* @brief Helper function to create a [[org.apache.hadoop.fs.FileSystem]] | |
* | |
* @param uri [[java.net.URI]] of remote file | |
* @param config configuration settings to pass to the underlying filesystem. See [[URIOutputStream]] for detailed explanation of the config parameter. | |
* @return a [[org.apache.hadoop.fs.FileSystem]] | |
*/ | |
def getHadoopFilesystem(uri: URI, config: Map[String, Any]) = { | |
val conf = config.getOrElse("configuration", new org.apache.hadoop.conf.Configuration()).asInstanceOf[org.apache.hadoop.conf.Configuration] | |
org.apache.hadoop.fs.FileSystem.get(uri, conf); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment