Created
November 12, 2015 16:56
-
-
Save deltheil/e1be741aee7b26a3da02 to your computer and use it in GitHub Desktop.
Playing with torch/argcheck
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require 'torch' | |
local argcheck = require 'argcheck' | |
local env = require 'argcheck.env' | |
function env.istype(obj, typename) | |
if typename == 'torch.Tensor' then | |
-- could also check the storage type! | |
return torch.isTensor(obj) | |
else | |
return type(obj) == typename | |
end | |
end | |
local TYPES = { | |
byte = true, | |
float = true, | |
double = true, | |
} | |
local valid = function(x) | |
return TYPES[x] | |
end | |
-- dst, filename, depth, tensortype | |
local check = argcheck{ | |
{name="dst", type="torch.Tensor", opt=true }, | |
{name="filename", type="string" }, | |
{name="depth", type="number", opt=true }, | |
{name="tensortype", type="string", opt=true, check=valid} | |
} | |
do | |
local a, b, c, d = check("foo.jpg") | |
print(a, b, c, d) | |
end | |
print("\n\n") | |
do | |
local a, b, c, d = check("foo.jpg", 3) | |
print(a, b, c, d) | |
end | |
print("\n\n") | |
do | |
local a, b, c, d = check("foo.jpg", 3, "byte") | |
print(a, b, c, d) | |
end | |
print("\n\n") | |
do | |
local buf = torch.Tensor() | |
local a, b, c, d = check(buf, "foo.jpg") | |
print(a, b, c, d) | |
end | |
print("\n\n") | |
do | |
local buf = torch.Tensor() | |
local a, b, c, d = check(buf, "foo.jpg", 3) | |
print(a, b, c, d) | |
end | |
print("\n\n") | |
do | |
local buf = torch.Tensor() | |
local a, b, c, d = check(buf, "foo.jpg", 3, "byte") | |
print(a, b, c, d) | |
if a and d then | |
print('redundant params: should error') | |
end | |
end | |
print("\n\n") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment