-
-
Save seddonm1/2fb5a6892989fe7bf246022a7bd586ee to your computer and use it in GitHub Desktop.
//! Object store that represents the amazon s3 file system | |
use std::io::Read; | |
use std::str::FromStr; | |
use std::sync::Arc; | |
use std::time::Duration; | |
use async_trait::async_trait; | |
use futures::{stream, AsyncRead}; | |
use std::sync::mpsc; | |
use datafusion::datasource::object_store::{ | |
FileMeta, FileMetaStream, ListEntryStream, ObjectReader, ObjectStore, | |
}; | |
use datafusion::error::{DataFusionError, Result}; | |
use datafusion::datasource::object_store::SizedFile; | |
use aws_config::meta::region::RegionProviderChain; | |
use aws_sdk_s3::{Config, Endpoint, Region, RetryConfig}; | |
use aws_smithy_types::timeout::TimeoutConfig; | |
use aws_smithy_types_convert::date_time::DateTimeExt; | |
use aws_types::credentials::Credentials; | |
use bytes::Buf; | |
use http::Uri; | |
use tracing::trace; | |
async fn new_client( | |
region: Option<String>, | |
endpoint: Option<String>, | |
retry_max_attempts: Option<u32>, | |
api_call_attempt_timeout_seconds: Option<u64>, | |
access_key: Option<String>, | |
secret_key: Option<String>, | |
) -> aws_sdk_s3::Client { | |
let region_provider = RegionProviderChain::first_try(region.map(Region::new)) | |
.or_default_provider() | |
.or_else(Region::new("us-west-2")) | |
.region() | |
.await; | |
let mut config_builder = Config::builder().region(region_provider); | |
if let Some(endpoint) = endpoint { | |
config_builder = config_builder | |
.endpoint_resolver(Endpoint::immutable(Uri::from_str(&endpoint).unwrap())); | |
} | |
if let Some(retry_max_attempts) = retry_max_attempts { | |
config_builder = | |
config_builder.retry_config(RetryConfig::new().with_max_attempts(retry_max_attempts)); | |
} | |
if let Some(api_call_attempt_timeout_seconds) = api_call_attempt_timeout_seconds { | |
config_builder = | |
config_builder.timeout_config(TimeoutConfig::new().with_api_call_attempt_timeout( | |
Some(Duration::from_secs(api_call_attempt_timeout_seconds)), | |
)); | |
}; | |
match (access_key, secret_key) { | |
(Some(access_key), Some(secret_key)) => { | |
config_builder = config_builder.credentials_provider(Credentials::new( | |
access_key, secret_key, None, None, "Static", | |
)); | |
} | |
_ => (), | |
}; | |
let config = config_builder.build(); | |
aws_sdk_s3::Client::from_conf(config) | |
} | |
#[derive(Debug)] | |
/// Amazon S3 as Object Store. | |
pub struct AmazonS3FileSystem { | |
region: Option<String>, | |
endpoint: Option<String>, | |
retry_max_attempts: Option<u32>, | |
api_call_attempt_timeout_seconds: Option<u64>, | |
access_key: Option<String>, | |
secret_key: Option<String>, | |
bucket: String, | |
client: aws_sdk_s3::Client, | |
} | |
impl AmazonS3FileSystem { | |
pub async fn new( | |
region: Option<String>, | |
endpoint: Option<String>, | |
retry_max_attempts: Option<u32>, | |
api_call_attempt_timeout_seconds: Option<u64>, | |
access_key: Option<String>, | |
secret_key: Option<String>, | |
bucket: &str, | |
) -> Self { | |
Self { | |
region: region.clone(), | |
endpoint: endpoint.clone(), | |
retry_max_attempts, | |
api_call_attempt_timeout_seconds, | |
access_key: access_key.clone(), | |
secret_key: secret_key.clone(), | |
bucket: bucket.to_string(), | |
client: new_client( | |
region, | |
endpoint, | |
retry_max_attempts, | |
api_call_attempt_timeout_seconds, | |
access_key, | |
secret_key, | |
) | |
.await, | |
} | |
} | |
} | |
#[async_trait] | |
impl ObjectStore for AmazonS3FileSystem { | |
async fn list_file(&self, prefix: &str) -> Result<FileMetaStream> { | |
let objects = self | |
.client | |
.list_objects_v2() | |
.bucket(&self.bucket) | |
.prefix(prefix) | |
.send() | |
.await | |
.map_err(|err| DataFusionError::Internal(format!("{:?}", err)))? | |
.contents() | |
.unwrap_or_default() | |
.to_vec(); | |
let result = stream::iter(objects.into_iter().map(|object| { | |
Ok(FileMeta { | |
sized_file: SizedFile { | |
path: object.key().unwrap_or_else(|| "").to_string(), | |
size: object.size() as u64, | |
}, | |
last_modified: object | |
.last_modified() | |
.map(|last_modified| last_modified.to_chrono_utc()), | |
}) | |
})); | |
Ok(Box::pin(result)) | |
} | |
async fn list_dir(&self, _prefix: &str, _delimiter: Option<String>) -> Result<ListEntryStream> { | |
todo!() | |
} | |
fn file_reader(&self, file: SizedFile) -> Result<Arc<dyn ObjectReader>> { | |
Ok(Arc::new(AmazonS3FileReader::new( | |
self.region.clone(), | |
self.endpoint.clone(), | |
self.retry_max_attempts, | |
self.api_call_attempt_timeout_seconds, | |
self.access_key.clone(), | |
self.secret_key.clone(), | |
&self.bucket, | |
file, | |
)?)) | |
} | |
} | |
struct AmazonS3FileReader { | |
region: Option<String>, | |
endpoint: Option<String>, | |
retry_max_attempts: Option<u32>, | |
api_call_attempt_timeout_seconds: Option<u64>, | |
access_key: Option<String>, | |
secret_key: Option<String>, | |
bucket: String, | |
file: SizedFile, | |
} | |
impl AmazonS3FileReader { | |
fn new( | |
region: Option<String>, | |
endpoint: Option<String>, | |
retry_max_attempts: Option<u32>, | |
api_call_attempt_timeout_seconds: Option<u64>, | |
access_key: Option<String>, | |
secret_key: Option<String>, | |
bucket: &str, | |
file: SizedFile, | |
) -> Result<Self> { | |
Ok(Self { | |
region, | |
endpoint, | |
retry_max_attempts, | |
api_call_attempt_timeout_seconds, | |
access_key, | |
secret_key, | |
bucket: bucket.to_string(), | |
file, | |
}) | |
} | |
} | |
#[async_trait] | |
impl ObjectReader for AmazonS3FileReader { | |
async fn chunk_reader(&self, _start: u64, _length: usize) -> Result<Box<dyn AsyncRead>> { | |
todo!("implement once async file readers are available (arrow-rs#78, arrow-rs#111)") | |
} | |
fn sync_chunk_reader(&self, start: u64, length: usize) -> Result<Box<dyn Read + Send + Sync>> { | |
let region = self.region.clone(); | |
let endpoint = self.endpoint.clone(); | |
let retry_max_attempts = self.retry_max_attempts.clone(); | |
let api_call_attempt_timeout_seconds = self.api_call_attempt_timeout_seconds.clone(); | |
let access_key = self.access_key.clone(); | |
let secret_key = self.secret_key.clone(); | |
let bucket = self.bucket.clone(); | |
let key = self.file.path.clone(); | |
// spawn a channel so that the async amazons3 api can be | |
let (tx, rx) = mpsc::channel(); | |
std::thread::spawn(move || { | |
let rt = tokio::runtime::Builder::new_current_thread() | |
.enable_all() | |
.build() | |
.unwrap(); | |
rt.block_on(async move { | |
let client = new_client( | |
region, | |
endpoint, | |
retry_max_attempts, | |
api_call_attempt_timeout_seconds, | |
access_key, | |
secret_key, | |
) | |
.await; | |
let get_object = client.get_object().bucket(bucket).key(key); | |
let resp = if length != 0 { | |
get_object | |
.range(format!("bytes={}-{}", start, start + length as u64)) | |
.send() | |
.await | |
} else { | |
get_object.send().await | |
}; | |
let bytes = match resp { | |
Ok(res) => { | |
let data = res.body.collect().await; | |
match data { | |
Ok(data) => Ok(data.into_bytes()), | |
Err(err) => Err(DataFusionError::Internal(format!("{:?}", err))), | |
} | |
} | |
Err(err) => Err(DataFusionError::Internal(format!("{:?}", err))), | |
}; | |
match tx.send(bytes) { | |
Ok(_) => (), | |
Err(err) => println!("{:?}", err), | |
}; | |
}) | |
}); | |
let bytes = rx | |
.recv_timeout(Duration::from_secs(10)) | |
.map_err(|err| DataFusionError::Internal(format!("{:?}", err)))??; | |
trace!( | |
"sync_chunk_reader: {:?} {:?}: read {} bytes (offset {})", | |
self.bucket, | |
self.file.path, | |
bytes.len(), | |
start | |
); | |
Ok(Box::new(bytes.reader())) | |
} | |
fn length(&self) -> u64 { | |
self.file.size | |
} | |
} | |
#[cfg(test)] | |
mod tests { | |
use super::*; | |
use common::datafusion::datasource::file_format::parquet::ParquetFormat; | |
use common::datafusion::datasource::listing::*; | |
use common::datafusion::datasource::TableProvider; | |
use futures::StreamExt; | |
#[tokio::test] | |
async fn test_read_files() -> Result<()> { | |
let amazon_s3_file_system = AmazonS3FileSystem::new( | |
None, | |
Some("http://localhost:9000".to_string()), | |
None, | |
None, | |
Some("AKIAIOSFODNN7EXAMPLE".to_string()), | |
Some("wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".to_string()), | |
"adventure_works", | |
) | |
.await; | |
let mut files = amazon_s3_file_system | |
.list_file("humanresources/department.parquet") | |
.await?; | |
while let Some(file) = files.next().await { | |
let sized_file = file.unwrap().sized_file; | |
let mut reader = amazon_s3_file_system | |
.file_reader(sized_file.clone()) | |
.unwrap() | |
.sync_chunk_reader(0, sized_file.size as usize) | |
.unwrap(); | |
let mut bytes = Vec::new(); | |
let size = reader.read_to_end(&mut bytes)?; | |
assert_eq!(size as u64, sized_file.size); | |
} | |
Ok(()) | |
} | |
#[tokio::test] | |
async fn test_read_parquet() -> Result<()> { | |
let amazon_s3_file_system = Arc::new( | |
AmazonS3FileSystem::new( | |
None, | |
Some("http://localhost:9000".to_string()), | |
None, | |
None, | |
Some("AKIAIOSFODNN7EXAMPLE".to_string()), | |
Some("wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".to_string()), | |
"adventure_works", | |
) | |
.await, | |
); | |
let filename = "humanresources/employee.parquet"; | |
let listing_options = ListingOptions { | |
format: Arc::new(ParquetFormat::default()), | |
collect_stat: true, | |
file_extension: "parquet".to_owned(), | |
target_partitions: num_cpus::get(), | |
table_partition_cols: vec![], | |
}; | |
let resolved_schema = listing_options | |
.infer_schema(amazon_s3_file_system.clone(), filename) | |
.await?; | |
let table = ListingTable::new( | |
amazon_s3_file_system, | |
filename.to_owned(), | |
resolved_schema, | |
listing_options, | |
); | |
let exec = table.scan(&None, 1024, &[], None).await?; | |
assert_eq!(exec.statistics().num_rows, Some(290)); | |
Ok(()) | |
} | |
} |
Do you think we should add an S3Error
type?
I think it would be nice if we add an option for using aws_config::load_from_env()
somehow.
i.e. in new_client
we could just have:
let config = aws_config::load_from_env().await;
Client::new(&config)
I tried this with minio as azure blobstore gateway and found an issue with the byte range. According to the spec, range end is inclusive and hence tests were failing.
start, start + length will get length + 1 bytes. Updating the range to (start, start + length - 1) works for me.
https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.35
successfully executed sql query using this :) this is great. thanks again @seddonm1
once the donation is completed i think we can just create issues to resolve the above points that were raised
@matthewmturner : Could you see if it works with byte range adjusted with inclusive range? I'm not sure if it's an issue with my parquet or incorrect range specification.
@gopik you are absolutely correct about the range being incorrect. I fixed it in the PR and added a test:
datafusion-contrib/datafusion-objectstore-s3#2
FYI uses these AWS dependencies: