Skip to content

Instantly share code, notes, and snippets.

@katlogic
Created July 24, 2014 14:10
Show Gist options
  • Save katlogic/f42ebb567e870a568f88 to your computer and use it in GitHub Desktop.
Save katlogic/f42ebb567e870a568f88 to your computer and use it in GitHub Desktop.
local lmath = require 'gmp'
local print, setmetatable, io, string, math, assert, pairs
= print, setmetatable, io, string, math, assert, pairs
local bn = lmath.z
local rsa = {}
local rsa_mt = {__index=rsa}
local genmod = 65537
do local _ENV = rsa
---------------------------------------
-- rsa.* API methods
---------------------------------------
-- generate random prime
-- @bits - number of bits wanted (will be rounded down to step of 8)
-- returns: random prime between 1 .. 2^bits
function random_prime(bits)
bits = bits
return bn():secure(bits):nextprime()
end
-- pretty progress printer factory
-- @out - terminal writer
-- returns: function(step,nstep) to be passed to new_key
function pretty_cb(out)
return function(n,outof)
if n == 0 then
out("["..string.rep(".", outof).."]"..string.rep("\b", outof+1))
elseif n < outof then
out("*")
elseif n == outof then
out("] ")
end
end
end
-- generate new key
-- @bits - number of bits in modulus
-- @?exp - public exponent (default: 65537)
-- @?progress(step,nsteps) - called with step out of nsteps
function new_key(bits,exp,progress)
-- defaults
bits = bits or 2048
exp = exp or genmod
-- progres
local counter = 0
local cb = progress and function()
if counter <= 7 then
progress(counter,7)
end
counter = counter+1
end or function() end
-- generate prime number such that gcd(e,n-1) == 1
local function rsa_prime(bits,e,cb)
local r, r1
cb()
repeat
r = rsa.random_prime(bits)
cb()
r1 = r-1
until r1:gcd(e) == bn(1)
return r, r1
end
cb()
local e = bn(exp)
local p, q, p1, q1, mod
-- because nextprime monotonously increments,
-- we might end up (very rarely) with larger than bits modulus
repeat
p, p1 = rsa_prime(bits/2, e, cb)
q, q1 = rsa_prime(bits/2, e, cb)
mod = p*q
local realbits = mod:sizeinbase(2)
until realbits < bits
-- ensure p > q
if p < q then p, p1, q, q1 = q, q1, p, p1 end
-- phi
local phi = p1 * q1
-- calculate private exponent d
cb()
local d = assert(e:invert(phi))
-- dmp1 = d mod (p-1)
local dmp1 = d % p1
local dmq1 = d % q1
cb()
local iqmp = assert(q:invert(p))
cb()
return setmetatable({
pub = {
e = e,
mod = mod,
},
priv = {
d = d,
-- optional
p = p,
q = q,
dmp1 = dmp1,
dmq1 = dmq1,
iqmp = iqmp,
}
}, rsa_mt)
end
-- sign or decrypt m
function do_private(self,msg)
local pk = self.priv
local p,q,dmp1,dmq1,iqmp = pk.p,pk.q,pk.dmp1,pk.dmq1,pk.iqmp
-- no extended fields; do slow rsa
if not (p and q and dmp1 and dmq1 and iqmp) then
return msg:powm(self.priv.d, self.pub.mod)
end
-- fast rsa
local vp, vq = msg % p, msg % q
vp:powm(dmp1, p, vp)
vq:powm(dmq1, q, vq)
return (((vp - vq) * iqmp) % p) * q + vq
end
-- verify or encrypt
function do_public(self,msg)
return msg:powm(self.pub.e, self.pub.mod)
end
-- pretty print rsa object
function rsa_mt.__tostring(self)
local res = "RSA {\n pub = {\n"
for k,v in pairs(self.pub) do
res = string.format("%s %s = %s\n", res, k, v:hex())
end
res = res .." }\n priv = {\n"
for k,v in pairs(self.priv) do
res = string.format("%s %s = %s\n", res, k, v:hex())
end
return res .. " }\n}\n"
end
end -- _ENV = rsa
---------------------------------------
-- test
---------------------------------------
if not ... then
local function write(...)
io.stdout:write(table.concat({...}, " "))
io.stdout:flush()
end
local tbits = 2048
local tpay = tbits - 30
local kf = io.open('priv','r')
if not kf then
write "Generating RSA key "
k = rsa.new_key(tbits, nil, rsa.pretty_cb(write))
write "Done\n"
io.open("priv","w"):write(k.priv.d:hex())
io.open("pub","w"):write(k.pub.mod:hex())
else
k = setmetatable({
priv = { d = bn(kf:read("*a"), 16) },
pub = { e = bn(genmod), mod = bn(io.open("pub","r"):read("*a"), 16) }
}, rsa_mt)
end
print(k)
testy = bn():secure(tpay)
start = os.clock()
total = 1
assert(k:do_public(k:do_private(testy))==testy)
while true do
for j=1,100 do
total = total+1
k:do_private(testy)
end
if (os.clock() > (start+5)) then break end
write('.')
end
stop = os.clock()
print(total/(stop-start))
else
return rsa
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment