|
import java.security.KeyPairGenerator |
|
import java.security.MessageDigest |
|
import java.security.PublicKey |
|
import java.util.* |
|
import javax.crypto.Cipher |
|
import javax.crypto.KeyGenerator |
|
import javax.crypto.SecretKey |
|
import javax.crypto.spec.SecretKeySpec |
|
import kotlin.random.Random |
|
|
|
class Network { |
|
private val dns = mutableMapOf<UUID, NetworkNode>() |
|
|
|
fun register(id: UUID, node: NetworkNode) = apply { dns[id] = node } |
|
|
|
fun getRandomNodes(amount: Int): List<Pair<UUID, NetworkNode>> { |
|
if (dns.values.size < amount) { |
|
throw IllegalStateException("There are not enough nodes available in the network") |
|
} else if (dns.values.size == amount) { |
|
return dns.values.map { it.id to it }.shuffled() |
|
} |
|
|
|
val source = dns.values.toMutableList() |
|
val result = mutableListOf<Pair<UUID, NetworkNode>>() |
|
|
|
for (i in 1..amount) { |
|
val index = Random.nextInt(source.size) |
|
val node = source.removeAt(index) |
|
result.add(node.id to node) |
|
} |
|
|
|
return result |
|
} |
|
|
|
fun forwardMessage(to: UUID, message: ByteArray, hash: ByteArray): ByteArray { |
|
val node = dns[to] ?: throw IllegalStateException("Node '$to' does not exist.") |
|
|
|
return node.handleMessage(message, hash) |
|
} |
|
|
|
fun sendMessage(to: String, message: String): String = dns.values.random().sendMessage(to, message) |
|
} |
|
|
|
class NetworkNode( |
|
private val network: Network |
|
) { |
|
companion object { |
|
const val NODE_CIRCUIT_SIZE = 2 |
|
|
|
const val UUID_SIZE = 36 |
|
const val DEST_BYTES_SIZE = 1 + UUID_SIZE // @ + UUID |
|
|
|
const val CIPHER_ALGORITHM = "AES" |
|
const val KEY_PAIR_ALGORITHM = "RSA" |
|
|
|
val KEY_PAIR_GEN: KeyPairGenerator = KeyPairGenerator |
|
.getInstance(KEY_PAIR_ALGORITHM) |
|
.apply { initialize(2048) } |
|
|
|
val sha512: MessageDigest = MessageDigest.getInstance("SHA-512") |
|
} |
|
|
|
val id: UUID = UUID.randomUUID() |
|
private val keyPair = KEY_PAIR_GEN.generateKeyPair() |
|
private val symmetricKey: SecretKey = KeyGenerator.getInstance(CIPHER_ALGORITHM).generateKey() |
|
|
|
private val aesDecryptor = Cipher.getInstance(CIPHER_ALGORITHM) |
|
.apply { init(Cipher.DECRYPT_MODE, symmetricKey) } |
|
|
|
private val aesEncryptor = Cipher.getInstance(CIPHER_ALGORITHM) |
|
.apply { init(Cipher.ENCRYPT_MODE, symmetricKey) } |
|
|
|
init { |
|
network.register(id, this) |
|
} |
|
|
|
private fun <L, R, X> Pair<L, R>.mapValue(mapFn: (R) -> X): Pair<L, X> { |
|
return this.first to mapFn(this.second) |
|
} |
|
|
|
private fun getSymmetricKey(serverPublicKey: PublicKey): ByteArray = |
|
Cipher.getInstance(KEY_PAIR_ALGORITHM) |
|
.also { it.init(Cipher.ENCRYPT_MODE, serverPublicKey) } |
|
.doFinal(symmetricKey.encoded) |
|
|
|
/* |
|
* Here we need to decrypt the message. |
|
* |
|
* If the hash of the message matches the desired hash, it means we got to the end of the chain, and this |
|
* node will call whatever website it should call |
|
* |
|
* Otherwise, it means we are not the last in the chain, so we extract the destination from the message |
|
* and forward it again to the next node |
|
*/ |
|
fun handleMessage(message: ByteArray, hash: ByteArray): ByteArray { |
|
val decryptedByteArray = aesDecryptor.doFinal(message) |
|
|
|
val messageHash: ByteArray = sha512.digest(decryptedByteArray) |
|
|
|
val response = if (messageHash.contentEquals(hash)) { |
|
"200 OK".toByteArray() |
|
} else { |
|
val (to, encryptedMessage) = decryptedByteArray.let { |
|
val dest = it.copyOf(UUID_SIZE) |
|
val encryptedMessage = it.copyOfRange(DEST_BYTES_SIZE, it.size) |
|
|
|
UUID.fromString(String(dest)) to encryptedMessage |
|
} |
|
|
|
network.forwardMessage(to, encryptedMessage, hash) |
|
} |
|
|
|
return aesEncryptor.doFinal(response) |
|
} |
|
|
|
fun sendMessage(to: String, message: String): String { |
|
val decryptNodeKeyCipher = Cipher |
|
.getInstance(KEY_PAIR_ALGORITHM) |
|
.apply { init(Cipher.DECRYPT_MODE, keyPair.private) } |
|
|
|
/* |
|
* We need to build a random circuit that will be used to confuse people where the message is going |
|
* We send our publicKey so they can encrypt their symmetric key with it, so only this Node can decrypt it |
|
* and use it. |
|
*/ |
|
val nodeCircuit = network.getRandomNodes(NODE_CIRCUIT_SIZE) // For each random node on the circuit |
|
.map { |
|
it |
|
.mapValue { node -> node.getSymmetricKey(keyPair.public) } // Get its symmetric key encrypted with our public key |
|
.mapValue { encKey -> decryptNodeKeyCipher.doFinal(encKey) } // Decrypt it with our private key |
|
.mapValue { keyBytes -> |
|
SecretKeySpec( |
|
keyBytes, |
|
0, |
|
keyBytes.size, |
|
CIPHER_ALGORITHM |
|
) |
|
} // map to a SecretKey impl |
|
} |
|
|
|
val cipher = Cipher.getInstance(CIPHER_ALGORITHM) |
|
|
|
/* |
|
* We build the payload and the hash/signature of the message |
|
*/ |
|
val payload = "$to@$message" |
|
val hash = sha512.digest(payload.toByteArray()) |
|
|
|
/* |
|
* We need to reverse cascade encrypt the message together with the next node in the chain |
|
* key1("2@" + key2("3@" + key3(message))) |
|
* |
|
* So the message will be decrypted by key1, forwarded to k2, get decrypted by it |
|
* and forwarded to key3 and so on... |
|
*/ |
|
val encryptedMessage = nodeCircuit |
|
.reversed() |
|
.foldIndexed(payload.toByteArray()) { index, bytes, (_, key) -> |
|
val dest = if (index == 0) "" else nodeCircuit.reversed()[index - 1].first.toString() + "@" |
|
cipher.apply { init(Cipher.ENCRYPT_MODE, key) }.doFinal(dest.toByteArray() + bytes) |
|
} |
|
|
|
/* |
|
* We ask the network to forward the message to the next node |
|
*/ |
|
val response = network.forwardMessage(nodeCircuit[0].first, encryptedMessage, hash) |
|
|
|
/* |
|
* We need to cascade decrypt the message |
|
* key1(key2(key3(message))) |
|
*/ |
|
val decryptedResponse = nodeCircuit.fold(response) { bytes, (_, key) -> |
|
cipher.also { it.init(Cipher.DECRYPT_MODE, key) }.doFinal(bytes) |
|
} |
|
|
|
return String(decryptedResponse) |
|
} |
|
} |
|
|
|
val network = Network() |
|
val nodes = (1..5).map { NetworkNode(network) } |
|
|
|
val response = network.sendMessage("google", "GET /something") |
|
println(response) |