Last active
August 29, 2015 14:00
-
-
Save apeiros/11374740 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require "openssl" | |
require 'digest/sha2' | |
require 'securerandom' | |
# Automatically upgradable encryption provider | |
module Encryption | |
class InvalidData < StandardError | |
end | |
class InvalidPattern < InvalidData | |
def initialize(version) | |
super("Invalid pattern version: #{version}") | |
end | |
end | |
class InvalidMessage < InvalidData | |
def initialize(version) | |
super("Invalid message") | |
end | |
end | |
class OutdatedLibrary < StandardError | |
def initialize(version) | |
"This library is as it does not support pattern version #{version}" | |
end | |
end | |
# Utility method, mostly copied from rails 4.1' MessageVerifier | |
# Compare two strings in constant time, to avoid timing attacks. | |
def self.secure_compare(a, b) | |
return false unless a.bytesize == b.bytesize | |
a_bytes = a.unpack("C*") | |
b_bytes = b.unpack("C*") | |
res = 0 | |
a_bytes.zip(b_bytes) do |a_byte, b_byte| res |= (a_byte ^ b_byte) end | |
res.zero? | |
end | |
class Pattern0 | |
Cipher = 'AES-256-CBC'.freeze | |
SaltLength = 64 | |
Keylength = 64 | |
DigestLength = 40 | |
IvLength = 16 | |
Iterations = 1<<16 | |
PatternVersion = 0 | |
def initialize(environment=:production) | |
@iterations = (environment == :test ? 1 : Iterations) | |
end | |
def generate_salt | |
SecureRandom.random_bytes(SaltLength) | |
end | |
def password_digest(password, salt) | |
OpenSSL::PKCS5.pbkdf2_hmac_sha1(password, salt, @iterations, Keylength) | |
end | |
def data_digest(key, data) | |
OpenSSL::HMAC.hexdigest(OpenSSL::Digest::SHA1.new, key, data) | |
end | |
def new_cipher | |
OpenSSL::Cipher::Cipher.new(Cipher) | |
end | |
def split(blob) | |
version = blob[0].ord | |
salt = blob[1,SaltLength] | |
digest = blob[1+SaltLength, DigestLength] | |
iv = blob[1+SaltLength+DigestLength, IvLength] | |
data = blob[(1+SaltLength+DigestLength+IvLength)..-1] | |
[version, salt, digest, iv, data] | |
end | |
def encrypt(data, password) | |
cipher = new_cipher | |
salt = generate_salt | |
key = password_digest(password, salt) | |
digest = data_digest(key, data) | |
iv = cipher.random_iv | |
cipher.encrypt | |
cipher.key = key | |
cipher.iv = iv | |
encrypted = cipher.update(data)+cipher.final | |
digest = OpenSSL::HMAC.hexdigest(OpenSSL::Digest::SHA1.new, key, encrypted) | |
"\x00#{salt}#{digest}#{iv}#{encrypted}".b | |
end | |
def decrypt(data, password) | |
version, salt, digest, iv, encrypted = *split(data) | |
key = password_digest(password, salt) | |
raise InvalidPattern.new(version) unless version == PatternVersion | |
raise InvalidMessage unless Encryption.secure_compare(digest, data_digest(key, encrypted)) | |
cipher = new_cipher | |
cipher.decrypt | |
cipher.key = key | |
cipher.iv = iv | |
cipher.update(encrypted)+cipher.final | |
end | |
end | |
Patterns = [ | |
Pattern0, | |
].freeze | |
module_function | |
def encrypt(data:, password:, environment: :production) | |
Patterns.last.new(environment).encrypt(data, password) | |
end | |
def encrypt_base64(data:, password:, environment: :production) | |
[encrypt(data: data, password: password, environment: environment)].pack('m0') | |
end | |
def decrypt(data:, password:, environment: :production) | |
pattern_version = data[0].ord | |
pattern_class = Patterns[pattern_version] | |
raise OutdatedLibrary.new(pattern_version) unless pattern_class | |
pattern_class.new(environment).decrypt(data, password) | |
end | |
def decrypt_base64(data:, password:, environment: :production) | |
decrypt(data: data.unpack("m*").first, password: password, environment: environment) | |
end | |
def try_decrypt(data:, password:, environment: :production) | |
decrypt(data: data, password: password, environment: environment) | |
rescue InvalidData | |
nil | |
end | |
def try_decrypt_base64(data:, password:, environment: :production) | |
try_decrypt(data: data.unpack("m*").first, password: password, environment: environment) | |
end | |
# Upgrade the encryption of data encrypted with an older cipher | |
def upgrade(data:, password:) | |
decrypted = decrypt(data: data, password: password) | |
encrypt(data: decrypted, password: password) | |
end | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment