Skip to content

Instantly share code, notes, and snippets.

@thibaudgg
Last active October 24, 2017 09:04
Show Gist options
  • Save thibaudgg/059a3f678d478df2a49e0755ca86c4a8 to your computer and use it in GitHub Desktop.
Save thibaudgg/059a3f678d478df2a49e0755ca86c4a8 to your computer and use it in GitHub Desktop.
class TOTP
attr_reader :secret, :digits, :digest, :interval
DEFAULT_INTERVAL = 30
DEFAULT_DIGITS = 6
# @param [String] secret in the form of base32
# @option options [Integer] interval (30) the time interval in seconds for OTP
# This defaults to 30 which is standard.
# @option options digits [Integer] (6)
# Number of integers in the OTP
# Google Authenticate only supports 6 currently
# @option options digest [String] (sha1)
# Digest used in the HMAC
# Google Authenticate only supports 'sha1' currently
# @returns [OTP] OTP instantiation
def initialize(s, options = {})
@interval = options[:interval] || DEFAULT_INTERVAL
@digits = options[:digits] || DEFAULT_DIGITS
@digest = options[:digest] || "sha1"
@secret = s
end
# Generate the current time OTP
# @return [Integer] the OTP as an integer
def now(padding=true)
generate_otp(timecode(Time.now), padding)
end
# Accepts either a Unix timestamp integer or a Time object.
# Time objects will be adjusted to UTC automatically
# @param [Time/Integer] time the time to generate an OTP for
# @option [Boolean] padding (true) Issue the number as a 0 padded string
def at(time, padding=true)
unless time.class == Time
time = Time.at(time.to_i)
end
generate_otp(timecode(time), padding)
end
# Verifies the OTP passed in against the current time OTP
# @param [String/Integer] otp the OTP to check against
def verify(otp, time = Time.now)
generated = self.at(time)
unless otp.is_a?(String) && generated.is_a?(String)
raise ArgumentError, "ROTP only verifies strings - See: https://github.com/mdp/rotp/issues/32"
end
time_constant_compare(otp, generated)
end
# Verifies the OTP passed in against the current time OTP
# and adjacent intervals up to +drift+.
# @param [String] otp the OTP to check against
# @param [Integer] drift the number of seconds that the client
# and server are allowed to drift apart
def verify_with_drift(otp, drift, time = Time.now)
time = time.to_i
times = (time-drift..time+drift).step(interval).to_a
times << time + drift if times.last < time + drift
times.any? { |ti| verify(otp, ti) }
end
private
# @param [Integer] input the number used seed the HMAC
# @option padded [Boolean] (false) Output the otp as a 0 padded string
# Usually either the counter, or the computed integer
# based on the Unix timestamp
def generate_otp(input, padded=true)
hmac = OpenSSL::HMAC.digest(
OpenSSL::Digest.new(digest),
byte_secret,
int_to_bytestring(input)
)
offset = hmac[-1].ord & 0xf
code = (hmac[offset].ord & 0x7f) << 24 |
(hmac[offset + 1].ord & 0xff) << 16 |
(hmac[offset + 2].ord & 0xff) << 8 |
(hmac[offset + 3].ord & 0xff)
if padded
(code % 10 ** digits).to_s.rjust(digits, '0')
else
code % 10 ** digits
end
end
def timecode(time)
time.utc.to_i / interval
end
def byte_secret
Base32.decode(@secret)
end
# Turns an integer to the OATH specified
# bytestring, which is fed to the HMAC
# along with the secret
#
def int_to_bytestring(int, padding = 8)
result = []
until int == 0
result << (int & 0xFF).chr
int >>= 8
end
result.reverse.join.rjust(padding, 0.chr)
end
# constant-time compare the strings
def time_constant_compare(a, b)
return false if a.empty? || b.empty? || a.bytesize != b.bytesize
l = a.unpack "C#{a.bytesize}"
res = 0
b.each_byte { |byte| res |= byte ^ l.shift }
res == 0
end
end
class Base32
class Base32Error < RuntimeError; end
CHARS = "abcdefghijklmnopqrstuvwxyz234567".each_char.to_a
class << self
def decode(str)
output = []
str.scan(/.{1,8}/).each do |block|
char_array = decode_block(block).map{|c| c.chr}
output << char_array
end
output.join
end
private
def decode_block(block)
length = block.scan(/[^=]/).length
quints = block.each_char.map {|c| decode_quint(c)}
bytes = []
bytes[0] = (quints[0] << 3) + (quints[1] ? quints[1] >> 2 : 0)
return bytes if length < 3
bytes[1] = ((quints[1] & 3) << 6) + (quints[2] << 1) + (quints[3] ? quints[3] >> 4 : 0)
return bytes if length < 4
bytes[2] = ((quints[3] & 15) << 4) + (quints[4] ? quints[4] >> 1 : 0)
return bytes if length < 6
bytes[3] = ((quints[4] & 1) << 7) + (quints[5] << 2) + (quints[6] ? quints[6] >> 3 : 0)
return bytes if length < 7
bytes[4] = ((quints[6] & 7) << 5) + (quints[7] || 0)
bytes
end
def decode_quint(q)
CHARS.index(q.downcase) or raise(Base32Error, "Invalid Base32 Character - '#{q}'")
end
end
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment