Skip to content

Instantly share code, notes, and snippets.

@tompng
Last active August 17, 2025 04:37
Show Gist options
  • Save tompng/bcd09eaa448b463cf0431e405e62e5ce to your computer and use it in GitHub Desktop.
Save tompng/bcd09eaa448b463cf0431e405e62e5ce to your computer and use it in GitHub Desktop.
NTT(Number Theorem Transform) multiplication
# Reference: https://qiita.com/AngrySadEight/items/0dfde26060daaf6a2fda
BASES = [29, 26, 24]
SHIFT = 27
class NTT
attr_reader :p
def initialize(base, shift)
@base = base
@shift = shift
@p = (base << shift) | 1
r = 17.pow(@base, @p)
@root = @shift.times.map { r.pow(1 << it, @p) }.reverse
@inv_root = @root.map { it.pow(@p - 2, @p) }
end
def ntt(a, dir)
depth = a.size.bit_length - 2
root = dir > 0 ? @root : @inv_root
ntt_rec(a, depth, root[depth])
end
def ntt_rec(a, depth, r)
stride = 1 << (a.size.bit_length - depth - 2)
a2 = depth == 0 ? a : ntt_rec(a, depth - 1, r * r % p)
n = a.size / stride
n_half = n / 2
out = []
rn = 1
n.times do |i|
j = i % n_half * 2 * stride
stride.times do |chunk|
out << (a2[j + chunk] + rn * a2[j + chunk + stride]) % @p
end
rn = (rn * r) % @p
end
out
end
def convolution(a, b)
n = 1 << (a.size + b.size - 1).bit_length
a += [0] * (n - a.size)
b += [0] * (n - b.size)
da = ntt(a, +1)
db = ntt(b, +1)
dc = da.zip(db).map { _1 * _2 % @p }
c = ntt(dc, -1)
n_inv = n.pow(@p - 2, @p)
c.map { (it * n_inv) % @p }
end
end
MULT_BASE = 1000000000
N = 10000
va = rand(MULT_BASE ** N - 1)
vb = rand(MULT_BASE ** N - 1)
a = va.digits(MULT_BASE)
b = vb.digits(MULT_BASE)
ntts = BASES.map { NTT.new(it, SHIFT) }
cnt = 1
cnt += 1 while ntts.take(cnt).map(&:p).inject(:*) < (MULT_BASE - 1)**2 * N
ntts = ntts.take(cnt)
puts "needs multiple P: #{cnt}" if cnt > 1
def ext_euclid(a, b)
# Returns [x, y] where a * x + b * y = 1
return [a, b] if a == 0 || b == 0
return ext_euclid(b, a).reverse if a < b
x, y = ext_euclid(a % b, b)
# x * (a - b * (a / b)) + y * b = 1
[x, y - (a / b) * x]
end
def restore_from_convolution_mods(convolutions, ps)
pa = 1
bxy = ps.map do |pb|
xy = ext_euclid(pa, pb)
pa *= pb
[pb, *xy]
end
convolutions.transpose.map do |remainders|
pa = 1
rema = 0
remas = []
remainders.zip(bxy).each do |remb, (pb, x, y)|
pab = pa * pb
rema = (rema * pb * y + remb * pa * x) % pab
pa = pab
end
rema
end
end
def convolution(a, b, ntts)
convolutions = ntts.map { it.convolution(a, b) }
restore_from_convolution_mods(convolutions, ntts.map(&:p))
end
require 'bigdecimal'
biga = BigDecimal(va)
bigb = BigDecimal(vb)
t0=Time.now
result = convolution(a, b, ntts)
t1=Time.now
p t1 - t0
biga * bigb
t2=Time.now
p t2 - t1
BigDecimal(biga.to_i*bigb.to_i)
t3=Time.now
p t3 - t2
carry = 0
mult_result = result.map do |v|
carry, mod = (v + carry).divmod(MULT_BASE)
mod
end
vab = mult_result.map{ it.to_s.rjust(MULT_BASE.to_s.size-1, '0') }.reverse.join.to_i(10)
p va*vb == vab
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment