Created
December 26, 2024 12:18
-
-
Save amankrx/45e7d2a6ed935aa13dda0318681af2ad to your computer and use it in GitHub Desktop.
Azure Blob Storage Update and Get Implementation using Azure SDK in Rust
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use base64; | |
use hmac::{Hmac, Mac}; | |
use sha2::Sha256; | |
use std::collections::HashMap; | |
use std::fmt::Debug; | |
use std::sync::Arc; | |
use std::time::{SystemTime, UNIX_EPOCH}; | |
use thiserror::Error; | |
use azure_core::prelude::*; | |
use azure_core::{Body, HttpClient, TransportOptions}; | |
use azure_storage_blobs::prelude::*; | |
use futures::stream::StreamExt; | |
use hyper::client::HttpConnector; | |
use hyper::{Body as HyperBody, Client, Request as HyperRequest}; | |
use hyper_rustls::HttpsConnector; | |
const ACCOUNT_NAME: &str = ""; | |
const ACCOUNT_KEY: &str = ""; | |
const CONTAINER: &str = "simple-test-container"; | |
const BLOB: &str = "blob.txt"; | |
const PREFIX: &str = "test-prefix-index"; | |
const MAX_BLOCK_SIZE: usize = 4 * 1024 * 1024; | |
const MAX_BLOCKS: usize = 10; | |
#[derive(Debug, Clone)] | |
pub struct AzureConfig { | |
account_name: String, | |
account_key: String, | |
container: String, | |
max_block_size: usize, | |
max_blocks: usize, | |
api_version: String, | |
} | |
impl AzureConfig { | |
pub fn new( | |
account_name: impl Into<String>, | |
account_key: impl Into<String>, | |
container: impl Into<String>, | |
) -> Self { | |
Self { | |
account_name: account_name.into(), | |
account_key: account_key.into(), | |
container: container.into(), | |
max_block_size: 4 * 1024 * 1024, // 4MB | |
max_blocks: 10, | |
api_version: "2020-04-08".to_string(), | |
} | |
} | |
} | |
#[derive(Error, Debug)] | |
pub enum AzureClientError { | |
#[error("Azure API error: {0}")] | |
Azure(#[from] azure_core::Error), | |
#[error("HTTP error: {0}")] | |
Http(#[from] hyper::Error), | |
#[error("Invalid header value: {0}")] | |
InvalidHeader(#[from] hyper::header::InvalidHeaderValue), | |
#[error("Authentication error: {0}")] | |
Auth(String), | |
#[error("Configuration error: {0}")] | |
Config(String), | |
} | |
#[derive(Debug, Clone)] | |
pub struct CustomHttpClient { | |
client: Client<HttpsConnector<HttpConnector>>, | |
config: Arc<AzureConfig>, | |
} | |
impl CustomHttpClient { | |
pub fn new(config: AzureConfig) -> Self { | |
let https = hyper_rustls::HttpsConnectorBuilder::new() | |
.with_native_roots() | |
.https_only() | |
.enable_http1() | |
.enable_http2() | |
.build(); | |
let client = Client::builder() | |
.pool_idle_timeout(std::time::Duration::from_secs(30)) | |
.pool_max_idle_per_host(32) | |
.build(https); | |
Self { | |
client, | |
config: Arc::new(config), | |
} | |
} | |
fn format_http_date() -> String { | |
let now = SystemTime::now() | |
.duration_since(UNIX_EPOCH) | |
.expect("Time went backwards"); | |
use chrono::{TimeZone, Utc}; | |
let datetime = Utc.timestamp_opt(now.as_secs() as i64, 0).unwrap(); | |
datetime.format("%a, %d %b %Y %H:%M:%S GMT").to_string() | |
} | |
fn sign_request(&self, request: &azure_core::Request) -> Result<String, AzureClientError> { | |
let content_length = request | |
.headers() | |
.get_as::<String, std::string::ParseError>( | |
&azure_core::headers::HeaderName::from_static("content-length"), | |
) | |
.unwrap_or_default(); | |
let content_type = request | |
.headers() | |
.get_as::<String, std::string::ParseError>( | |
&azure_core::headers::HeaderName::from_static("content-type"), | |
) | |
.unwrap_or_default(); | |
// Build canonical headers using a BTreeMap for consistent ordering | |
let mut canonical_headers = std::collections::BTreeMap::new(); | |
for (name, value) in request.headers().iter() { | |
if name.as_str().starts_with("x-ms-") { | |
canonical_headers.insert( | |
name.as_str().to_lowercase(), | |
value.as_str().trim().to_string(), | |
); | |
} | |
} | |
let canonical_headers = canonical_headers | |
.iter() | |
.map(|(k, v)| format!("{}:{}\n", k, v)) | |
.collect::<String>(); | |
let canonical_resource = format!( | |
"/{}/{}{}", | |
self.config.account_name, | |
request.url().path().trim_start_matches('/'), | |
request | |
.url() | |
.query() | |
.map(|q| format!("?{}", q)) | |
.unwrap_or_default() | |
); | |
let string_to_sign = format!( | |
"{}\n\n{}\n{}\n{}\n\n{}\n{}", | |
request.method(), | |
content_type, | |
content_length, | |
"", // MD5 | |
canonical_headers, | |
canonical_resource | |
); | |
let key_bytes = base64::decode(&self.config.account_key) | |
.map_err(|e| AzureClientError::Auth(format!("Failed to decode account key: {}", e)))?; | |
let mut mac = Hmac::<Sha256>::new_from_slice(&key_bytes) | |
.map_err(|e| AzureClientError::Auth(format!("Failed to create HMAC: {}", e)))?; | |
mac.update(string_to_sign.as_bytes()); | |
let signature = base64::encode(mac.finalize().into_bytes()); | |
Ok(format!( | |
"SharedKey {}:{}", | |
self.config.account_name, signature | |
)) | |
} | |
} | |
#[async_trait::async_trait] | |
impl HttpClient for CustomHttpClient { | |
async fn execute_request( | |
&self, | |
request: &azure_core::Request, | |
) -> azure_core::Result<azure_core::Response> { | |
let mut builder = HyperRequest::builder() | |
.method(request.method().as_ref()) | |
.uri(request.url().as_str()); | |
// Add standard headers | |
let date = Self::format_http_date(); | |
builder = builder | |
.header("x-ms-date", &date) | |
.header("x-ms-version", &self.config.api_version); | |
// Add request headers | |
for (name, value) in request.headers().iter() { | |
if name.as_str() != "Authorization" && name.as_str() != "x-ms-date" { | |
builder = builder.header(name.as_str(), value.as_str()); | |
} | |
} | |
// Sign and add authorization header | |
let auth_header = self | |
.sign_request(request) | |
.map_err(|e| azure_core::Error::new(azure_core::error::ErrorKind::Other, e))?; | |
builder = builder.header("Authorization", auth_header); | |
// Handle request body | |
let body = match request.body() { | |
Body::Bytes(bytes) if bytes.is_empty() => HyperBody::empty(), | |
Body::Bytes(bytes) => HyperBody::from(bytes.to_vec()), | |
_ => { | |
return Err(azure_core::Error::new( | |
azure_core::error::ErrorKind::Other, | |
"Unsupported body type", | |
)) | |
} | |
}; | |
let hyper_request = builder | |
.body(body) | |
.map_err(|e| azure_core::Error::new(azure_core::error::ErrorKind::Other, e))?; | |
// Execute request with timeout | |
let response = tokio::time::timeout( | |
std::time::Duration::from_secs(30), | |
self.client.request(hyper_request), | |
) | |
.await | |
.map_err(|_| { | |
azure_core::Error::new(azure_core::error::ErrorKind::Other, "Request timeout") | |
})? | |
.map_err(|e| azure_core::Error::new(azure_core::error::ErrorKind::Other, e))?; | |
let (parts, body) = response.into_parts(); | |
// Map the response stream | |
let mapped_stream = body.map(|result| { | |
result.map_err(|e| azure_core::Error::new(azure_core::error::ErrorKind::Other, e)) | |
}); | |
// Convert headers | |
let headers: HashMap<_, _> = parts | |
.headers | |
.iter() | |
.filter_map(|(k, v)| { | |
Some(( | |
azure_core::headers::HeaderName::from(k.as_str().to_owned()), | |
azure_core::headers::HeaderValue::from(v.to_str().ok()?.to_owned()), | |
)) | |
}) | |
.collect(); | |
Ok(azure_core::Response::new( | |
azure_core::StatusCode::try_from(parts.status.as_u16()).expect("Invalid status code"), | |
azure_core::headers::Headers::from(headers), | |
Box::pin(mapped_stream), | |
)) | |
} | |
} | |
pub async fn create_blob_client( | |
config: AzureConfig, | |
) -> Result<Arc<BlobServiceClient>, AzureClientError> { | |
let http_client = Arc::new(CustomHttpClient::new(config.clone())); | |
let transport_options = TransportOptions::new(http_client); | |
let client = BlobServiceClient::builder( | |
&config.account_name, | |
azure_storage::StorageCredentials::anonymous(), | |
) | |
.transport(transport_options) | |
.blob_service_client(); | |
Ok(Arc::new(client)) | |
} | |
fn make_blob_path(blob: &str) -> String { | |
format!("{}/{}", PREFIX, blob) | |
} | |
struct TestCase { | |
name: String, | |
blob_name: String, | |
content: Vec<u8>, | |
use_multi_part: bool, | |
expected_result: TestResult, | |
} | |
enum TestResult { | |
Success, | |
BlobNotFound, | |
ContainerNotFound, | |
InvalidBlockSize, | |
UploadError, | |
} | |
async fn upload_multi_part( | |
blob_client: &BlobClient, | |
content: &[u8], | |
) -> Result<(), Box<dyn std::error::Error>> { | |
let total_size = content.len(); | |
let block_size = std::cmp::min(total_size / (MAX_BLOCKS - 1), MAX_BLOCK_SIZE); | |
let mut block_ids = Vec::new(); | |
// Upload blocks | |
for (i, chunk) in content.chunks(block_size).enumerate() { | |
let block_id = format!("{:032}", i); | |
println!("Uploading block {} of size {}", block_id, chunk.len()); | |
match blob_client | |
.put_block(block_id.clone(), Body::from(chunk.to_vec())) | |
.await | |
{ | |
Ok(_) => { | |
println!("Successfully uploaded block {}", block_id); | |
block_ids.push(block_id); | |
} | |
Err(e) => { | |
println!("Failed to upload block {}: {:?}", block_id, e); | |
return Err(Box::new(e)); | |
} | |
} | |
} | |
// Commit block list | |
let block_list = BlockList { | |
blocks: block_ids | |
.into_iter() | |
.map(|id| BlobBlockType::Latest(BlockId::from(id))) | |
.collect(), | |
}; | |
match blob_client | |
.put_block_list(block_list) | |
.content_type("application/octet-stream") | |
.await | |
{ | |
Ok(_) => println!("Successfully committed block list"), | |
Err(e) => { | |
println!("Failed to commit block list: {:?}", e); | |
return Err(Box::new(e)); | |
} | |
} | |
Ok(()) | |
} | |
async fn test_upload( | |
client: &BlobServiceClient, | |
test_case: &TestCase, | |
) -> Result<(), Box<dyn std::error::Error>> { | |
let container_client = client.container_client(CONTAINER); | |
let blob_path = make_blob_path(&test_case.blob_name); | |
let blob_client = container_client.blob_client(&blob_path); | |
println!("\n=== Testing upload for: {} ===", test_case.name); | |
println!("Content size: {} bytes", test_case.content.len()); | |
if test_case.use_multi_part { | |
println!("Using multi-part upload strategy"); | |
match upload_multi_part(&blob_client, &test_case.content).await { | |
Ok(_) => println!("Multi-part upload completed successfully"), | |
Err(e) => println!("Multi-part upload failed: {:?}", e), | |
} | |
} else { | |
match blob_client | |
.put_block_blob(Body::from(test_case.content.clone())) | |
.content_type("application/octet-stream") | |
.into_future() | |
.await | |
{ | |
Ok(response) => println!("Single-part upload response: {:?}", response), | |
Err(e) => println!("Single-part upload failed: {:?}", e), | |
} | |
} | |
Ok(()) | |
} | |
async fn test_download( | |
client: &BlobServiceClient, | |
test_case: &TestCase, | |
) -> Result<(), Box<dyn std::error::Error>> { | |
let container_client = client.container_client(CONTAINER); | |
let blob_path = make_blob_path(&test_case.blob_name); | |
let blob_client = container_client.blob_client(&blob_path); | |
println!("\n=== Testing download for: {} ===", test_case.name); | |
if test_case.content.is_empty() { | |
println!("Skipping download test - no content to download"); | |
return Ok(()); | |
} | |
let mut stream = blob_client | |
.get() | |
.range(Range::new(0, test_case.content.len() as u64 - 1)) | |
.into_stream(); | |
let mut total_bytes = 0; | |
while let Some(result) = stream.next().await { | |
match result { | |
Ok(response) => { | |
println!("Response: {:?}", response); | |
let data = response.data.collect().await?; | |
total_bytes += data.len(); | |
println!("Download chunk received: {} bytes", data.len()); | |
} | |
Err(e) => { | |
println!("Download error: {:?}", e); | |
if let Some(status) = e.as_http_error().map(|e| e.status()) { | |
println!("HTTP Status Code: {}", status); | |
} | |
break; | |
} | |
} | |
} | |
println!("Total bytes downloaded: {}", total_bytes); | |
Ok(()) | |
} | |
async fn test_get_properties( | |
client: &BlobServiceClient, | |
test_case: &TestCase, | |
) -> Result<(), Box<dyn std::error::Error>> { | |
let container_client = client.container_client(CONTAINER); | |
let blob_path = make_blob_path(&test_case.blob_name); | |
let blob_client = container_client.blob_client(&blob_path); | |
println!("\n=== Testing get_properties for: {} ===", test_case.name); | |
match blob_client.get_properties().await { | |
Ok(props) => { | |
println!("Properties retrieved successfully:"); | |
println!("Props: {:?}", props); | |
} | |
Err(e) => { | |
println!("Get properties error: {:?}", e); | |
if let Some(status) = e.as_http_error().map(|e| e.status()) { | |
println!("HTTP Status Code: {}", status); | |
} | |
} | |
} | |
Ok(()) | |
} | |
#[tokio::main] | |
async fn main() -> Result<(), Box<dyn std::error::Error>> { | |
let azure_config = AzureConfig::new( | |
ACCOUNT_NAME.to_string(), | |
ACCOUNT_KEY.to_string(), | |
CONTAINER.to_string(), | |
); | |
let client = create_blob_client(azure_config).await?; | |
let test_cases = vec![ | |
TestCase { | |
name: "Success case - Small blob".to_string(), | |
blob_name: "test-blob-small.txt".to_string(), | |
content: b"Hello, World!".to_vec(), | |
use_multi_part: false, | |
expected_result: TestResult::Success, | |
}, | |
TestCase { | |
name: "Success case - Multi-part upload".to_string(), | |
blob_name: "test-blob-large.txt".to_string(), | |
content: vec![0; 8 * 1024 * 1024], // 8MB blob to trigger multi-part upload | |
use_multi_part: true, | |
expected_result: TestResult::Success, | |
}, | |
TestCase { | |
name: "Error case - Non-existent blob".to_string(), | |
blob_name: "non-existent-blob.txt".to_string(), | |
content: vec![], | |
use_multi_part: false, | |
expected_result: TestResult::BlobNotFound, | |
}, | |
TestCase { | |
name: "Error case - Invalid container".to_string(), | |
blob_name: "test-blob-invalid-container.txt".to_string(), | |
content: b"Test content".to_vec(), | |
use_multi_part: false, | |
expected_result: TestResult::ContainerNotFound, | |
}, | |
TestCase { | |
name: "Error case - Very large block size".to_string(), | |
blob_name: "test-blob-large-block.txt".to_string(), | |
content: vec![0; MAX_BLOCK_SIZE + 1], // Exceeds max block size | |
use_multi_part: true, | |
expected_result: TestResult::InvalidBlockSize, | |
}, | |
]; | |
// Create container if it doesn't exist | |
let container_client = client.container_client(CONTAINER); | |
match container_client.create().await { | |
Ok(_) => println!("Container created or already exists"), | |
Err(e) => println!("Container creation error: {:?}", e), | |
} | |
for test_case in &test_cases { | |
println!("\nExecuting test case: {}", test_case.name); | |
match test_case.expected_result { | |
TestResult::Success => { | |
test_upload(&client, test_case).await?; | |
test_get_properties(&client, test_case).await?; | |
test_download(&client, test_case).await?; | |
} | |
_ => { | |
// For error cases, we might want to test specific scenarios | |
match test_case.expected_result { | |
TestResult::BlobNotFound => { | |
test_get_properties(&client, test_case).await?; | |
test_download(&client, test_case).await?; | |
} | |
TestResult::ContainerNotFound => { | |
// Test with non-existent container | |
let invalid_container = client.container_client("non-existent-container"); | |
let blob_path = make_blob_path(&test_case.blob_name); | |
let result = invalid_container | |
.blob_client(&blob_path) | |
.get_properties() | |
.await; | |
println!("Testing invalid container: {:?}", result); | |
} | |
TestResult::InvalidBlockSize => { | |
test_upload(&client, test_case).await?; | |
} | |
_ => {} | |
} | |
} | |
} | |
} | |
Ok(()) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment