Created
November 20, 2024 12:59
-
-
Save Yggdrasill501/25c5fa6a267657a59708da84668dad39 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Alternative to s3fs-fuse-awscred-lib that handles credential refreshing | |
using a background thread and the existing credentials service. | |
This implementation: | |
- Periodically refreshes credentials without restarting s3fs | |
- Uses existing credentials service | |
- Handles role assumption | |
- Provides proper error handling | |
- Supports debug logging | |
""" | |
import os | |
import time | |
import threading | |
import requests | |
from typing import Dict, Optional | |
import logging | |
class CredentialRefresherError(Exception): | |
"""Base exception class for credential refresher errors.""" | |
pass | |
class RoleAssumeError(CredentialRefresherError): | |
"""Exception raised when role assumption fails.""" | |
pass | |
class CredentialServiceError(CredentialRefresherError): | |
"""Exception raised when credential service communication fails.""" | |
pass | |
class CredentialFileError(CredentialRefresherError): | |
"""Exception raised when there are issues with the credentials file.""" | |
pass | |
class MountError(CredentialRefresherError): | |
"""Exception raised when mounting operations fail.""" | |
pass | |
class ConfigurationError(CredentialRefresherError): | |
"""Exception raised when there are configuration issues.""" | |
pass | |
class CredentialRefresher: | |
def __init__(self, role_arn: str, external_id: Optional[str] = None, | |
refresh_interval: int = 300, credentials_file: str = '/etc/passwd-s3fs'): | |
self.role_arn = role_arn | |
self.external_id = external_id | |
self.refresh_interval = refresh_interval | |
self.credentials_file = credentials_file | |
self.service_url = os.environ.get('AWS_CREDENTIALS_SERVICE_URL') | |
self.service_secret = os.environ.get('AWS_CREDENTIALS_SERVICE_SECRET') | |
if not self.service_url or not self.service_secret: | |
raise ConfigurationError("Missing required environment variables") | |
self._stop_event = threading.Event() | |
self._refresh_thread = None | |
def get_credentials(self) -> Dict[str, str]: | |
url = f"{self.service_url}/api/aws-credentials" | |
payload = {"roleArn": self.role_arn} | |
if self.external_id: | |
payload["externalId"] = self.external_id | |
headers = { | |
"Authorization": f"Bearer {self.service_secret}", | |
"Content-Type": "application/json" | |
} | |
try: | |
response = requests.post(url, json=payload, headers=headers) | |
if response.status_code == 403: | |
raise RoleAssumeError(f"Failed to assume role {self.role_arn}") | |
response.raise_for_status() | |
data = response.json() | |
return { | |
'AccessKeyId': data['accessKeyId'], | |
'SecretAccessKey': data['secretAccessKey'], | |
'SessionToken': data['sessionToken'] | |
} | |
except requests.exceptions.RequestException as e: | |
raise CredentialServiceError(f"Credential service error: {str(e)}") | |
def update_credentials_file(self): | |
try: | |
credentials = self.get_credentials() | |
creds_str = (f"{credentials['AccessKeyId']}:{credentials['SecretAccessKey']}" | |
f":{credentials['SessionToken']}") | |
temp_file = f"{self.credentials_file}.tmp" | |
with open(temp_file, 'w') as f: | |
f.write(creds_str) | |
os.chmod(temp_file, 0o600) | |
os.rename(temp_file, self.credentials_file) | |
except (IOError, OSError) as e: | |
raise CredentialFileError(f"Failed to write credentials: {str(e)}") | |
def _refresh_loop(self): | |
while not self._stop_event.is_set(): | |
try: | |
self.update_credentials_file() | |
except CredentialRefresherError as e: | |
logging.error(f"Credential refresh error: {str(e)}") | |
self._stop_event.wait(self.refresh_interval) | |
def start(self): | |
if self._refresh_thread is not None: | |
raise CredentialRefresherError("Already running") | |
self._stop_event.clear() | |
self._refresh_thread = threading.Thread(target=self._refresh_loop) | |
self._refresh_thread.daemon = True | |
self._refresh_thread.start() | |
def stop(self): | |
if self._refresh_thread: | |
self._stop_event.set() | |
self._refresh_thread.join() | |
self._refresh_thread = None | |
def mount_s3fs_with_refreshing_credentials( | |
bucket: str, | |
mount_point: str, | |
role_arn: str, | |
external_id: Optional[str] = None, | |
refresh_interval: int = 300, | |
debug: bool = False | |
) -> None: | |
import subprocess | |
try: | |
os.makedirs(mount_point, exist_ok=True) | |
except OSError as e: | |
raise MountError(f"Failed to create mount point: {str(e)}") | |
refresher = CredentialRefresher( | |
role_arn=role_arn, | |
external_id=external_id, | |
refresh_interval=refresh_interval | |
) | |
try: | |
refresher.start() | |
refresher.update_credentials_file() | |
command = [ | |
's3fs', | |
bucket, | |
mount_point, | |
'-o', f'passwd_file={refresher.credentials_file}', | |
'-o', 'allow_other', | |
'-o', 'mp_umask=0000' | |
] | |
if debug: | |
command.extend(['-o', 'dbglevel=info', '-f', '-d']) | |
process = subprocess.Popen(command) | |
return_code = process.wait() | |
if return_code != 0: | |
raise MountError(f"Mount failed with code {return_code}") | |
finally: | |
refresher.stop() | |
# Usage example: | |
if __name__ == "__main__": | |
try: | |
mount_s3fs_with_refreshing_credentials( | |
bucket="my-bucket", | |
mount_point="/mnt/s3", | |
role_arn="arn:aws:iam::123456789012:role/my-role", | |
external_id="my-external-id", | |
debug=True | |
) | |
except CredentialRefresherError as e: | |
logging.error(f"Failed to mount: {str(e)}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment