Skip to content

Instantly share code, notes, and snippets.

@turekt
Created March 6, 2023 17:40
Show Gist options
  • Save turekt/41b6948701e331ad6524714f99e8cac8 to your computer and use it in GitHub Desktop.
Save turekt/41b6948701e331ad6524714f99e8cac8 to your computer and use it in GitHub Desktop.
Minimal Python3 script for ansible vault encryption and decryption
from argparse import ArgumentParser
from binascii import hexlify, unhexlify
from code import InteractiveConsole
from cryptography.hazmat.primitives import hashes, hmac, padding
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from enum import Enum
from os import urandom
from sys import exit, stdin
from textwrap import wrap
from typing import Tuple
PROMPT_INTERACTIVE = ">>> "
PROMPT_INPUT_PASSWORD = "Input password: "
class VaultOperation(str, Enum):
ENCRYPT = "encrypt"
DECRYPT = "decrypt"
class AnsibleVault:
KEYS_LEN = 80
PADDING_SIZE = 128
HEADER = b"$ANSIBLE_VAULT;1.1;AES256"
def __init__(self, password: bytes, salt: bytes):
self.password = password
self.salt = salt
self.key, self.hmac_key, self.nonce = self._derive_keys()
def _derive_keys(self) -> Tuple[bytes, bytes, bytes]:
kdf = PBKDF2HMAC(algorithm=hashes.SHA256(), length=AnsibleVault.KEYS_LEN, salt=self.salt, iterations=10000)
k = kdf.derive(self.password)
# key, hmac_key, iv
return k[:32], k[32:64], k[64:AnsibleVault.KEYS_LEN]
def _decrypt(self, ciphertext: bytes) -> bytes:
cipher = Cipher(algorithms.AES(self.key), modes.CTR(self.nonce))
decryptor = cipher.decryptor()
plaintext = decryptor.update(ciphertext) + decryptor.finalize()
return plaintext
def _encrypt(self, plaintext: bytes) -> bytes:
cipher = Cipher(algorithms.AES(self.key), modes.CTR(self.nonce))
encryptor = cipher.encryptor()
ciphertext = encryptor.update(plaintext) + encryptor.finalize()
return ciphertext
def calculate_hmac(self, ciphertext: bytes) -> bytes:
h = hmac.HMAC(self.hmac_key, hashes.SHA256())
h.update(ciphertext)
signature = h.finalize()
return signature
def encrypt(self, plaintext: bytes) -> bytes:
ciphertext = self._encrypt(self.pad(plaintext))
hm = self.calculate_hmac(ciphertext)
return self.salt, hm, ciphertext
def decrypt(self, ciphertext: bytes, hmac_content: bytes = None) -> bytes:
h = self.calculate_hmac(ciphertext)
if hmac_content and h != hmac_content:
raise ValueError("HMAC verification failed")
return self.unpad(self._decrypt(ciphertext))
def pad(self, data: bytes) -> bytes:
padder = padding.PKCS7(AnsibleVault.PADDING_SIZE).padder()
padded = padder.update(data) + padder.finalize()
return padded
def unpad(self, padded: bytes) -> bytes:
unpadder = padding.PKCS7(AnsibleVault.PADDING_SIZE).unpadder()
data = unpadder.update(padded) + unpadder.finalize()
return data
@staticmethod
def unpack(content: bytes) -> Tuple[bytes,bytes, bytes, bytes]:
header, hex_value = content.split(b'\n', 1)
hex_value = hex_value.replace(b'\n', b'').strip()
hlen = len(hex_value)
hex_value = hex_value.zfill(hlen + 1) if hlen % 2 else hex_value
value = unhexlify(hex_value)
values = list(map(unhexlify, value.split(b'\n')))
values.insert(0, header)
return values
@staticmethod
def pack(salt: bytes, hm: bytes, ciphertext: bytes) -> bytes:
hex_values = b'\n'.join([hexlify(salt), hexlify(hm), hexlify(ciphertext)])
hhv = hexlify(hex_values)
content = b'\n'.join([hhv[i:i+AnsibleVault.KEYS_LEN] for i in range(0, len(hhv), AnsibleVault.KEYS_LEN)])
return AnsibleVault.HEADER + b'\n' + content
def interactive():
c = InteractiveConsole()
data = ""
prompt = PROMPT_INTERACTIVE
while True:
try:
inp = c.raw_input(prompt)
if not inp:
data_bytes = data.encode()
transform(data_bytes, data_bytes.startswith(AnsibleVault.HEADER[:14]))
data = ""
prompt = PROMPT_INTERACTIVE
else:
data += '\n' if data else ''
data += inp
prompt = ""
except (KeyboardInterrupt, EOFError):
print()
break
except Exception as e:
print(e)
break
def encrypt(data, passwd):
av = AnsibleVault(passwd, urandom(32))
s, hm, ct = av.encrypt(data)
return AnsibleVault.pack(s, hm, ct)
def decrypt(data, passwd):
header, salt, hm, ct = AnsibleVault.unpack(data)
av = AnsibleVault(passwd, salt)
return av.decrypt(ct, hmac_content=hm)
def transform(data, cond):
passwd = input(PROMPT_INPUT_PASSWORD).encode()
f = decrypt if cond else encrypt
print(f(data, passwd).decode())
if __name__ == "__main__":
parser = ArgumentParser(description="Minimal ansible vault encryption and decryption script")
parser.add_argument("-i", "--interactive", help="Run script in interactive mode", action="store_true")
parser.add_argument("-o", "--operation", default=VaultOperation.ENCRYPT, choices=[e.value for e in VaultOperation], help="Operation to execute")
parser.add_argument("-f", "--file", default=stdin.fileno(), help="Path to file on which to execute the operation")
args = parser.parse_args()
if args.interactive:
interactive()
exit(0)
with open(args.file, "rb") as fp:
transform(fp.read(), args.operation == VaultOperation.DECRYPT)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment