Skip to content

Instantly share code, notes, and snippets.

@filipeandre
Last active October 14, 2025 15:40
Show Gist options
  • Save filipeandre/90245f8634a3aeb523296d11804ae8b0 to your computer and use it in GitHub Desktop.
Save filipeandre/90245f8634a3aeb523296d11804ae8b0 to your computer and use it in GitHub Desktop.
This script acts as lock down switch for all regional aws waf for current region
#!/usr/bin/env python3
"""
Add a top-priority "block-all" rule to every AWS WAFv2 Web ACL in the current Region (REGIONAL scope),
without removing existing rules. The change is fully reversible via a local backup file.
Features
- Enumerates all REGIONAL Web ACLs in the configured AWS region
- Creates (or reuses) IP sets that match all IPv4 and IPv6 addresses (0.0.0.0/0 and ::/0)
- Inserts a new rule at priority 0 that blocks all traffic, shifting existing rule priorities down
- Stores a per-WebACL backup (original rules + metadata) under ./waf_backups/<region>/<web_acl_id>.json
- Revert command removes the block-all rule and restores the original rules exactly as saved in the backup
- Dry run mode for both apply and revert
- Optional role assumption before any API calls
Usage
Apply (add the rule):
python aws-waf-add-block-all-rule.py apply --region eu-west-1
Revert (restore from backup):
python aws-waf-add-block-all-rule.py revert --region eu-west-1
With role assumption:
python aws-waf-add-block-all-rule.py apply --region eu-west-1 --role-arn arn:aws:iam::123456789012:role/Admin
Dry run:
python aws-waf-add-block-all-rule.py apply --region eu-west-1 --dry-run
python aws-waf-add-block-all-rule.py revert --region eu-west-1 --dry-run
Notes
- Scope is REGIONAL only. If you need CLOUDFRONT scope, extend the script accordingly.
- Backups are local files; keep them safe to ensure reversibility.
"""
import argparse
import json
import os
import sys
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import boto3
from botocore.exceptions import ClientError
BLOCK_ALL_RULE_NAME = "__TEMP_BLOCK_ALL__"
IPSET_V4_NAME = "__TEMP_ALL_IPV4__"
IPSET_V6_NAME = "__TEMP_ALL_IPV6__"
BACKUP_DIR = "waf_backups"
SCOPE = "REGIONAL" # This script targets the current region (REGIONAL) only
@dataclass
class AwsClients:
waf: any
waf_paginator: any
sts: any
def assume_role_if_needed(role_arn: Optional[str], session: Optional[boto3.Session] = None) -> boto3.Session:
"""Assume an IAM role if role_arn is provided; otherwise return the given or default session.
Always called before any boto3 client usage.
"""
base_sess = session or boto3.Session()
if not role_arn:
return base_sess
sts = base_sess.client("sts")
resp = sts.assume_role(RoleArn=role_arn, RoleSessionName="waf-blockall-script")
creds = resp["Credentials"]
return boto3.Session(
aws_access_key_id=creds["AccessKeyId"],
aws_secret_access_key=creds["SecretAccessKey"],
aws_session_token=creds["SessionToken"],
region_name=base_sess.region_name,
)
def make_clients(region: str, role_arn: Optional[str]) -> AwsClients:
base = boto3.Session(region_name=region)
sess = assume_role_if_needed(role_arn, base)
waf = sess.client("wafv2", region_name=region)
sts = sess.client("sts", region_name=region)
paginator = waf.get_paginator("list_web_acls")
return AwsClients(waf=waf, waf_paginator=paginator, sts=sts)
def ensure_ipset(clients: AwsClients, name: str, description: str, addresses: List[str], region: str, scope: str = SCOPE) -> Tuple[str, str]:
"""Create or return an IPSet that contains the given addresses. Returns (Id, ARN)."""
waf = clients.waf
# Try to find existing by name
next_marker = None
while True:
kwargs = {"Scope": scope}
if next_marker:
kwargs["NextMarker"] = next_marker
page = waf.list_ip_sets(**kwargs)
for ipset_summary in page.get("IPSets", []):
if ipset_summary.get("Name") == name:
# fetch details for CIDR check and return
got = waf.get_ip_set(Scope=scope, Name=name, Id=ipset_summary["Id"])
is_same = sorted(got["IPSet"].get("Addresses", [])) == sorted(addresses)
if not is_same:
# Update it to desired addresses
waf.update_ip_set(
Scope=scope,
Name=name,
Id=ipset_summary["Id"],
Addresses=addresses,
LockToken=got["LockToken"],
)
return ipset_summary["Id"], ipset_summary["ARN"]
next_marker = page.get("NextMarker")
if not next_marker:
break
# Create new IPSet
ip_version = "IPV4" if any(a.count(":") == 0 for a in addresses) and addresses[0] != "::/0" else ("IPV6" if ":" in addresses[0] else "IPV4")
resp = waf.create_ip_set(
Name=name,
Scope=scope,
Description=description,
IPAddressVersion=ip_version,
Addresses=addresses,
)
return resp["Summary"]["Id"], resp["Summary"]["ARN"]
def list_all_web_acls(clients: AwsClients, scope: str = SCOPE) -> List[Dict]:
items = []
for page in clients.waf_paginator.paginate(Scope=scope):
items.extend(page.get("WebACLs", []))
return items
def load_backup_path(region: str, web_acl_id: str) -> str:
path = os.path.join(BACKUP_DIR, region, f"{web_acl_id}.json")
os.makedirs(os.path.dirname(path), exist_ok=True)
return path
def save_backup(region: str, web_acl_id: str, data: Dict) -> None:
path = load_backup_path(region, web_acl_id)
with open(path, "w", encoding="utf-8") as f:
json.dump(data, f, indent=2, sort_keys=True)
def load_backup(region: str, web_acl_id: str) -> Optional[Dict]:
path = load_backup_path(region, web_acl_id)
if not os.path.exists(path):
return None
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
def build_block_all_rule(ipv4_ipset_arn: str, ipv6_ipset_arn: Optional[str]) -> Dict:
statements = []
# IPv4 all
statements.append({
"IPSetReferenceStatement": {
"ARN": ipv4_ipset_arn
}
})
if ipv6_ipset_arn:
statements.append({
"IPSetReferenceStatement": {
"ARN": ipv6_ipset_arn
}
})
statement: Dict
if len(statements) == 1:
statement = statements[0]
else:
statement = {"OrStatement": {"Statements": statements}}
return {
"Name": BLOCK_ALL_RULE_NAME,
"Priority": 0, # will be placed at the top
"Statement": statement,
"Action": {"Block": {}},
"VisibilityConfig": {
"SampledRequestsEnabled": True,
"CloudWatchMetricsEnabled": True,
"MetricName": BLOCK_ALL_RULE_NAME,
},
}
def apply_block_all_to_acl(clients: AwsClients, region: str, web_acl_summary: Dict, dry_run: bool) -> None:
waf = clients.waf
name = web_acl_summary["Name"]
web_acl_id = web_acl_summary["Id"]
# Fetch full WebACL (need LockToken, DefaultAction, Rules, VisibilityConfig...)
got = waf.get_web_acl(Scope=SCOPE, Name=name, Id=web_acl_id)
web_acl = got["WebACL"]
lock = got["LockToken"]
# Idempotency: if the rule already exists, skip
existing_rules: List[Dict] = web_acl.get("Rules", [])
if any(r.get("Name") == BLOCK_ALL_RULE_NAME for r in existing_rules):
print(f"[SKIP] {name} already has {BLOCK_ALL_RULE_NAME}")
return
# Backup original state for revert
backup_payload = {
"Name": name,
"Id": web_acl_id,
"Scope": SCOPE,
"Region": region,
"OriginalRules": existing_rules,
"DefaultAction": web_acl.get("DefaultAction"),
"VisibilityConfig": web_acl.get("VisibilityConfig"),
"Description": web_acl.get("Description"),
"CaptchaConfig": web_acl.get("CaptchaConfig"),
"ChallengeConfig": web_acl.get("ChallengeConfig"),
}
save_backup(region, web_acl_id, backup_payload)
# Ensure all-IP IP sets exist
v4_id, v4_arn = ensure_ipset(
clients,
IPSET_V4_NAME,
"Temporary all-IPv4 IPSet for block-all rule (0.0.0.0/0)",
["0.0.0.0/0"],
region,
)
v6_id, v6_arn = None, None
try:
v6_id, v6_arn = ensure_ipset(
clients,
IPSET_V6_NAME,
"Temporary all-IPv6 IPSet for block-all rule (::/0)",
["::/0"],
region,
)
except ClientError as e:
# Accounts/regions without IPv6 can ignore this
if e.response.get("Error", {}).get("Code") not in {"WAFInvalidParameterException", "WAFUnavailableEntityException"}:
raise
print("[WARN] IPv6 IPSet creation failed or unsupported; proceeding with IPv4 only.")
# Build new rule list with priorities shifted by +1
new_rules: List[Dict] = []
new_rules.append(build_block_all_rule(v4_arn, v6_arn))
# Sort existing by priority and shift
for r in sorted(existing_rules, key=lambda x: x.get("Priority", 0)):
r_copy = json.loads(json.dumps(r)) # deep copy
r_copy["Priority"] = (r_copy.get("Priority", 0) + 1)
new_rules.append(r_copy)
print(f"[APPLY] {name}: will add {BLOCK_ALL_RULE_NAME} at priority 0 and shift {len(existing_rules)} rules")
if dry_run:
print("[DRY-RUN] Skipping update_web_acl()")
return
# Update the WebACL
update_kwargs = {
"Name": name,
"Scope": SCOPE,
"Id": web_acl_id,
"DefaultAction": web_acl["DefaultAction"],
"Description": web_acl.get("Description", ""),
"Rules": new_rules,
"VisibilityConfig": web_acl["VisibilityConfig"],
"LockToken": lock,
}
# Preserve optional configs when present
if "CaptchaConfig" in web_acl:
update_kwargs["CaptchaConfig"] = web_acl["CaptchaConfig"]
if "ChallengeConfig" in web_acl:
update_kwargs["ChallengeConfig"] = web_acl["ChallengeConfig"]
waf.update_web_acl(**update_kwargs)
print(f"[DONE] {name}: block-all rule added")
def revert_acl_from_backup(clients: AwsClients, region: str, web_acl_summary: Dict, dry_run: bool) -> None:
waf = clients.waf
name = web_acl_summary["Name"]
web_acl_id = web_acl_summary["Id"]
# Load backup
backup = load_backup(region, web_acl_id)
if not backup:
print(f"[SKIP] No backup found for {name} ({web_acl_id}), nothing to revert.")
return
got = waf.get_web_acl(Scope=SCOPE, Name=name, Id=web_acl_id)
web_acl = got["WebACL"]
lock = got["LockToken"]
# Prepare restore rules (exactly as saved)
original_rules = backup.get("OriginalRules", [])
print(f"[REVERT] {name}: will restore {len(original_rules)} original rules and remove {BLOCK_ALL_RULE_NAME} if present")
if dry_run:
print("[DRY-RUN] Skipping update_web_acl()")
return
update_kwargs = {
"Name": name,
"Scope": SCOPE,
"Id": web_acl_id,
"DefaultAction": backup.get("DefaultAction", web_acl.get("DefaultAction")),
"Description": backup.get("Description", web_acl.get("Description", "")),
"Rules": original_rules,
"VisibilityConfig": backup.get("VisibilityConfig", web_acl.get("VisibilityConfig")),
"LockToken": lock,
}
if backup.get("CaptchaConfig"):
update_kwargs["CaptchaConfig"] = backup["CaptchaConfig"]
if backup.get("ChallengeConfig"):
update_kwargs["ChallengeConfig"] = backup["ChallengeConfig"]
waf.update_web_acl(**update_kwargs)
print(f"[DONE] {name}: restored original rule set")
def cmd_apply(clients: AwsClients, region: str, dry_run: bool) -> int:
count = 0
for acl in list_all_web_acls(clients):
try:
apply_block_all_to_acl(clients, region, acl, dry_run)
count += 1
except ClientError as e:
print(f"[ERROR] {acl['Name']}: {e}")
print(f"[SUMMARY] Processed {count} Web ACL(s) in {region} (apply)")
return 0
def cmd_revert(clients: AwsClients, region: str, dry_run: bool) -> int:
count = 0
for acl in list_all_web_acls(clients):
try:
revert_acl_from_backup(clients, region, acl, dry_run)
count += 1
except ClientError as e:
print(f"[ERROR] {acl['Name']}: {e}")
print(f"[SUMMARY] Processed {count} Web ACL(s) in {region} (revert)")
return 0
def parse_args(argv: List[str]) -> argparse.Namespace:
p = argparse.ArgumentParser(description="Add (and revert) a temporary top-priority block-all rule to all WAFv2 REGIONAL Web ACLs.")
sub = p.add_subparsers(dest="command", required=True)
common = argparse.ArgumentParser(add_help=False)
common.add_argument("--region", required=True, help="AWS region, e.g. eu-west-1")
common.add_argument("--role-arn", default=None, help="Optional role ARN to assume before making any API calls")
common.add_argument("--dry-run", action="store_true", help="Do not perform updates; print planned changes only")
sub_apply = sub.add_parser("apply", parents=[common], help="Add the block-all rule at the top of every Web ACL")
sub_revert = sub.add_parser("revert", parents=[common], help="Revert using local backups (restores original rules)")
return p.parse_args(argv)
def main(argv: List[str]) -> int:
args = parse_args(argv)
clients = make_clients(args.region, args.role_arn)
if args.command == "apply":
return cmd_apply(clients, args.region, args.dry_run)
elif args.command == "revert":
return cmd_revert(clients, args.region, args.dry_run)
else:
print("Unknown command")
return 2
if __name__ == "__main__":
sys.exit(main(sys.argv[1:]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment