-
-
Save guizmaii/6b5d3666081960639c3df0a24e17e2fd to your computer and use it in GitHub Desktop.
package com.guizmaii.utils | |
import java.nio.charset.{Charset, StandardCharsets} | |
import java.security.{Key, SecureRandom} | |
import cats.effect.Sync | |
import doobie.{Put, Read} | |
import eu.timepit.refined.types.all.NonEmptyString | |
import io.estatico.newtype.macros.newtype | |
import javax.crypto.{Cipher, SecretKey, SecretKeyFactory} | |
import javax.crypto.spec.{GCMParameterSpec, PBEKeySpec, SecretKeySpec} | |
import pureconfig.ConfigReader | |
import pureconfig.error.CannotConvert | |
import scala.language.implicitConversions | |
object AES { | |
/** | |
* This type represents a password. | |
* | |
* The constraints on that password are fairly simple but can be improved later if needed. | |
* | |
* For now, a valid password is just a string which is at least 40 characters long. | |
* | |
* Why 40? It seems long enough. | |
*/ | |
@newtype case class Password(value: NonEmptyString) { | |
def toCharArray: Array[Char] = value.value.toCharArray | |
} | |
object Password { | |
implicit final val PasswordReader: ConfigReader[Password] = | |
ConfigReader.fromString { password => | |
val length = password.length | |
if (length >= 40) Right(Password(NonEmptyString.unsafeFrom(password))) | |
else { | |
val obfuscated = s"${password.take(3)}..." | |
Left( | |
CannotConvert( | |
obfuscated, | |
"com.guizmaii.utils.AES.Password", | |
s"The value is less than 40 characters. Current length: $length characters" | |
) | |
) | |
} | |
} | |
} | |
final val UTF_8: Charset = StandardCharsets.UTF_8 | |
@newtype case class CipherText(value: Base64String) | |
@newtype case class IV(value: Base64String) | |
@newtype case class Salt(value: Base64String) { | |
def toRawSalt: RawSalt = RawSalt(base64decode(value)) | |
} | |
@newtype case class RawSalt(value: Array[Byte]) { | |
def toSalt: Salt = Salt(base64encode(value)) | |
} | |
@newtype case class ClearText(value: String) { | |
def getBytes(charset: Charset): Array[Byte] = value.getBytes(charset) | |
} | |
@newtype case class Base64String(value: String) { | |
def getBytes(charset: Charset): Array[Byte] = value.getBytes(charset) | |
} | |
object CipherText { | |
implicit final val doobieRead: Read[CipherText] = deriving | |
implicit final val doobiePut: Put[CipherText] = deriving | |
} | |
object IV { | |
implicit final val doobieRead: Read[IV] = deriving | |
implicit final val doobiePut: Put[IV] = deriving | |
} | |
object Salt { | |
implicit final val doobieRead: Read[Salt] = deriving | |
implicit final val doobiePut: Put[Salt] = deriving | |
} | |
object Base64String { | |
implicit final val doobieRead: Read[Base64String] = deriving | |
implicit final val doobiePut: Put[Base64String] = deriving | |
} | |
def base64encode(in: Array[Byte]): Base64String = Base64String(new String(java.util.Base64.getEncoder.encode(in), UTF_8)) | |
def base64decode(in: Base64String): Array[Byte] = java.util.Base64.getDecoder.decode(in.getBytes(UTF_8)) | |
} | |
import AES._ | |
/** | |
* Resources that helped: | |
* - https://mkyong.com/java/java-aes-encryption-and-decryption/ | |
* - https://wiki.sei.cmu.edu/confluence/display/java/MSC61-J.+Do+not+use+insecure+or+weak+cryptographic+algorithms | |
* - https://proandroiddev.com/security-best-practices-symmetric-encryption-with-aes-in-java-7616beaaade9 | |
* - https://security.stackexchange.com/a/105788/66294 | |
* - https://stackoverflow.com/a/13915596 | |
*/ | |
trait AES[F[_]] { | |
def encrypt(in: ClearText): F[(CipherText, Salt, IV)] | |
def decrypt(data: CipherText, salt: Salt, iv: IV): F[ClearText] | |
} | |
final class AESImpl[F[+_]: Sync](password: Password) extends AES[F] { | |
private val cipher = "AES/GCM/NoPadding" | |
private val Algorithm = "AES" | |
private val GcmAuthenticationTagLength = 128 | |
private val SaltLength = 16 | |
private val IvLength = 12 | |
private val random: SecureRandom = new SecureRandom() | |
override def encrypt(in: ClearText): F[(CipherText, Salt, IV)] = | |
Sync[F].delay { | |
val rawSalt: RawSalt = generateRawSalt | |
val key: Key = getAESKeyFromPassword(rawSalt) | |
val iv: GCMParameterSpec = generateIv | |
val cipherText = doEncrypt(in.getBytes(UTF_8), key, iv) | |
(CipherText(base64encode(cipherText)), rawSalt.toSalt, IV(base64encode(iv.getIV))) | |
} | |
override def decrypt(data: CipherText, salt: Salt, iv: IV): F[ClearText] = | |
Sync[F].delay { | |
val key: Key = getAESKeyFromPassword(salt.toRawSalt) | |
val gcmParams = new GCMParameterSpec(GcmAuthenticationTagLength, base64decode(iv.value)) | |
val clearText = doDecrypt(base64decode(data.value), key, gcmParams) | |
ClearText(new String(clearText, UTF_8)) | |
} | |
private def doEncrypt(in: Array[Byte], key: Key, gcmParams: GCMParameterSpec): Array[Byte] = { | |
val c = Cipher.getInstance(cipher) | |
c.init(Cipher.ENCRYPT_MODE, key, gcmParams) | |
c.doFinal(in) | |
} | |
private def doDecrypt(in: Array[Byte], key: Key, gcmParams: GCMParameterSpec): Array[Byte] = { | |
val c = Cipher.getInstance(cipher) | |
c.init(Cipher.DECRYPT_MODE, key, gcmParams) | |
c.doFinal(in) | |
} | |
private def generateRawSalt: RawSalt = { | |
val salt = new Array[Byte](SaltLength) | |
random.nextBytes(salt) | |
RawSalt(salt) | |
} | |
private def generateIv: GCMParameterSpec = { | |
val iv = new Array[Byte](IvLength) | |
random.nextBytes(iv) | |
new GCMParameterSpec(GcmAuthenticationTagLength, iv) | |
} | |
/** | |
* AES key derived from a password | |
* | |
* Comes from: https://mkyong.com/java/java-aes-encryption-and-decryption/ | |
*/ | |
private def getAESKeyFromPassword(salt: RawSalt): SecretKey = { | |
val factory = SecretKeyFactory.getInstance("PBKDF2WithHmacSHA256") | |
val iterationCount = 65536 | |
val keyLength = 256 | |
val spec = new PBEKeySpec(password.toCharArray, salt.value, iterationCount, keyLength) | |
new SecretKeySpec(factory.generateSecret(spec).getEncoded, Algorithm) | |
} | |
} |
package com.guizmaii.utils | |
import com.guizmaii.utils.AES.{ClearText, Password} | |
import eu.timepit.refined.types.string.NonEmptyString | |
import monix.eval.Task | |
import org.scalacheck.{Arbitrary, Gen, Shrink} | |
import org.scalactic.anyvals.PosInt | |
import org.scalatest.freespec.AnyFreeSpec | |
import org.scalatest.matchers.should.Matchers | |
import org.scalatestplus.scalacheck.ScalaCheckDrivenPropertyChecks | |
class AESTest extends AnyFreeSpec with Matchers with ScalaCheckDrivenPropertyChecks { | |
import monix.execution.Scheduler.Implicits.global | |
implicit val arbPassword: Arbitrary[Password] = | |
Arbitrary { | |
for { | |
size <- Gen.choose(40, 200) // if the max value here is too high, EncryptionUtils tests will take forever. | |
password <- Gen.listOfN(size, Gen.alphaChar) | |
} yield Password(NonEmptyString.unsafeFrom(password.mkString)) | |
} | |
override implicit val generatorDrivenConfig: PropertyCheckConfiguration = | |
PropertyCheckConfiguration(minSuccessful = 1000, workers = PosInt.ensuringValid(Runtime.getRuntime.availableProcessors())) | |
implicit def noShrink[T]: Shrink[T] = Shrink.shrinkAny | |
"AES" - { | |
"#encrypt" - { | |
"encrypted text is not just base64'ed text" in forAll { (secret: String, password: Password) => | |
val service: AES[Task] = new AESImpl[Task](password) | |
val (encrypted, _, _) = service.encrypt(ClearText(secret)).runSyncUnsafe() | |
AES.base64decode(encrypted.value) should not be secret | |
} | |
"salt should be 16 bytes long" in forAll { (secret: String, password: Password) => | |
val service: AES[Task] = new AESImpl(password) | |
val (_, salt, _) = service.encrypt(ClearText(secret)).runSyncUnsafe() | |
AES.base64decode(salt.value).length shouldBe 16 | |
} | |
"iv should be 12 bytes long" in forAll { (secret: String, password: Password) => | |
val service: AES[Task] = new AESImpl(password) | |
val (_, _, iv) = service.encrypt(ClearText(secret)).runSyncUnsafe() | |
AES.base64decode(iv.value).length shouldBe 12 | |
} | |
} | |
"#decrypt" - {} | |
"both way" - { | |
"decrypted encrypted text should be equal to initial text" in forAll { (secret: String, password: Password) => | |
val service: AES[Task] = new AESImpl(password) | |
val decrypted = | |
for { | |
(encrypted, salt, iv) <- service.encrypt(ClearText(secret)) | |
decrypted <- service.decrypt(encrypted, salt, iv) | |
} yield decrypted.value | |
decrypted.runSyncUnsafe() shouldBe secret | |
} | |
} | |
} | |
} |
package com.guizmaii.utils | |
import cats.effect.Blocker | |
import com.guizmaii.utils.AES.Password | |
import com.typesafe.config.ConfigFactory | |
import monix.eval.Task | |
import org.scalacheck.Shrink | |
import org.scalatest.freespec.AnyFreeSpec | |
import org.scalatest.matchers.should.Matchers | |
import org.scalatestplus.scalacheck.ScalaCheckDrivenPropertyChecks | |
import pureconfig.ConfigSource | |
import pureconfig.generic.auto._ | |
import pureconfig.module.catseffect._ | |
// scalastyle:off | |
final case class ExampleConfig(password: Password) | |
class TypesTest extends AnyFreeSpec with Matchers with ScalaCheckDrivenPropertyChecks { | |
override implicit val generatorDrivenConfig: PropertyCheckConfiguration = | |
PropertyCheckConfiguration(minSuccessful = 100) | |
import monix.execution.Scheduler.Implicits.global | |
implicit def noShrink[T]: Shrink[T] = Shrink.shrinkAny | |
implicit final val blocker: Blocker = Blocker.liftExecutionContext(global) | |
"Password" - { | |
"pureconfig PasswordReader" - { | |
"should only accept string at least 40 characters long" in forAll() { value: String => | |
import scala.collection.JavaConverters.mapAsJavaMapConverter | |
val c = ConfigFactory.parseMap(Map("password" -> value).asJava) | |
val result: Either[Throwable, ExampleConfig] = | |
loadF[Task, ExampleConfig](ConfigSource.fromConfig(c), blocker).attempt.runSyncUnsafe() | |
if (value.length >= 40) { | |
result shouldBe a[Right[_, _]] | |
val password: String = result.toOption.get.password.value.value | |
password shouldBe value | |
} else { | |
result shouldBe a[Left[_, _]] | |
val error = result.swap.toOption.get | |
error.getMessage should include( | |
s"""com.guizmaii.utils.AES.Password: The value is less than 40 characters. Current length: ${value.length} characters.""" | |
) | |
if (value.length > 3) error.getMessage should not include (value) else succeed | |
} | |
} | |
} | |
} | |
} |
Hey @ares-b, deriving
comes from this library: https://github.com/estatico/scala-newtype
See doc here: https://github.com/estatico/scala-newtype#companion-objects
Did you add the "paradise" scala compiler plugin in your project? See https://github.com/estatico/scala-newtype#getting-newtype
Hey @guizmaii Thanks alot, one last thing, can you please tell me which version of Monix, Cats & Cats-Effect are you using ? I'm having a hard time to guess these :/
This was a long time ago. I'm not even using these anymore. I'm using ZIO now. It's way better. I'd advise you to do so too if you can.
The newtypes in zio-prelude are way more powerful than the ones from scala-newtypes
That being said, I don't think that Cats has moved enough to make my code not work anymore (I'm not following Cats evolutions, so I can be wrong)
The only Cats (Cats-effects, actually) thing I'm using here is cats.effect.Sync
which most probably still exist
Hey @guizmaii do you have it working with ZIO already?
@leandrocruz No, but I can make it. I'll try to do this and will post a link here ;)
@leandrocruz I didn't publish it as a library but if you want I can.
Here is the ZIO port: https://github.com/guizmaii/zio-AES
I finally published the lib.
It's now living here: https://github.com/guizmaii-opensource/zio-AES
Hey, I just copy/pasted your code in my IDE and I'm getting this error :
I'm running Scala 2.12, do you have any idea of what can be the issue ?