Created
August 26, 2018 18:04
-
-
Save pfirsich/3a79b57b63b7548d545f8df4c48672c4 to your computer and use it in GitHub Desktop.
A LuaJIT bitmask class (like std::bitset)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
local bit = require("bit") | |
local band, bor, bxor, bnot, lshift = bit.band, bit.bor, bit.bxor, bit.bnot, bit.lshift | |
local function pot(n) -- power of two | |
return lshift(1, n) | |
end | |
local BitMask = setmetatable({}, {__call = function(t, ...) | |
local self = setmetatable({}, t) | |
self:initialize(...) | |
return self | |
end}) | |
BitMask.__index = BitMask | |
function BitMask:initialize(arg) | |
arg = arg or 32 | |
assert(type(arg) == "number" or type(arg) == "table") | |
if type(arg) == "number" then | |
local size = arg | |
self.size = size | |
local count = math.ceil(size / 32) -- 32 bits per number | |
self.numbers = {} | |
for i = 1, count do | |
self.numbers[i] = 0 | |
end | |
elseif type(arg) == "table" then -- copy constructor | |
local other = arg | |
self.size = other.size | |
self.numbers = {} | |
for i = 1, #other.numbers do | |
self.numbers[i] = other.numbers[i] | |
end | |
end | |
end | |
-- BitMask:count, all, any | |
function BitMask:_getNumIndex(i) | |
assert(i > 0 and i <= self.size) | |
local numIndex = math.floor((i-1)/32) + 1 | |
local mask = pot((i-1) % 32) | |
return numIndex, mask | |
end | |
function BitMask:set(i, value) | |
local numIndex, mask = self:_getNumIndex(i) | |
if value then | |
--print("set", self.numbers[numIndex], mask, bor(self.numbers[numIndex], mask)) | |
self.numbers[numIndex] = bor(self.numbers[numIndex], mask) | |
else | |
self.numbers[numIndex] = band(self.numbers[numIndex], bnot(mask)) | |
end | |
end | |
function BitMask:get(i) | |
local numIndex, mask = self:_getNumIndex(i) | |
return band(self.numbers[numIndex], mask) == mask | |
end | |
function BitMask:toggle(i) | |
self:set(i, not self:get(i)) | |
end | |
-- checks if `self & other == other` i.e. if other is set in self. | |
function BitMask:check(other) | |
assert(self.size == other.size) | |
for i = 1, #self.numbers do | |
if band(self.numbers[i], other.numbers[i]) ~= other.numbers[i] then | |
return false | |
end | |
end | |
return true | |
end | |
-- does `self = self & ~other` i.e. remove all bits set in other from self | |
function BitMask:remove(other) | |
assert(self.size == other.size) | |
for i = 1, #self.numbers do | |
self.numbers[i] = band(self.numbers[i], bnot(other.numbers[i])) | |
end | |
end | |
function BitMask:setOp(op, other) | |
assert(self.size == other.size) | |
for i = 1, #self.numbers do | |
self.numbers[i] = op(self.numbers[i], other.numbers[i]) | |
end | |
end | |
function BitMask:setAnd(other) | |
self:setOp(band, other) | |
end | |
function BitMask:setOr(other) | |
self:setOp(bor, other) | |
end | |
function BitMask:setXor(other) | |
self:setOp(bxor, other) | |
end | |
function BitMask:setNot() | |
for i = 1, #self.numbers do | |
self.numbers[i] = bnot(self.numbers[i]) | |
end | |
end | |
function BitMask:retAnd(other) | |
local ret = BitMask(self) | |
ret:setAnd(other) | |
return ret | |
end | |
function BitMask:retOr(other) | |
local ret = BitMask(self) | |
ret:setOr(other) | |
return ret | |
end | |
function BitMask:retXor(other) | |
local ret = BitMask(self) | |
ret:setXor(other) | |
return ret | |
end | |
function BitMask:retNot() | |
local ret = BitMask(self) | |
ret:setNot() | |
return ret | |
end | |
function BitMask:string() | |
local parts = {} | |
for i = self.size, 1, -1 do | |
parts[i] = self:get(i) and "1" or "0" | |
end | |
return table.concat(parts) | |
end | |
if arg[1] == "test" then | |
local function randomMask(size, numSet) | |
local mask = BitMask(size) | |
for i = 1, numSet do | |
mask:set(math.random(1, mask.size), true) | |
end | |
return mask | |
end | |
local masks = {} | |
for i = 1, 1000 do | |
masks[#masks + 1] = randomMask(nil, math.random(1, 16)) | |
masks[#masks + 1] = BitMask(masks[#masks]) | |
assert(masks[#masks]:check(masks[#masks-1])) | |
masks[#masks + 1] = randomMask(32, math.random(1, 16)) | |
masks[#masks + 1] = BitMask(masks[#masks]) | |
assert(masks[#masks]:check(masks[#masks-1])) | |
masks[#masks + 1] = randomMask(64, math.random(1, 32)) | |
masks[#masks + 1] = BitMask(masks[#masks]) | |
assert(masks[#masks]:check(masks[#masks-1])) | |
masks[#masks + 1] = randomMask(128, math.random(1, 64)) | |
masks[#masks + 1] = BitMask(masks[#masks]) | |
assert(masks[#masks]:check(masks[#masks-1])) | |
end | |
for m = 1, #masks do | |
local mask = masks[m] | |
local other = randomMask(mask.size, mask.size/4) | |
local i = math.random(1, mask.size) | |
local val = math.random() > 0.5 | |
mask:set(i, val) | |
assert(mask:get(i) == val) | |
assert(mask:check(mask)) | |
mask:setOr(other) | |
assert(mask:check(other)) | |
mask:remove(other) | |
assert(not mask:check(other)) | |
end | |
end | |
return BitMask |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment