# https://github.com/P-H-C/phc-string-format/blob/master/phc-sf-spec.md

import re
import base64

PHC_STRING_RE = re.compile(r"""
	\$(?P<id>[a-z0-9-]{1,32})
	(\$v=(?P<version>[0-9]+))?
	(\$(?P<params>[a-z0-9-]{1,32}=[a-zA-Z0-9/+.-=,]*))? # NOTE: full param parse happens later, this just checks at least one param exists
	(
		\$(?P<salt>[a-zA-Z0-9/+.-]*) # NOTE: zero-length salt is allowed?
		(\$(?P<hash>[a-zA-Z0-9/+]+))?
	)? # NOTE: no trailing $
""", re.X)

def unpadded_base64_encode(data: bytes) -> str:
	return base64.b64encode(data).decode().rstrip("=")

def canonical_unpadded_base64_decode(data: str) -> bytes:
	decoded = base64.b64decode(data + "===")
	roundtrip = unpadded_base64_encode(decoded)# XXX: there are probably cheaper ways to perform this check
	if data != roundtrip:
		raise ValueError("non-canonical base64 encoding")
	return decoded

def canonical_int_decode(data: str) -> int:
	if not re.fullmatch(r"-?[0-9]+", data):
		raise ValueError("invalid int encoding")
	decoded = int(data)
	roundtrip = str(decoded)
	if data != roundtrip:
		raise ValueError("non-canonical int encoding")
	return decoded

# this should be generic to any PHC-conformant hash string
def parse_phc_string(encoded: str) -> dict:
	match = PHC_STRING_RE.fullmatch(encoded)
	if match is None:
		raise ValueError("invalid hash string")
	match_dict = match.groupdict()

	if match_dict["version"]:
		match_dict["version"] = canonical_int_decode(match_dict["version"])

	if match_dict["params"]:
		params = {} # NOTE: dict key order matters for ensuring canonical-ness
		for param in match_dict["params"].split(","):
			param_match = re.fullmatch(r"([a-z0-9-]{1,32})=([a-zA-Z0-9/+.-]*)", param)
			if param_match is None:
				raise ValueError("invalid hash parameter")
			k, v = param_match.groups()
			if k in params:
				raise ValueError("duplicate parameter name")
			params[k] = v
		match_dict["params"] = params

	# NOTE: if salt exists, it SHOULD be base64, but it might not be, so we don't decode here

	if match_dict["hash"]:
		match_dict["hash"] = canonical_unpadded_base64_decode(match_dict["hash"])#

	return match_dict

def parse_argon2id_v19_string(encoded: str) -> dict:
	parsed = parse_phc_string(encoded)
	if parsed["id"] != "argon2id":
		raise ValueError("not an argon2id hash")
	if parsed["version"] != 19:
		raise ValueError("unsupported argon2id version")
	if parsed["params"]:
		canonical_keys = ["m", "t", "p", "keyid", "data"]
		parsed_keys = list(parsed["params"].keys())
		if len(parsed_keys) > 5 or canonical_keys[:len(parsed_keys)] != parsed_keys:
			raise ValueError("invalid params")
		decoded_params = {
			"m": canonical_int_decode(parsed["params"]["m"]),
			"t": canonical_int_decode(parsed["params"]["t"]),
			"p": canonical_int_decode(parsed["params"]["p"]),
		}
		if decoded_params["m"] not in range(1, 2**32):
			raise ValueError("m param out of range")
		if decoded_params["t"] not in range(1, 2**32):
			raise ValueError("t param out of range")
		if decoded_params["p"] not in range(1, 255+1):
			raise ValueError("p param out of range")
		if "keyid" in parsed["params"]:
			decoded_params["keyid"] = canonical_unpadded_base64_decode(parsed["params"]["keyid"])
			if len(decoded_params["keyid"]) > 8:
				raise ValueError("keyid too long")
		if "data" in parsed["params"]:
			decoded_params["data"] = canonical_unpadded_base64_decode(parsed["params"]["data"])
			if len(decoded_params["data"]) > 32:
				raise ValueError("associated data too long")
		parsed["params"] = decoded_params
	if parsed["salt"] is not None:
		salt = canonical_unpadded_base64_decode(parsed["salt"])
		if len(salt) not in range(8, 48+1):
			raise ValueError("invalid salt length")
		parsed["salt"] = salt
	if parsed["hash"]:
		if len(parsed["hash"]) not in range(12, 64+1):
			raise ValueError("invalid hash length")
	return parsed

from cryptography.hazmat.primitives.kdf.argon2 import Argon2id
import os
from typing import Optional

def create_argon2id_password(password: bytes):
	# TODO: support custom parameters!!!
	salt = os.urandom(16)
	kdf = Argon2id(
		salt=salt,
		length=32,
		iterations=1,
		lanes=4,
		memory_cost=64 * 1024,
		ad=None,
		secret=None,
	)
	digest = kdf.derive(password)
	return f"$argon2id$v=19$m={64 * 1024},t={1},p={4}${unpadded_base64_encode(salt)}${unpadded_base64_encode(digest)}"

def verify_argon2id_password(encoded_hash: str, password: bytes, secret: Optional[bytes]=None):
	hashinfo = parse_argon2id_v19_string(encoded_hash)
	if "keyid" in hashinfo["params"]:
		raise NotImplementedError("keyed hashing unsupported")
	kdf = Argon2id(
		salt=hashinfo["salt"],
		length=len(hashinfo["hash"]),
		iterations=hashinfo["params"]["t"],
		lanes=hashinfo["params"]["p"],
		memory_cost=hashinfo["params"]["m"],
		ad=hashinfo["params"].get("data"),
		secret=secret,
	)
	kdf.verify(password, hashinfo["hash"])

print(parse_phc_string("$argon2id"))
print(parse_phc_string("$argon2id$v=19"))
print(parse_phc_string("$argon2id$v=19$m=65536,t=2,p=1"))
print(parse_phc_string("$argon2id$v=19$m=65536,t=2,p=1$gZiV/M1gPc22ElAH/Jh1Hw"))
print(parse_phc_string("$argon2id$v=19$m=65536,t=2,p=1$gZiV/M1gPc22ElAH/Jh1Hw$CWOrkoo7oJBQ/iyh7uJ0LO2aLEfrHwTWllSAxT0zRno"))
print(parse_phc_string("$argon2id$m=65536,t=2,p=1$gZiV/M1gPc22ElAH/Jh1Hw$CWOrkoo7oJBQ/iyh7uJ0LO2aLEfrHwTWllSAxT0zRno"))
print(parse_phc_string("$argon2id$v=19$gZiV/M1gPc22ElAH/Jh1Hw$CWOrkoo7oJBQ/iyh7uJ0LO2aLEfrHwTWllSAxT0zRno"))
print(parse_phc_string("$argon2id$gZiV/M1gPc22ElAH/Jh1Hw$CWOrkoo7oJBQ/iyh7uJ0LO2aLEfrHwTWllSAxT0zRno"))

print(parse_phc_string("$argon2id$v=19$m=65536,t=2,p=1$")) # weird but valid (I think) edge case - zero-length salt

print(parse_argon2id_v19_string("$argon2id$v=19$m=65536,t=2,p=1$gZiV/M1gPc22ElAH/Jh1Hw$CWOrkoo7oJBQ/iyh7uJ0LO2aLEfrHwTWllSAxT0zRno"))

verify_argon2id_password("$argon2id$v=19$m=65536,t=2,p=1$gZiV/M1gPc22ElAH/Jh1Hw$CWOrkoo7oJBQ/iyh7uJ0LO2aLEfrHwTWllSAxT0zRno", b"hunter2", b"pepper")

foo = create_argon2id_password(b"hello")
verify_argon2id_password(foo, b"hello")
try:
	verify_argon2id_password(foo, b"wrong")
except Exception as e:
	print(e)