Skip to content

Instantly share code, notes, and snippets.

@adamori
Last active September 25, 2024 16:01
Show Gist options
  • Save adamori/e0b7f805f17c64e8db6d65ab62f56b53 to your computer and use it in GitHub Desktop.
Save adamori/e0b7f805f17c64e8db6d65ab62f56b53 to your computer and use it in GitHub Desktop.
This Python script manages temporary firewall rules for a Hetzner Cloud server, ideal for use in CI/CD pipelines like GitHub Actions. It allows temporary SSH access from a specific IP address during automated tasks.
import os
import sys
import requests
import logging
import argparse
from typing import Optional
from dotenv import load_dotenv
from dataclasses import dataclass
# Configure root logger
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
@dataclass
class Config:
api_key: str
server_id: int
ssh_port: int = 22
created_by: str = "gh-actions"
@staticmethod
def from_env() -> "Config":
load_dotenv()
api_key = os.getenv("HETZNER_API_KEY")
server_id_str = os.getenv("HETZNER_SERVER_ID")
ssh_port_str = os.getenv("SSH_PORT", "22")
created_by = os.getenv("CREATED_BY", "gh-actions")
if not api_key or not server_id_str:
raise ValueError(
"Environment variables HETZNER_API_KEY and HETZNER_SERVER_ID must be set."
)
try:
server_id = int(server_id_str)
except ValueError:
raise ValueError("HETZNER_SERVER_ID must be an integer.")
try:
ssh_port = int(ssh_port_str)
except ValueError:
raise ValueError("SSH_PORT must be an integer.")
return Config(
api_key=api_key,
server_id=server_id,
ssh_port=ssh_port,
created_by=created_by,
)
class HetznerFirewallManager:
BASE_URL = "https://api.hetzner.cloud/v1"
def __init__(
self,
api_key: str,
server_id: int,
ssh_port: int = 22,
created_by: str = "gh-actions",
):
self.api_key = api_key
self.server_id = server_id
self.ssh_port = ssh_port
self.created_by = created_by
self.headers = {"Authorization": f"Bearer {self.api_key}"}
# Create a logger for this class
self.logger = logging.getLogger(self.__class__.__name__)
self.logger.propagate = (
False # Prevent messages from being propagated to the root logger
)
# Add a StreamHandler if not already present
if not self.logger.handlers:
handler = logging.StreamHandler()
formatter = logging.Formatter(
fmt="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
handler.setFormatter(formatter)
self.logger.addHandler(handler)
def get_my_ip(self) -> str:
url = "https://api.ipify.org"
try:
response = requests.get(url)
response.raise_for_status()
return response.text
except requests.RequestException as e:
self.logger.error(f"Error getting public IP: {e}")
raise
def apply_firewall(self, allow_ip: str) -> str:
url = f"{self.BASE_URL}/firewalls"
data = {
"name": f"{self.created_by} Temp Firewall",
"apply_to": [{"type": "server", "server": {"id": self.server_id}}],
"labels": {"created-by": self.created_by},
"rules": [
{
"description": "Temporary access for CI/CD",
"direction": "in",
"protocol": "tcp",
"port": str(self.ssh_port),
"source_ips": [f"{allow_ip}/32"],
}
],
}
try:
response = requests.post(url, headers=self.headers, json=data)
response.raise_for_status()
firewall_data = response.json()
firewall_id = str(firewall_data["firewall"]["id"])
self.logger.info(f"Firewall applied with ID: {firewall_id}")
return firewall_id
except requests.RequestException as e:
error_text = response.text if "response" in locals() else str(e)
self.logger.error(f"Error applying firewall: {e} - {error_text}")
raise
def get_applied_firewall_id(self) -> Optional[str]:
url = f"{self.BASE_URL}/firewalls"
params = {"label_selector": f"created-by=={self.created_by}"}
try:
response = requests.get(url, headers=self.headers, params=params)
response.raise_for_status()
firewalls = response.json().get("firewalls", [])
for firewall in firewalls:
firewall_id = firewall.get("id")
applied_to = firewall.get("applied_to", [])
for resource in applied_to:
if (
resource.get("type") == "server"
and resource.get("server", {}).get("id") == self.server_id
):
return str(firewall_id)
return None
except requests.RequestException as e:
error_text = response.text if "response" in locals() else str(e)
self.logger.error(f"Error getting firewalls: {e} - {error_text}")
raise
def remove_firewall(self, firewall_id: str) -> None:
url = f"{self.BASE_URL}/firewalls/{firewall_id}/actions/remove_from_resources"
data = {"remove_from": [{"type": "server", "server": {"id": self.server_id}}]}
try:
response = requests.post(url, headers=self.headers, json=data)
response.raise_for_status()
self.logger.info(
f"Firewall {firewall_id} removed from server {self.server_id}"
)
except requests.RequestException as e:
error_text = response.text if "response" in locals() else str(e)
self.logger.error(f"Error removing firewall: {e} - {error_text}")
raise
def delete_firewall(self, firewall_id: str) -> None:
url = f"{self.BASE_URL}/firewalls/{firewall_id}"
try:
response = requests.delete(url, headers=self.headers)
response.raise_for_status()
self.logger.info(f"Firewall {firewall_id} deleted")
except requests.RequestException as e:
error_text = response.text if "response" in locals() else str(e)
self.logger.error(f"Error deleting firewall: {e} - {error_text}")
raise
def main():
parser = argparse.ArgumentParser(
description="Manage temporary Hetzner firewall for CI/CD."
)
subparsers = parser.add_subparsers(dest="command", required=True)
subparsers.add_parser("apply", help="Apply the temporary firewall")
subparsers.add_parser("delete", help="Delete the temporary firewall")
args = parser.parse_args()
try:
config = Config.from_env()
except ValueError as e:
logging.error(e)
sys.exit(1)
manager = HetznerFirewallManager(
api_key=config.api_key,
server_id=config.server_id,
ssh_port=config.ssh_port,
created_by=config.created_by,
)
if args.command == "apply":
try:
firewall_id = manager.get_applied_firewall_id()
if firewall_id:
manager.logger.info(
f"Firewall labeled '{config.created_by}' already exists with ID: {firewall_id}"
)
else:
my_ip = manager.get_my_ip()
manager.apply_firewall(my_ip)
except Exception as e:
manager.logger.error(f"Error applying firewall: {e}")
sys.exit(1)
elif args.command == "delete":
try:
firewall_id = manager.get_applied_firewall_id()
if firewall_id:
manager.logger.info(f"Removing firewall with ID: {firewall_id}")
manager.remove_firewall(firewall_id)
manager.delete_firewall(firewall_id)
else:
manager.logger.info(f"No firewall labeled '{config.created_by}' found.")
except Exception as e:
manager.logger.error(f"Error deleting firewall: {e}")
sys.exit(1)
if __name__ == "__main__":
main()
@adamori
Copy link
Author

adamori commented Sep 25, 2024

Hetzner Cloud Temporary Firewall Manager

This Python script manages temporary firewall rules for a Hetzner Cloud server, ideal for use in CI/CD pipelines like GitHub Actions. It allows temporary SSH access from a specific IP address during automated tasks.

Important: This script requests third-party service https://api.ipify.org to get IP address of the server.

Key Features:

  • Apply Temporary Firewall: Fetches your public IP and applies a firewall rule to allow SSH access (default port 22) only from that IP.
  • Delete Firewall: Removes the temporary firewall to close SSH access after completing tasks.

Environment Variables:

Set the following in a .env file:

  • HETZNER_API_KEY: Your Hetzner API key.
  • HETZNER_SERVER_ID: The server ID where the firewall will be applied.
  • SSH_PORT (optional, defaults to 22): The SSH port to allow.
  • CREATED_BY (optional, defaults to "gh-actions"): Label to identify who created the firewall.

Commands:

  • Apply Firewall:

    python hetzner_firewall_manager.py apply
    
  • Delete Firewall:

    python hetzner_firewall_manager.py delete
    

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment