-
-
Save senny/4108b128312d04b3bcd9632eb3358cde to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class TOTP | |
attr_reader :secret, :digits, :digest, :interval | |
DEFAULT_INTERVAL = 30 | |
DEFAULT_DIGITS = 6 | |
def initialize(s, options = {}) | |
@interval = options[:interval] || DEFAULT_INTERVAL | |
@digits = options[:digits] || DEFAULT_DIGITS | |
@digest = options[:digest] || "sha1" | |
@secret = s | |
end | |
def now(padding=true) | |
generate_otp(timecode(Time.now), padding) | |
end | |
def at(time, padding=true) | |
unless time.class == Time | |
time = Time.at(time.to_i) | |
end | |
generate_otp(timecode(time), padding) | |
end | |
def verify(otp, time = Time.now) | |
generated = self.at(time) | |
unless otp.is_a?(String) && generated.is_a?(String) | |
raise ArgumentError, "ROTP only verifies strings - See: https://github.com/mdp/rotp/issues/32" | |
end | |
time_constant_compare(otp, generated) | |
end | |
def verify_with_drift(otp, drift, time = Time.now) | |
time = time.to_i | |
times = (time-drift..time+drift).step(interval).to_a | |
times << time + drift if times.last < time + drift | |
times.any? { |ti| verify(otp, ti) } | |
end | |
private | |
def generate_otp(input, padded=true) | |
hmac = OpenSSL::HMAC.digest( | |
OpenSSL::Digest.new(digest), | |
byte_secret, | |
int_to_bytestring(input) | |
) | |
offset = hmac[-1].ord & 0xf | |
code = (hmac[offset].ord & 0x7f) << 24 | | |
(hmac[offset + 1].ord & 0xff) << 16 | | |
(hmac[offset + 2].ord & 0xff) << 8 | | |
(hmac[offset + 3].ord & 0xff) | |
if padded | |
(code % 10 ** digits).to_s.rjust(digits, '0') | |
else | |
code % 10 ** digits | |
end | |
end | |
def timecode(time) | |
time.utc.to_i / interval | |
end | |
def byte_secret | |
Base32.decode(@secret) | |
end | |
def int_to_bytestring(int, padding = 8) | |
result = [] | |
until int == 0 | |
result << (int & 0xFF).chr | |
int >>= 8 | |
end | |
result.reverse.join.rjust(padding, 0.chr) | |
end | |
def time_constant_compare(a, b) | |
return false if a.empty? || b.empty? || a.bytesize != b.bytesize | |
l = a.unpack "C#{a.bytesize}" | |
res = 0 | |
b.each_byte { |byte| res |= byte ^ l.shift } | |
res == 0 | |
end | |
end | |
class Base32 | |
class Base32Error < RuntimeError; end | |
CHARS = "abcdefghijklmnopqrstuvwxyz234567".each_char.to_a | |
class << self | |
def decode(str) | |
output = [] | |
str.scan(/.{1,8}/).each do |block| | |
char_array = decode_block(block).map{|c| c.chr} | |
output << char_array | |
end | |
output.join | |
end | |
private | |
def decode_block(block) | |
length = block.scan(/[^=]/).length | |
quints = block.each_char.map {|c| decode_quint(c)} | |
bytes = [] | |
bytes[0] = (quints[0] << 3) + (quints[1] ? quints[1] >> 2 : 0) | |
return bytes if length < 3 | |
bytes[1] = ((quints[1] & 3) << 6) + (quints[2] << 1) + (quints[3] ? quints[3] >> 4 : 0) | |
return bytes if length < 4 | |
bytes[2] = ((quints[3] & 15) << 4) + (quints[4] ? quints[4] >> 1 : 0) | |
return bytes if length < 6 | |
bytes[3] = ((quints[4] & 1) << 7) + (quints[5] << 2) + (quints[6] ? quints[6] >> 3 : 0) | |
return bytes if length < 7 | |
bytes[4] = ((quints[6] & 7) << 5) + (quints[7] || 0) | |
bytes | |
end | |
def decode_quint(q) | |
CHARS.index(q.downcase) or raise(Base32Error, "Invalid Base32 Character - '#{q}'") | |
end | |
end | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment