Skip to content

Instantly share code, notes, and snippets.

@Dekkonot
Created September 30, 2024 00:15
Show Gist options
  • Save Dekkonot/49143d03962be70dfaeb71b3258dd305 to your computer and use it in GitHub Desktop.
Save Dekkonot/49143d03962be70dfaeb71b3258dd305 to your computer and use it in GitHub Desktop.
--!native
--!optimize 2
type vector = Vector3
-- stylua: ignore
local HEX_TO_BINARY = {
["0"] = "0000", ["1"] = "0001", ["2"] = "0010", ["3"] = "0011",
["4"] = "0100", ["5"] = "0101", ["6"] = "0110", ["7"] = "0111",
["8"] = "1000", ["9"] = "1001", ["a"] = "1010", ["b"] = "1011",
["c"] = "1100", ["d"] = "1101", ["e"] = "1110", ["f"] = "1111",
}
local function u64_from_pair(most: number, least: number): vector
-- 32 - 22 = 10
-- x = upper 24 of most
-- y = lower 10 of most and lower 10 of least
-- z = upper 24 of least
return vector(
bit32.rshift(most, 10),
bit32.replace(bit32.band(least, 0x3FF), bit32.band(most, 0x3FF), 10, 10),
bit32.rshift(least, 10)
)
end
local function u64_from_u32(u32: number): vector
return vector(0, bit32.band(u32, 0x3FF), bit32.rshift(u32, 10))
end
local function u64_from_f64(f64: number): vector
return u64_from_pair(bit32.bor(f64 // 2 ^ 32), bit32.bor(f64 % 2 ^ 32))
end
local function u64_from_buffer(buf: buffer, offset: number?): vector
offset = offset or 0
local least = buffer.readu32(buf, offset :: number)
local most = buffer.readu32(buf, offset :: number + 4)
return u64_from_pair(most, least)
end
local function u64_from_string(str: string, offset: number?): vector
offset = offset or 1
-- This is deliberately backwards because we expect a little endian input
local d, c, b, a, h, g, f, e = string.byte(str, offset, offset :: number + 7)
local least = bit32.bor(bit32.lshift(bit32.replace(b, a, 8, 8), 16), bit32.replace(d, c, 8, 8))
local most = bit32.bor(bit32.lshift(bit32.replace(f, e, 8, 8), 16), bit32.replace(h, g, 8, 8))
return u64_from_pair(most, least)
end
local function u32_from_u64(u64: vector): (number, number)
local most = bit32.bor(bit32.lshift(u64.X, 10), bit32.rshift(u64.Y, 10))
local least = bit32.bor(bit32.lshift(u64.Z, 10), bit32.band(u64.Y, 0x3FF))
return most, least
end
local function u16_from_u64(u64: vector): (number, number, number, number)
local most, least = u32_from_u64(u64)
return bit32.rshift(most, 16), bit32.band(most, 0xFFFF), bit32.rshift(least, 16), bit32.band(least, 0xFFFF)
end
local function f64_from_u64(u64: vector): number
local most, least = u32_from_u64(u64)
return most * 2 ^ 32 + least
end
local function to_bytes_buffer(u64: vector): buffer
local most, least = u32_from_u64(u64)
local buf = buffer.create(8)
-- We use a little endian format because it doesn't matter and most
-- systems we run on are little endian so it'll be nicer for them.
buffer.writeu32(buf, 0, least)
buffer.writeu32(buf, 4, most)
return buf
end
local function to_bytes_string(u64: vector): string
local most, least = u32_from_u64(u64)
return string.format(
"%c%c%c%c%c%c%c%c",
bit32.band(least, 0xFF),
bit32.extract(least, 8, 8),
bit32.extract(least, 16, 8),
bit32.rshift(least, 24),
bit32.band(most, 0xFF),
bit32.extract(most, 8, 8),
bit32.extract(most, 16, 8),
bit32.rshift(most, 24)
)
end
local ZERO_VECTOR = vector(0, 0, 0)
-- TODO: validate that this is correct
local ONE_VECTOR = vector(0, 0x1, 0)
--[=[
Converts the provided `u64` to a string of 16 hexadecimal digits.
The returned string will always be 16 bytes and the number is formatted as
if it were big-endian.
]=]
local function to_hex_string(u64: vector): string
local most, least = u32_from_u64(u64)
return string.format("%08x%08x", most, least)
end
--[=[
Converts the provided `u64` to a string of 64 binary digits.
The returned string will always be 64 bytes and the number is formatted as
if it were big-endian.
]=]
local function to_bin_string(u64: vector): string
-- horrible, inefficent, terrible.
return (string.gsub(to_hex_string(u64), "(.)", HEX_TO_BINARY))
end
--[=[
Returns whether the provided `u64` is zero or not.
Equivalent to `u64 == from_u32(0)`.
]=]
local function is_zero(u64: vector): boolean
return u64 == ZERO_VECTOR
end
--[=[
Computes the bitwise AND of the two provide values.
This does not accept a vararg like the `bit32` equivalent for performance
reasons.
]=]
local function band(lhs: vector, rhs: vector): vector
return vector(bit32.band(lhs.X, rhs.X), bit32.band(lhs.Y, rhs.Y), bit32.band(lhs.Z, rhs.Z))
end
--[=[
Computes the bitwise OR of the two provide values.
This does not accept a vararg like the `bit32` equivalent for performance
reasons.
]=]
local function bor(lhs: vector, rhs: vector): vector
return vector(bit32.bor(lhs.X, rhs.X), bit32.bor(lhs.Y, rhs.Y), bit32.bor(lhs.Z, rhs.Z))
end
--[=[
Computes the bitwise XOR of the two provide values.
This does not accept a vararg like the `bit32` equivalent for performance
reasons.
]=]
local function bxor(lhs: vector, rhs: vector): vector
return vector(bit32.bxor(lhs.X, rhs.X), bit32.bxor(lhs.Y, rhs.Y), bit32.bxor(lhs.Z, rhs.Z))
end
--[=[
Computes the bitwise negation of the provided value.
]=]
local function bnot(u64: vector): vector
return vector(bit32.bnot(u64.X), bit32.bnot(u64.Y), bit32.bnot(u64.Z))
end
--[=[
Shifts the provided `u64` logically left by `n` bits.
This function will error if `n` is outside the range `[0, 64]`.
]=]
local function lshift(u64: vector, n: number): vector
if n == 0 then
return u64
end
local most, least = u32_from_u64(u64)
if n < 32 then
local remainder = 32 - n
local ret =
u64_from_pair(bit32.replace(bit32.rshift(least, remainder), most, n, remainder), bit32.lshift(least, n))
return ret
else
local ret = u64_from_pair(bit32.lshift(least, n - 32), 0)
return ret
end
end
--[=[
Shifts the provided `u64` logically right by `n` bits.
This function will error if `n` is outside the range `[0, 64]`.
]=]
local function rshift(u64: vector, n: number): vector
if n == 0 then
return u64
end
local most, least = u32_from_u64(u64)
if n < 32 then
local ret = u64_from_pair(bit32.rshift(most, n), bit32.replace(bit32.rshift(least, n), most, 32 - n, n))
return ret
else
local ret = u64_from_pair(0, bit32.rshift(most, n - 32))
return ret
end
end
--[=[
Shifts the provided `u64` arithmetically right by `n` bits. Since these
numbers are unsigned, this effectively just copies the most significant
bit into the empty space rather than filling them with zeros.
This function will error if `n` is outside the range `[0, 64]`.
]=]
local function arshift(u64: vector, n: number): vector
if n == 0 then
return u64
end
local most, least = u32_from_u64(u64)
if n < 32 then
local ret = u64_from_pair(bit32.arshift(most, n), bit32.replace(bit32.rshift(least, n), most, 32 - n, n))
return ret
else
local ret = u64_from_pair(
if bit32.btest(most, 0x8000_0000) then 0xFFFF_FFFF else 0x0000_0000,
bit32.arshift(most, n - 32)
)
return ret
end
end
--[=[
Rotates the bits of the provided `u64` left by `n` bits.
This function will error if `n` is outside the range `[0, 64]`.
]=]
local function lrotate(u64: vector, n: number): vector
if n == 64 then
return u64
else
local lshifted = lshift(u64, n)
local rshifted = rshift(u64, 64 - n)
return bor(lshifted, rshifted)
end
end
--[=[
Rotates the bits of the provided `u64` right by `n` bits.
This function will error if `n` is outside the range `[0, 64]`.
]=]
local function rrotate(u64: vector, n: number): vector
if n == 64 then
return u64
else
local lshifted = rshift(u64, n)
local rshifted = lshift(u64, 64 - n)
return bor(lshifted, rshifted)
end
end
--[=[
Returns the number of consecutive zero bits in the provided `u64` starting
from the left-most (most significant) bit.
]=]
local function countlz(u64: vector): number
local most, least = u32_from_u64(u64)
if most == 0 then
return bit32.countlz(least) + 32
else
return bit32.countlz(most)
end
end
--[=[
Returns the number of consecutive zero bits in the provided `u64` starting
from the right-most (least significant) bit.
]=]
local function countrz(u64: vector): number
local most, least = u32_from_u64(u64)
if least == 0 then
return bit32.countrz(most) + 32
else
return bit32.countrz(least)
end
end
--[=[
Returns a boolean describing whether the bitwise AND of `lhs` and
`rhs` are different than zero.
This does not accept a vararg like the `bit32` equivalent for performance
reasons.
]=]
local function btest(lhs: vector, rhs: vector): boolean
return is_zero(band(lhs, rhs))
end
--[=[
Returns the provided `u64` with the order of bytes swapped.
]=]
local function byteswap(u64: vector): vector
local most, least = u32_from_u64(u64)
return u64_from_pair(bit32.byteswap(least), bit32.byteswap(most))
end
--[=[
Returns whether `lhs` is greater than `rhs`.
]=]
local function gt(lhs: vector, rhs: vector): boolean
local l_most, l_least = u32_from_u64(lhs)
local r_most, r_least = u32_from_u64(rhs)
if l_most == r_most then
return l_least > r_least
else
return l_most > r_most
end
end
--[=[
Returns whether `lhs` is greater than or equal to `rhs`.
]=]
local function gt_equal(lhs: vector, rhs: vector): boolean
return lhs == rhs or gt(lhs, rhs)
end
--[=[
Returns whether `lhs` is less than `rhs`.
]=]
local function lt(lhs: vector, rhs: vector): boolean
local l_most, l_least = u32_from_u64(lhs)
local r_most, r_least = u32_from_u64(rhs)
if l_most == r_most then
return l_least < r_least
else
return l_most < r_most
end
end
--[=[
Returns whether `lhs` is less than or equal to `rhs`.
]=]
local function lt_equal(lhs: vector, rhs: vector): boolean
return lhs == rhs or lt(lhs, rhs)
end
--[=[
Calculates the sum of the two provided values. Equivalent to `+` for
normal integers.
If the sum is equal to or greater than 2^64, the returned value will
overflow rather than expanding beyond 64 bits.
]=]
local function add(lhs: vector, rhs: vector): vector
-- A more clever man could probably just use `lhs + rhs`
-- and then handle carrying manually. I am not clever.
local l_most, l_least = u32_from_u64(lhs)
local r_most, r_least = u32_from_u64(rhs)
local f_most = l_most + r_most
local f_least = l_least + r_least
-- carrying
if f_least >= 2 ^ 32 then
f_least -= 2 ^ 32
f_most += 1
end
if f_most >= 2 ^ 32 then
f_most -= 2 ^ 32
end
return u64_from_pair(f_most, f_least)
end
--[=[
Calculates the difference of the two provided values. Equivalent to `-` for
normal integers.
If the difference is less than 0, the returned value will overflow rather
than going negative.
]=]
local function sub(lhs: vector, rhs: vector): vector
local l_most, l_least = u32_from_u64(lhs)
local r_most, r_least = u32_from_u64(rhs)
local f_most = l_most - r_most
local f_least = l_least - r_least
-- carrying
if f_least < 0 then
f_least += 2 ^ 32
f_most -= 1
end
if f_most < 0 then
f_most += 2 ^ 32
end
return u64_from_pair(f_most, f_least)
end
--[=[
Calculates the product of the two provided values. Equivalent to `*` for
normal integers.
If the product is greater than or equal to 2^64, the returned value will
overflow rather than expanding beyond 64-bits.
]=]
local function mult(lhs: vector, rhs: vector): vector
-- We represent 64-bit numbers as two 32-bit ones.
-- Multiplying them is:
-- (A + B) * (C + D) = (A * C) + (A * D) + (B * C) + (B * D)
--
-- However, multiplying two 32-bit numbers might overflow. So, we need to
-- use 16-bit numbers. This turns out math into this:
-- (A + B + C + D) * (E + F + G + H) =
-- (A * E) + (A * F) + (A * G) + (A * H) +
-- (B * E) + (B * F) + (B * G) + (B * H) +
-- (C * E) + (C * F) + (C * G) + (C * H) +
-- (D * E) + (D * F) + (D * G) + (D * H)
--
-- We can skip (A * E), (A * F), (A * G), (B * E), (B * F), and (C * F)
-- because they don't exist within the bounds of the final product.
-- Since the numbers are built as A * 2 ^ 48 + B * 2 ^ 32 + C * 2 ^ 16 + D
-- you end up with e.g. (A * 2 ^ 48 * E * 2 ^ 48) which is A * E * 2^96.
-- Otherwise... Here we go.
local a, b, c, d = u16_from_u64(lhs)
local e, f, g, h = u16_from_u64(rhs)
local product_4 = d * h
local product_3 = bit32.rshift(product_4, 16) + c * h
local product_2 = bit32.rshift(product_3, 16)
product_3 = bit32.band(product_3, 0xFFFF) + d * g
product_2 += bit32.rshift(product_3, 16) + b * h
local product_1 = bit32.rshift(product_2, 16)
product_2 = bit32.band(product_2, 0xFFFF) + c * g
product_1 += bit32.rshift(product_2, 16)
product_2 = bit32.band(product_2, 0xFFFF) + d * f
product_1 += bit32.rshift(product_2, 16) + a * h + b * g + c * f + d * e
-- We skip truncating any of the products the last time since they'll never
-- overflow and bit32 will truncate them for us.
return u64_from_pair(bit32.replace(product_2, product_1, 16, 16), bit32.replace(product_4, product_3, 16, 16))
end
--[=[
Calculates the product of the two provided values. Equivalent to `//` for
normal integers.
This function will error if `rhs` is 0.
]=]
local function div(lhs: vector, rhs: vector): vector
if is_zero(rhs) then
error("cannot divide integers by zero", 2)
elseif countlz(lhs) >= 11 and countlz(rhs) >= 11 then
local l_real = f64_from_u64(lhs)
local r_real = f64_from_u64(rhs)
return u64_from_f64(l_real / r_real), u64_from_f64(math.fmod(l_real, r_real))
end
local quotient = ZERO_VECTOR
local divisor = rhs
local power = 0
while lt_equal(divisor, lhs) do
divisor = lshift(divisor, 1)
power += 1
end
while power > 0 do
divisor = rshift(divisor, 1)
power -= 1
if lt_equal(divisor, lhs) then
lhs = sub(lhs, divisor)
quotient = bor(quotient, lshift(ONE_VECTOR, power))
end
end
return quotient
end
-- TODO: replace and extract
-- TODO: to_decimal_string
-- TODO: vararg versions of band, bor, bxor
-- TODO: constants
return {
from_pair = u64_from_pair,
from_u32 = u64_from_u32,
from_f64 = u64_from_f64,
from_string = u64_from_string,
from_buffer = u64_from_buffer,
to_pair = u32_from_u64,
to_quartet = u16_from_u64,
to_string = to_bytes_string,
to_buffer = to_bytes_buffer,
to_hex_string = to_hex_string,
to_bin_string = to_bin_string,
band = band,
bor = bor,
bxor = bxor,
btest = btest,
bnot = bnot,
byteswap = byteswap,
countlz = countlz,
countrz = countrz,
rshift = rshift,
arshift = arshift,
rrotate = rrotate,
lshift = lshift,
lrotate = lrotate,
add = add,
sub = sub,
mult = mult,
div = div,
lt = lt,
lt_equal = lt_equal,
gt = gt,
gt_equal = gt_equal,
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment