Skip to content

Instantly share code, notes, and snippets.

@jessiepathfinder
Created April 30, 2025 20:02
Show Gist options
  • Save jessiepathfinder/4c0315cc356b0397ea78aaf5873a1e78 to your computer and use it in GitHub Desktop.
Save jessiepathfinder/4c0315cc356b0397ea78aaf5873a1e78 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.optim as optim
import hashlib
import numpy as np
import math
from scipy.stats import binomtest
from functorch.compile import memory_efficient_fusion
# deterministic test
torch.manual_seed(1234567890)
np.random.seed(987654321)
# ====== Config ======
BATCH_SIZE = 4096
INPUT_BITS = 512
HASH_BITS = 256
BATCHES = 10000
VAL_SIZE = 65536
iscuda = torch.cuda.is_available()
DEVICE = 'cuda' if iscuda else 'cpu'
torch.set_default_device(DEVICE)
torch.set_default_dtype(torch.float32)
if iscuda:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# ====== Helper Functions ======
def sha256_hash_bits(message_bytes):
digest = hashlib.sha256(message_bytes).digest()
return np.unpackbits(np.frombuffer(digest, dtype=np.uint8))
def generate_batch(batch_size):
messages = np.random.randint(0, 2, size=(batch_size, INPUT_BITS), dtype=np.uint8)
hashes = []
labels = []
for msg in messages:
msg_bytes = np.packbits(msg)
hash_bits = sha256_hash_bits(msg_bytes)
hashes.append(hash_bits)
labels.append(msg[-1]) # final bit of original message
return (
torch.tensor(np.array(hashes), dtype=torch.float32),
torch.tensor(np.array(labels), dtype=torch.float32).unsqueeze(1)
)
def makekaiminglinear(inputs, outputs, bias = True, gain = 1.0):
lin = torch.nn.Linear(inputs, outputs, bias)
torch.nn.init.normal_(lin.weight,0.0,gain / math.sqrt(inputs))
return lin
class Arctan(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, input):
return input.atan()
class ConstantMul(torch.nn.Module):
def __init__(self,x):
super().__init__()
self.x = x
def forward(self, input):
return input.mul(self.x)
class ConstantSub(torch.nn.Module):
def __init__(self,x):
super().__init__()
self.x = x
def forward(self, input):
return input.sub(self.x)
class ConstantDiv(torch.nn.Module):
def __init__(self,x):
super().__init__()
self.x = x
def forward(self, input):
return input.div(self.x)
# SUPER DUPER IMPORTANT thing
def p_value_binomial_test(validation_size, accuracy):
correct = int(round(validation_size * accuracy))
result = binomtest(k=correct, n=validation_size, p=0.5, alternative='greater')
return result.pvalue
# ====== Model ======
generator_gain = 1.48663410298
arctan_mod = Arctan()
model = nn.Sequential(
ConstantSub(0.5),
ConstantMul(0.90747 * math.sqrt(12)),
makekaiminglinear(HASH_BITS, 1024,gain=generator_gain),arctan_mod,
makekaiminglinear(1024, 1024,gain=generator_gain),arctan_mod,
makekaiminglinear(1024, 1024,gain=generator_gain),arctan_mod,
makekaiminglinear(1024, 1024,gain=generator_gain),arctan_mod,
makekaiminglinear(1024, 1024,gain=generator_gain),arctan_mod,
makekaiminglinear(1024, 1,gain=generator_gain),
ConstantDiv(0.90747)
)
fast_model = memory_efficient_fusion(model)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.SGD(model.parameters(),lr=0.01,momentum=0.9)
# ====== Training Loop ======
for batch in range(BATCHES):
X, y = generate_batch(BATCH_SIZE)
X, y = X.to(DEVICE), y.to(DEVICE)
optimizer.zero_grad()
outputs = fast_model(X)
loss = criterion(outputs, y)
loss.backward()
optimizer.step()
predictions = (outputs > 0.0).float()
print(f"Batch {batch+1}: Loss = {loss.tolist():.4f}, Accuracy = {(((predictions == y).sum().item()) / BATCH_SIZE):.4f}")
# ====== Inference Test ======
X_test, y_test = generate_batch(VAL_SIZE)
with torch.no_grad():
y_pred = model(X_test.to(DEVICE))
pred_labels = (y_pred > 0.0).float()
test_accuracy = (pred_labels == y_test).float().mean().item()
print(f"\nTest accuracy: {test_accuracy:.4f}")
p_value = p_value_binomial_test(VAL_SIZE, test_accuracy)
print(f"p-value: {p_value:.6f}")
if p_value > 0.95:
print("No concern: SHA-256 neural preimage resistance intact")
elif p_value > 0.5:
print("Minor concern: Potential weaknesses in SHA-256 neural preimage resistance")
elif p_value > 0.05:
print("Major concern: Significant vulnerabilities detected")
else:
print("Disproven: SHA-256 neural preimage resistance broken")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment