Last active
July 14, 2020 07:15
-
-
Save Rulexec/664747f12179d6bd94d0a112a138d679 to your computer and use it in GitHub Desktop.
Port of punycode encode function in lua from https://github.com/bestiejs/punycode.js
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
local bit = require "bit" | |
local _M = {} | |
local maxInt = 2147483647 | |
local base = 36 | |
local tMin = 1 | |
local tMax = 26 | |
local skew = 38 | |
local damp = 700 | |
local initialBias = 72 | |
local initialN = 128 | |
local delimiter = 45 | |
local baseMinusTMin = base - tMin | |
local function utf8_decode(str) | |
local i = 1 | |
local length = string.len(str) | |
local output = {} | |
local codes_left = 0 | |
local prev_multioctet = 0 | |
while i <= length do | |
local c = string.byte(str, i) | |
i = i + 1 | |
if codes_left > 0 then | |
prev_multioctet = prev_multioctet + (c - 0x80) | |
codes_left = codes_left - 1 | |
if codes_left == 0 then | |
table.insert(output, prev_multioctet) | |
else | |
prev_multioctet = prev_multioctet * 0x40 | |
end | |
elseif c <= 0x7f then | |
table.insert(output, c) | |
elseif c <= 0xdf then | |
codes_left = 1 | |
prev_multioctet = (c - 0xc0) * 0x40 | |
elseif c <= 0xef then | |
codes_left = 2 | |
prev_multioctet = (c - 0xe0) * 0x40 | |
else | |
codes_left = 3 | |
prev_multioctet = (c - 0xf0) * 0x40 | |
end | |
end | |
return output | |
end | |
local function ucs2decode(str) | |
local output = {} | |
local counter = 1 | |
local length = table.getn(str) | |
while counter <= length do | |
local value = str[counter] | |
counter = counter + 1 | |
if (value >= 0xD800 and value <= 0xDBFF and counter <= length) then | |
local extra = str[counter] | |
counter = counter + 1 | |
if bit.band(extra, 0xfc00) == 0xdc00 then | |
table.insert(output, bit.lshift(bit.band(value, 0x3ff), 10) + bit.band(extra, 0x3ff) + 0x10000) | |
else | |
table.insert(output, value) | |
counter = counter - 1 | |
end | |
else | |
table.insert(output, value) | |
end | |
end | |
return output | |
end | |
local function digit_to_basic(digit, flag) | |
local result = digit + 22 | |
if flag ~= 0 then | |
result = result - 32 | |
end | |
if digit < 26 then | |
result = result + 75 | |
end | |
return result | |
end | |
local function adapt(delta, numPoints, firstTime) | |
local k = 0 | |
if firstTime then | |
delta = math.floor(delta / damp) | |
else | |
delta = bit.rshift(delta, 1) | |
end | |
delta = delta + math.floor(delta / numPoints) | |
while delta > baseMinusTMin * bit.rshift(tMax, 1) do | |
delta = math.floor(delta / baseMinusTMin) | |
k = k + base | |
end | |
return math.floor(k + (baseMinusTMin + 1) * delta / (delta + skew)) | |
end | |
local function from_char_codes(arr) | |
local result = "" | |
for _, code in ipairs(arr) do | |
result = result .. string.char(code) | |
end | |
return result | |
end | |
local function raw_encode(input) | |
local output = {} | |
input = ucs2decode(input) | |
local inputLength = table.getn(input) | |
local n = initialN | |
local delta = 0 | |
local bias = initialBias | |
for _, val in ipairs(input) do | |
if val < 0x80 then | |
table.insert(output, val) | |
end | |
end | |
local basicLength = table.getn(output) | |
local handledCPCount = basicLength | |
if basicLength > 0 then | |
table.insert(output, delimiter) | |
end | |
while handledCPCount < inputLength do | |
local m = maxInt | |
for _, val in ipairs(input) do | |
if val >= n and val < m then | |
m = val | |
end | |
end | |
local handledCPCountPlusOne = handledCPCount + 1 | |
if m - n > math.floor((maxInt - delta) / handledCPCountPlusOne) then | |
error('overflow') | |
end | |
delta = delta + (m - n) * handledCPCountPlusOne | |
n = m | |
for _, val in ipairs(input) do | |
if val < n then | |
delta = delta + 1 | |
if delta > maxInt then | |
error('overflow') | |
end | |
end | |
if val == n then | |
local q = delta | |
local k = base | |
while true do | |
local t | |
if k <= bias then | |
t = tMin | |
elseif k >= bias + tMax then | |
t = tMax | |
else | |
t = k - bias | |
end | |
if q < t then | |
break | |
end | |
local qMinusT = q - t | |
local baseMinusT = base - t | |
table.insert(output, digit_to_basic(t + qMinusT % baseMinusT, 0)) | |
q = math.floor(qMinusT / baseMinusT) | |
k = k + base | |
end | |
table.insert(output, digit_to_basic(q, 0)) | |
bias = adapt(delta, handledCPCountPlusOne, handledCPCount == basicLength) | |
delta = 0 | |
handledCPCount = handledCPCount + 1 | |
end | |
end | |
delta = delta + 1 | |
n = n + 1 | |
end | |
return output | |
end | |
function _M.encode_domain(domain) | |
local has_non_ascii = false | |
local length = string.len(domain) | |
for i = 1,length do | |
if string.byte(domain, i) > 0x7e then | |
has_non_ascii = true | |
break | |
end | |
end | |
if not has_non_ascii then | |
return domain | |
end | |
-- TODO: implement split by \u3002\uFF0E\uFF61 too | |
domain = utf8_decode(domain) | |
local result = "" | |
local part = {} | |
local part_has_non_ascii = false | |
for _, c in ipairs(domain) do | |
if c ~= 0x2e then | |
if c > 0x7e then | |
part_has_non_ascii = true | |
end | |
table.insert(part, c) | |
else | |
if part_has_non_ascii then | |
result = result .. "xn--" .. from_char_codes(raw_encode(part)) .. "." | |
else | |
result = result .. from_char_codes(part) .. "." | |
end | |
part = {} | |
part_has_non_ascii = false | |
end | |
end | |
if part_has_non_ascii then | |
result = result .. "xn--" .. from_char_codes(raw_encode(part)) | |
else | |
result = result .. from_char_codes(part) | |
end | |
return result | |
end | |
return _M |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment