|
import java.math.BigInteger |
|
import java.security.MessageDigest |
|
import java.security.SecureRandom |
|
import javax.crypto.Cipher |
|
import javax.crypto.spec.SecretKeySpec |
|
import kotlin.system.measureTimeMillis |
|
|
|
val random = SecureRandom() |
|
|
|
// For brevity, public G and N are being pre-computed |
|
val g: BigInteger = BigInteger.probablePrime(10, random) |
|
val n: BigInteger = BigInteger.probablePrime(1000, random) |
|
|
|
enum class MessageType { |
|
HANDSHAKE_HELLO, HANDSHAKE_FINISH, HANDSHAKE_ERROR, ENCRYPTED_MESSAGE |
|
} |
|
|
|
class Message(val type: MessageType, val payload: ByteArray) |
|
|
|
/* |
|
Using a simple AES algorithm with a symmetric key derived from the session-key |
|
using a SHA-512 hash. |
|
|
|
Apparently the usage of truncate(SHA-512, 32) is better than SHA-256 |
|
*/ |
|
class AESCipher(keyBytes: ByteArray) { |
|
companion object { |
|
private const val ALGORITHM = "AES" |
|
private val SHA512: MessageDigest = MessageDigest.getInstance("SHA-512") |
|
} |
|
|
|
private val encryptor: Cipher |
|
private val decryptor: Cipher |
|
|
|
init { |
|
val keySpec = SHA512.digest(keyBytes) |
|
.copyOf(32) |
|
.let { SecretKeySpec(it, ALGORITHM) } |
|
|
|
encryptor = Cipher.getInstance(ALGORITHM).also { it.init(Cipher.ENCRYPT_MODE, keySpec) } |
|
decryptor = Cipher.getInstance(ALGORITHM).also { it.init(Cipher.DECRYPT_MODE, keySpec) } |
|
} |
|
|
|
fun encrypt(message: ByteArray): ByteArray = encryptor.doFinal(message) |
|
fun decrypt(message: ByteArray): ByteArray = decryptor.doFinal(message) |
|
} |
|
|
|
/* |
|
This class would represent a channel out in the public, for example the internet |
|
where you can have people intercepting the messages going through it |
|
*/ |
|
inner class PublicChannel { |
|
private val dnsServer = mutableMapOf<String, Actor>() |
|
|
|
fun register(name: String, actor: Actor) = apply { |
|
dnsServer[name] = actor |
|
} |
|
|
|
fun sendMessage(from: String, to: String, message: Message): Message { |
|
val dest: Actor = dnsServer[to] |
|
?: throw IllegalArgumentException("Actor $to does not exist") |
|
|
|
println("PublicChannel intercepted a message: ${interceptedMessageAsString(message)}") |
|
|
|
val response = dest.receiveMessage(from, message) |
|
|
|
println("PublicChannel intercepted a response: ${interceptedMessageAsString(response)}") |
|
|
|
return response |
|
} |
|
|
|
private fun interceptedMessageAsString(message: Message): Any = |
|
"{type=${message.type}, message=" + when (message.type) { |
|
MessageType.HANDSHAKE_HELLO -> BigInteger(message.payload) |
|
MessageType.HANDSHAKE_FINISH -> String(message.payload) |
|
MessageType.HANDSHAKE_ERROR -> "Handshake error" |
|
MessageType.ENCRYPTED_MESSAGE -> String(message.payload) |
|
} + "}" |
|
} |
|
|
|
/* |
|
This class represents two entities that are going to publicly exchange their public keys |
|
and talk between each other using encrypted messages after the exchange |
|
*/ |
|
inner class Actor( |
|
private val name: String, |
|
private val channel: PublicChannel |
|
) { |
|
private val privateKey = BigInteger.probablePrime(n.bitLength() / 2, random) |
|
private val publicKey = g.modPow(privateKey, n).toByteArray() |
|
private var cipher: AESCipher? = null |
|
|
|
init { |
|
channel.register(name, this) |
|
} |
|
|
|
fun receiveMessage(from: String, message: Message): Message = when (message.type) { |
|
MessageType.HANDSHAKE_HELLO -> handleHandshakeHello(message) |
|
MessageType.HANDSHAKE_FINISH -> handleHandshakeFinish(from, message) |
|
MessageType.HANDSHAKE_ERROR -> handleHandshakeError() |
|
MessageType.ENCRYPTED_MESSAGE -> handleEncryptedMessage(message) |
|
} |
|
|
|
private fun createCipher(otherPublicKey: ByteArray) { |
|
val sessionKey = BigInteger(otherPublicKey).modPow(privateKey, n) |
|
cipher = AESCipher(sessionKey.toByteArray()) |
|
} |
|
|
|
private fun handshake(to: String) { |
|
// We send our public key and get the other public key |
|
val othersPublicKey = channel |
|
.sendMessage(name, to, Message(MessageType.HANDSHAKE_HELLO, publicKey)).payload |
|
|
|
createCipher(othersPublicKey) |
|
|
|
// As a finish message to verify the encryption, I'm sending simply "$name:finish" |
|
// and validating it on the other side. |
|
val response = channel |
|
.sendMessage(name, to, Message(MessageType.HANDSHAKE_FINISH, buildHandshakeFinishPayload(name))) |
|
|
|
if (response.type == MessageType.HANDSHAKE_FINISH) { |
|
val expectedMessage = buildHandshakeFinishPayload(to) |
|
|
|
if (!response.payload.contentEquals(expectedMessage)) { |
|
throw IllegalStateException("Failed to verify HANDSHAKE_FINISH message") |
|
} |
|
} else throw IllegalStateException("Invalid during handshake state: ${response.type}") |
|
} |
|
|
|
private fun handleHandshakeHello(message: Message): Message { |
|
createCipher(message.payload) |
|
return Message(MessageType.HANDSHAKE_HELLO, publicKey) |
|
} |
|
|
|
private fun buildHandshakeFinishPayload(name: String) = |
|
"$name:finished".let { cipher!!.encrypt(it.toByteArray()) } |
|
|
|
private fun handleHandshakeFinish(from: String, message: Message): Message { |
|
val expectedMessage = buildHandshakeFinishPayload(from) |
|
|
|
if (!expectedMessage.contentEquals(message.payload)) { |
|
return Message(MessageType.HANDSHAKE_ERROR, byteArrayOf()) |
|
} |
|
|
|
return Message(MessageType.HANDSHAKE_FINISH, buildHandshakeFinishPayload(name)) |
|
} |
|
|
|
private fun handleHandshakeError(): Nothing = throw IllegalStateException("$name got a handshake error") |
|
|
|
private fun handleEncryptedMessage(message: Message): Message { |
|
val plainTextMessage = String(cipher!!.decrypt(message.payload)) |
|
|
|
println("$name got: $plainTextMessage") |
|
|
|
return Message(message.type, cipher!!.encrypt("Hey there!".toByteArray())) |
|
} |
|
|
|
fun sendMessage(to: String, message: String): String { |
|
if (cipher == null) handshake(to) |
|
|
|
val encStr = cipher!!.encrypt(message.toByteArray()) |
|
|
|
val response = channel |
|
.sendMessage(name, to, Message(MessageType.ENCRYPTED_MESSAGE, encStr)) |
|
|
|
return String(cipher!!.decrypt(response.payload)) |
|
} |
|
} |
|
|
|
val channel = PublicChannel() |
|
|
|
val alice = Actor("alice", channel) |
|
val bob = Actor("bob", channel) |
|
|
|
alice.sendMessage("bob", "Hey Bob!") |
|
bob.sendMessage("alice", "Hey Alice!") |