Last active
February 9, 2019 07:20
-
-
Save cwchentw/9b618ca04508c644d838a20b4790a961 to your computer and use it in GitHub Desktop.
Math Vector in Pure Lua (Apache 2.0)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
local Vector = {} | |
package.loaded['Vector'] = Vector | |
Vector.__index = Vector | |
Vector.__eq = function (a, b) | |
if type(a) ~= type(b) then | |
return false | |
end | |
assert(type(a) == "table" and a["at"] and a["len"]) | |
assert(type(b) == "table" and b["at"] and b["len"]) | |
local len = a:len() | |
if len ~= b:len() then | |
return false | |
end | |
for i = 1, len do | |
if a:at(1) ~= b:at(1) then | |
return false | |
end | |
end | |
return true | |
end | |
Vector.__add = function(a, b) | |
local function _scalar_add(s, v0) | |
local v = Vector:new(v0:len()) | |
for i = 1, #v0 do | |
v:setAt(i, v0:at(i) + s) | |
end | |
return v | |
end | |
if type(a) == "number" then | |
return _scalar_add(a, b) | |
end | |
if type(b) == "number" then | |
return _scalar_add(b, a) | |
end | |
assert(a:len() == b:len()) | |
local v = Vector:new(a:len()) | |
for i = 1, a:len() do | |
v:setAt(i, a:at(i) + b:at(i)) | |
end | |
return v | |
end | |
Vector.__sub = function(a, b) | |
local function _scalar_sub_first(s, v0) | |
local v = Vector:new(v0:len()) | |
for i = 1, v0:len() do | |
v:setAt(i, s - v0:at(i)) | |
end | |
return v | |
end | |
local function _scalar_sub_second(v0, s) | |
local v = Vector:new(v0:len()) | |
for i = 1, v0:len() do | |
v:setAt(i, v0:at(i) - s) | |
end | |
return v | |
end | |
if type(a) == "number" then | |
return _scalar_sub_first(a, b) | |
end | |
if type(b) == "number" then | |
return _scalar_sub_second(a, b) | |
end | |
assert(a:len() == b:len()) | |
local v = Vector:new(a:len()) | |
for i = 1, a:len() do | |
v:setAt(i, a:at(i) - b:at(i)) | |
end | |
return v | |
end | |
Vector.__mul = function (a, b) | |
local function __scalar_mul(s, v0) | |
local v = Vector:new(v0:len()) | |
for i = 1, v0:len() do | |
v:setAt(i, v0:at(i) * s) | |
end | |
return v | |
end | |
if type(a) == "number" then | |
return __scalar_mul(a, b) | |
end | |
if type(b) == "number" then | |
return __scalar_mul(b, a) | |
end | |
assert(a:len() == b:len()) | |
local v = Vector:new(a:len()) | |
for i = 1, a:len() do | |
v:setAt(i, a:at(i) * b:at(i)) | |
end | |
return v | |
end | |
Vector.__div = function (a, b) | |
local function __scalar_div_first(s, v0) | |
local v = Vector:new(v0:len()) | |
for i = 1, v0:len() do | |
v:setAt(i, s / v0:at(i)) | |
end | |
return v | |
end | |
local function __scalar_div_second(v0, s) | |
local v = Vector:new(v0:len()) | |
for i = 1, v0:len() do | |
v:setAt(i, v0:at(i) / s) | |
end | |
return v | |
end | |
if type(a) == "number" then | |
return __scalar_div_first(a, b) | |
end | |
if type(b) == "number" then | |
return __scalar_div_second(a, b) | |
end | |
assert(a:len() == b:len()) | |
local v = Vector:new(a:len()) | |
for i = 1, a:len() do | |
v:setAt(i, a:at(i) / b:at(i)) | |
end | |
return v | |
end | |
function Vector:new(size) | |
self = {} | |
self._array = {} | |
for i = 1, size do | |
self._array[i] = 0 | |
end | |
setmetatable(self, Vector) | |
return self | |
end | |
function Vector:fromTable(t) | |
local v = Vector:new(#t) | |
for i = 1, #t do | |
v:setAt(i, t[i]) | |
end | |
return v | |
end | |
function Vector:at(i) | |
return self._array[i] | |
end | |
function Vector:setAt(i, e) | |
self._array[i] = e | |
end | |
function Vector:len() | |
return #(self._array) | |
end | |
return Vector |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
local vector = require("vector") | |
do | |
local v = vector:new(3) | |
assert(v:len(), 3) | |
assert(v:at(1) == 0) | |
assert(v:at(2) == 0) | |
assert(v:at(3) == 0) | |
end | |
do | |
v = vector:fromTable({1, 2, 3}) | |
assert(v:len(), 3) | |
assert(v:at(1) == 1) | |
assert(v:at(2) == 2) | |
assert(v:at(3) == 3) | |
end | |
-- vector equality | |
do | |
local v1 = vector:fromTable({1, 2, 3}) | |
local v2 = vector:fromTable({1, 2, 3}) | |
local v3 = vector:fromTable({2, 3, 4}) | |
local v4 = vector:fromTable({1, 2, 3, 4}) | |
assert(v1 == v2) | |
assert(v1 ~= v3) | |
assert(v1 ~= v4) | |
end | |
-- vector addition | |
do | |
local v1 = vector:fromTable({1, 2, 3}) | |
local v2 = vector:fromTable({2, 3, 4}) | |
local v = v1 + v2 | |
assert(v:len(), 3) | |
assert(v:at(1) == 3) | |
assert(v:at(2) == 5) | |
assert(v:at(3) == 7) | |
end | |
-- vector substration | |
do | |
local v1 = vector:fromTable({1, 2, 3}) | |
local v2 = vector:fromTable({2, 3, 4}) | |
local v = v1 - v2 | |
assert(v:len(), 3) | |
assert(v:at(1) == -1) | |
assert(v:at(2) == -1) | |
assert(v:at(3) == -1) | |
end | |
-- vector multiplication | |
do | |
local v1 = vector:fromTable({1, 2, 3}) | |
local v2 = vector:fromTable({2, 3, 4}) | |
local v = v1 * v2 | |
assert(v:len(), 3) | |
assert(v:at(1) == 2) | |
assert(v:at(2) == 6) | |
assert(v:at(3) == 12) | |
end | |
-- vector division | |
do | |
local v1 = vector:fromTable({1, 2, 3}) | |
local v2 = vector:fromTable({2, 3, 4}) | |
local v = v1 / v2 | |
assert(v:len(), 3) | |
assert(math.abs(v:at(1) - 0.5) < 1 / 1000000) | |
assert(math.abs(v:at(2) - 0.6666667) < 1 / 1000000) | |
assert(math.abs(v:at(3) - 0.75) < 1 / 1000000) | |
end | |
-- Prevent Segmentation fault | |
os.exit(0) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment