Skip to content

Instantly share code, notes, and snippets.

@amankrx
Created December 26, 2024 12:18
Show Gist options
  • Save amankrx/45e7d2a6ed935aa13dda0318681af2ad to your computer and use it in GitHub Desktop.
Save amankrx/45e7d2a6ed935aa13dda0318681af2ad to your computer and use it in GitHub Desktop.
Azure Blob Storage Update and Get Implementation using Azure SDK in Rust
use base64;
use hmac::{Hmac, Mac};
use sha2::Sha256;
use std::collections::HashMap;
use std::fmt::Debug;
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use thiserror::Error;
use azure_core::prelude::*;
use azure_core::{Body, HttpClient, TransportOptions};
use azure_storage_blobs::prelude::*;
use futures::stream::StreamExt;
use hyper::client::HttpConnector;
use hyper::{Body as HyperBody, Client, Request as HyperRequest};
use hyper_rustls::HttpsConnector;
const ACCOUNT_NAME: &str = "";
const ACCOUNT_KEY: &str = "";
const CONTAINER: &str = "simple-test-container";
const BLOB: &str = "blob.txt";
const PREFIX: &str = "test-prefix-index";
const MAX_BLOCK_SIZE: usize = 4 * 1024 * 1024;
const MAX_BLOCKS: usize = 10;
#[derive(Debug, Clone)]
pub struct AzureConfig {
account_name: String,
account_key: String,
container: String,
max_block_size: usize,
max_blocks: usize,
api_version: String,
}
impl AzureConfig {
pub fn new(
account_name: impl Into<String>,
account_key: impl Into<String>,
container: impl Into<String>,
) -> Self {
Self {
account_name: account_name.into(),
account_key: account_key.into(),
container: container.into(),
max_block_size: 4 * 1024 * 1024, // 4MB
max_blocks: 10,
api_version: "2020-04-08".to_string(),
}
}
}
#[derive(Error, Debug)]
pub enum AzureClientError {
#[error("Azure API error: {0}")]
Azure(#[from] azure_core::Error),
#[error("HTTP error: {0}")]
Http(#[from] hyper::Error),
#[error("Invalid header value: {0}")]
InvalidHeader(#[from] hyper::header::InvalidHeaderValue),
#[error("Authentication error: {0}")]
Auth(String),
#[error("Configuration error: {0}")]
Config(String),
}
#[derive(Debug, Clone)]
pub struct CustomHttpClient {
client: Client<HttpsConnector<HttpConnector>>,
config: Arc<AzureConfig>,
}
impl CustomHttpClient {
pub fn new(config: AzureConfig) -> Self {
let https = hyper_rustls::HttpsConnectorBuilder::new()
.with_native_roots()
.https_only()
.enable_http1()
.enable_http2()
.build();
let client = Client::builder()
.pool_idle_timeout(std::time::Duration::from_secs(30))
.pool_max_idle_per_host(32)
.build(https);
Self {
client,
config: Arc::new(config),
}
}
fn format_http_date() -> String {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Time went backwards");
use chrono::{TimeZone, Utc};
let datetime = Utc.timestamp_opt(now.as_secs() as i64, 0).unwrap();
datetime.format("%a, %d %b %Y %H:%M:%S GMT").to_string()
}
fn sign_request(&self, request: &azure_core::Request) -> Result<String, AzureClientError> {
let content_length = request
.headers()
.get_as::<String, std::string::ParseError>(
&azure_core::headers::HeaderName::from_static("content-length"),
)
.unwrap_or_default();
let content_type = request
.headers()
.get_as::<String, std::string::ParseError>(
&azure_core::headers::HeaderName::from_static("content-type"),
)
.unwrap_or_default();
// Build canonical headers using a BTreeMap for consistent ordering
let mut canonical_headers = std::collections::BTreeMap::new();
for (name, value) in request.headers().iter() {
if name.as_str().starts_with("x-ms-") {
canonical_headers.insert(
name.as_str().to_lowercase(),
value.as_str().trim().to_string(),
);
}
}
let canonical_headers = canonical_headers
.iter()
.map(|(k, v)| format!("{}:{}\n", k, v))
.collect::<String>();
let canonical_resource = format!(
"/{}/{}{}",
self.config.account_name,
request.url().path().trim_start_matches('/'),
request
.url()
.query()
.map(|q| format!("?{}", q))
.unwrap_or_default()
);
let string_to_sign = format!(
"{}\n\n{}\n{}\n{}\n\n{}\n{}",
request.method(),
content_type,
content_length,
"", // MD5
canonical_headers,
canonical_resource
);
let key_bytes = base64::decode(&self.config.account_key)
.map_err(|e| AzureClientError::Auth(format!("Failed to decode account key: {}", e)))?;
let mut mac = Hmac::<Sha256>::new_from_slice(&key_bytes)
.map_err(|e| AzureClientError::Auth(format!("Failed to create HMAC: {}", e)))?;
mac.update(string_to_sign.as_bytes());
let signature = base64::encode(mac.finalize().into_bytes());
Ok(format!(
"SharedKey {}:{}",
self.config.account_name, signature
))
}
}
#[async_trait::async_trait]
impl HttpClient for CustomHttpClient {
async fn execute_request(
&self,
request: &azure_core::Request,
) -> azure_core::Result<azure_core::Response> {
let mut builder = HyperRequest::builder()
.method(request.method().as_ref())
.uri(request.url().as_str());
// Add standard headers
let date = Self::format_http_date();
builder = builder
.header("x-ms-date", &date)
.header("x-ms-version", &self.config.api_version);
// Add request headers
for (name, value) in request.headers().iter() {
if name.as_str() != "Authorization" && name.as_str() != "x-ms-date" {
builder = builder.header(name.as_str(), value.as_str());
}
}
// Sign and add authorization header
let auth_header = self
.sign_request(request)
.map_err(|e| azure_core::Error::new(azure_core::error::ErrorKind::Other, e))?;
builder = builder.header("Authorization", auth_header);
// Handle request body
let body = match request.body() {
Body::Bytes(bytes) if bytes.is_empty() => HyperBody::empty(),
Body::Bytes(bytes) => HyperBody::from(bytes.to_vec()),
_ => {
return Err(azure_core::Error::new(
azure_core::error::ErrorKind::Other,
"Unsupported body type",
))
}
};
let hyper_request = builder
.body(body)
.map_err(|e| azure_core::Error::new(azure_core::error::ErrorKind::Other, e))?;
// Execute request with timeout
let response = tokio::time::timeout(
std::time::Duration::from_secs(30),
self.client.request(hyper_request),
)
.await
.map_err(|_| {
azure_core::Error::new(azure_core::error::ErrorKind::Other, "Request timeout")
})?
.map_err(|e| azure_core::Error::new(azure_core::error::ErrorKind::Other, e))?;
let (parts, body) = response.into_parts();
// Map the response stream
let mapped_stream = body.map(|result| {
result.map_err(|e| azure_core::Error::new(azure_core::error::ErrorKind::Other, e))
});
// Convert headers
let headers: HashMap<_, _> = parts
.headers
.iter()
.filter_map(|(k, v)| {
Some((
azure_core::headers::HeaderName::from(k.as_str().to_owned()),
azure_core::headers::HeaderValue::from(v.to_str().ok()?.to_owned()),
))
})
.collect();
Ok(azure_core::Response::new(
azure_core::StatusCode::try_from(parts.status.as_u16()).expect("Invalid status code"),
azure_core::headers::Headers::from(headers),
Box::pin(mapped_stream),
))
}
}
pub async fn create_blob_client(
config: AzureConfig,
) -> Result<Arc<BlobServiceClient>, AzureClientError> {
let http_client = Arc::new(CustomHttpClient::new(config.clone()));
let transport_options = TransportOptions::new(http_client);
let client = BlobServiceClient::builder(
&config.account_name,
azure_storage::StorageCredentials::anonymous(),
)
.transport(transport_options)
.blob_service_client();
Ok(Arc::new(client))
}
fn make_blob_path(blob: &str) -> String {
format!("{}/{}", PREFIX, blob)
}
struct TestCase {
name: String,
blob_name: String,
content: Vec<u8>,
use_multi_part: bool,
expected_result: TestResult,
}
enum TestResult {
Success,
BlobNotFound,
ContainerNotFound,
InvalidBlockSize,
UploadError,
}
async fn upload_multi_part(
blob_client: &BlobClient,
content: &[u8],
) -> Result<(), Box<dyn std::error::Error>> {
let total_size = content.len();
let block_size = std::cmp::min(total_size / (MAX_BLOCKS - 1), MAX_BLOCK_SIZE);
let mut block_ids = Vec::new();
// Upload blocks
for (i, chunk) in content.chunks(block_size).enumerate() {
let block_id = format!("{:032}", i);
println!("Uploading block {} of size {}", block_id, chunk.len());
match blob_client
.put_block(block_id.clone(), Body::from(chunk.to_vec()))
.await
{
Ok(_) => {
println!("Successfully uploaded block {}", block_id);
block_ids.push(block_id);
}
Err(e) => {
println!("Failed to upload block {}: {:?}", block_id, e);
return Err(Box::new(e));
}
}
}
// Commit block list
let block_list = BlockList {
blocks: block_ids
.into_iter()
.map(|id| BlobBlockType::Latest(BlockId::from(id)))
.collect(),
};
match blob_client
.put_block_list(block_list)
.content_type("application/octet-stream")
.await
{
Ok(_) => println!("Successfully committed block list"),
Err(e) => {
println!("Failed to commit block list: {:?}", e);
return Err(Box::new(e));
}
}
Ok(())
}
async fn test_upload(
client: &BlobServiceClient,
test_case: &TestCase,
) -> Result<(), Box<dyn std::error::Error>> {
let container_client = client.container_client(CONTAINER);
let blob_path = make_blob_path(&test_case.blob_name);
let blob_client = container_client.blob_client(&blob_path);
println!("\n=== Testing upload for: {} ===", test_case.name);
println!("Content size: {} bytes", test_case.content.len());
if test_case.use_multi_part {
println!("Using multi-part upload strategy");
match upload_multi_part(&blob_client, &test_case.content).await {
Ok(_) => println!("Multi-part upload completed successfully"),
Err(e) => println!("Multi-part upload failed: {:?}", e),
}
} else {
match blob_client
.put_block_blob(Body::from(test_case.content.clone()))
.content_type("application/octet-stream")
.into_future()
.await
{
Ok(response) => println!("Single-part upload response: {:?}", response),
Err(e) => println!("Single-part upload failed: {:?}", e),
}
}
Ok(())
}
async fn test_download(
client: &BlobServiceClient,
test_case: &TestCase,
) -> Result<(), Box<dyn std::error::Error>> {
let container_client = client.container_client(CONTAINER);
let blob_path = make_blob_path(&test_case.blob_name);
let blob_client = container_client.blob_client(&blob_path);
println!("\n=== Testing download for: {} ===", test_case.name);
if test_case.content.is_empty() {
println!("Skipping download test - no content to download");
return Ok(());
}
let mut stream = blob_client
.get()
.range(Range::new(0, test_case.content.len() as u64 - 1))
.into_stream();
let mut total_bytes = 0;
while let Some(result) = stream.next().await {
match result {
Ok(response) => {
println!("Response: {:?}", response);
let data = response.data.collect().await?;
total_bytes += data.len();
println!("Download chunk received: {} bytes", data.len());
}
Err(e) => {
println!("Download error: {:?}", e);
if let Some(status) = e.as_http_error().map(|e| e.status()) {
println!("HTTP Status Code: {}", status);
}
break;
}
}
}
println!("Total bytes downloaded: {}", total_bytes);
Ok(())
}
async fn test_get_properties(
client: &BlobServiceClient,
test_case: &TestCase,
) -> Result<(), Box<dyn std::error::Error>> {
let container_client = client.container_client(CONTAINER);
let blob_path = make_blob_path(&test_case.blob_name);
let blob_client = container_client.blob_client(&blob_path);
println!("\n=== Testing get_properties for: {} ===", test_case.name);
match blob_client.get_properties().await {
Ok(props) => {
println!("Properties retrieved successfully:");
println!("Props: {:?}", props);
}
Err(e) => {
println!("Get properties error: {:?}", e);
if let Some(status) = e.as_http_error().map(|e| e.status()) {
println!("HTTP Status Code: {}", status);
}
}
}
Ok(())
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let azure_config = AzureConfig::new(
ACCOUNT_NAME.to_string(),
ACCOUNT_KEY.to_string(),
CONTAINER.to_string(),
);
let client = create_blob_client(azure_config).await?;
let test_cases = vec![
TestCase {
name: "Success case - Small blob".to_string(),
blob_name: "test-blob-small.txt".to_string(),
content: b"Hello, World!".to_vec(),
use_multi_part: false,
expected_result: TestResult::Success,
},
TestCase {
name: "Success case - Multi-part upload".to_string(),
blob_name: "test-blob-large.txt".to_string(),
content: vec![0; 8 * 1024 * 1024], // 8MB blob to trigger multi-part upload
use_multi_part: true,
expected_result: TestResult::Success,
},
TestCase {
name: "Error case - Non-existent blob".to_string(),
blob_name: "non-existent-blob.txt".to_string(),
content: vec![],
use_multi_part: false,
expected_result: TestResult::BlobNotFound,
},
TestCase {
name: "Error case - Invalid container".to_string(),
blob_name: "test-blob-invalid-container.txt".to_string(),
content: b"Test content".to_vec(),
use_multi_part: false,
expected_result: TestResult::ContainerNotFound,
},
TestCase {
name: "Error case - Very large block size".to_string(),
blob_name: "test-blob-large-block.txt".to_string(),
content: vec![0; MAX_BLOCK_SIZE + 1], // Exceeds max block size
use_multi_part: true,
expected_result: TestResult::InvalidBlockSize,
},
];
// Create container if it doesn't exist
let container_client = client.container_client(CONTAINER);
match container_client.create().await {
Ok(_) => println!("Container created or already exists"),
Err(e) => println!("Container creation error: {:?}", e),
}
for test_case in &test_cases {
println!("\nExecuting test case: {}", test_case.name);
match test_case.expected_result {
TestResult::Success => {
test_upload(&client, test_case).await?;
test_get_properties(&client, test_case).await?;
test_download(&client, test_case).await?;
}
_ => {
// For error cases, we might want to test specific scenarios
match test_case.expected_result {
TestResult::BlobNotFound => {
test_get_properties(&client, test_case).await?;
test_download(&client, test_case).await?;
}
TestResult::ContainerNotFound => {
// Test with non-existent container
let invalid_container = client.container_client("non-existent-container");
let blob_path = make_blob_path(&test_case.blob_name);
let result = invalid_container
.blob_client(&blob_path)
.get_properties()
.await;
println!("Testing invalid container: {:?}", result);
}
TestResult::InvalidBlockSize => {
test_upload(&client, test_case).await?;
}
_ => {}
}
}
}
}
Ok(())
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment