Created
July 18, 2019 21:04
-
-
Save keveman/7dce82b0824d96eb12e412a38b0aa490 to your computer and use it in GitHub Desktop.
Circle - Peglib
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <assert.h> | |
#include <iostream> | |
#include <memory> | |
#include "peglib.h" | |
#include <vector> | |
#include <string> | |
struct matrix { | |
int rows, cols; | |
std::vector<int> data; | |
}; | |
struct access { | |
std::string tensor_name; | |
matrix A; | |
matrix b; | |
}; | |
enum class operation { | |
AssignAdd, | |
Multiply, | |
Add | |
}; | |
struct expr { | |
operation op; | |
std::vector<expr *> operands; | |
}; | |
struct loop { | |
std::vector<int> bounds; | |
std::vector<std::string> loop_indices; | |
expr body; | |
void dump(std::ostream &o) const { | |
o << "all("; | |
for (int i = 0, e = loop_indices.size(); i != e; ++i) { | |
o << loop_indices[i]; | |
if (i < e - 1) o << ", "; | |
} | |
o << ") in ("; | |
for (int i = 0, e = bounds.size(); i != e; ++i) { | |
o << bounds[i]; | |
if (i < e - 1) o << ", "; | |
} | |
o << ")\n"; | |
} | |
}; | |
loop parse(const char *str); | |
using namespace peg; | |
int get_number(const Ast &node) { | |
assert(node.name == "Number"); | |
if (node.nodes.size() == 1) { | |
assert(node.nodes[0]->name == "PlainNumber" && node.nodes[0]->is_token); | |
return std::stoi(node.nodes[0]->token); | |
} | |
assert(node.nodes[0]->name == "Sign" && node.nodes[0]->is_token); | |
assert(node.nodes[1]->name == "PlainNumber" && node.nodes[1]->is_token); | |
int retval = std::stoi(node.nodes[1]->token); | |
if (node.nodes[0]->token == "-") return -retval; | |
return retval; | |
} | |
loop parse(const char *str) { | |
auto grammar = R"( | |
ROOT <- LOOP | |
TensorAccess | |
LOOP <- ALL LEFT_BRACKET IDLIST RIGHT_BRACKET | |
IN LEFT_BRACKET NUMLIST RIGHT_BRACKET | |
ALL <- 'all' | |
IN <- 'in' | |
LEFT_BRACKET <- '(' | |
RIGHT_BRACKET <- ')' | |
PlainNumber <- T([0-9]+) | |
Sign <- T('-' / '+') | |
Number <- Sign? PlainNumber | |
ID <- < [a-zA-Z_] [a-zA-Z0-9_]* > | |
IDLIST <- ID (',' ID)* | |
NUMLIST <- Number (',' Number)* | |
Additive <- Multitive '+' Additive | |
/ Multitive '-' Additive | |
/ Multitive | |
Multitive <- Primary '*' Multitive / Primary | |
Primary <- '(' Additive ')' / Number / ID | |
Expr <- Additive | |
IndexExpr <- '[' Expr ']' | |
TensorAccess <- ID IndexExpr* | |
%whitespace <- [ \t\n]* | |
T(x) <- < x > | |
)"; | |
parser parser; | |
auto load_ok = parser.load_grammar(grammar); | |
assert(load_ok); | |
parser["LOOP"] = [](const peg::SemanticValues &sv, any& dt) -> loop { | |
loop l; | |
l.loop_indices = sv[2].get<std::vector<std::string>>(); | |
l.bounds = sv[6].get<std::vector<int>>(); | |
dt = l; | |
return l; | |
}; | |
parser["IDLIST"] = | |
[](const peg::SemanticValues &sv) -> std::vector<std::string> { | |
std::vector<std::string> ids; | |
for (const auto &id : sv) | |
ids.push_back(id.get<std::string>()); | |
return ids; | |
}; | |
parser["ID"] = [](const peg::SemanticValues &sv) -> std::string { | |
return std::string(sv.token()); | |
}; | |
parser["NUMLIST"] = [](const peg::SemanticValues &sv) -> std::vector<int> { | |
std::vector<int> nums; | |
for (const auto &n : sv) | |
nums.push_back(n.get<int>()); | |
return nums; | |
}; | |
parser["Number"] = [](const peg::SemanticValues &sv) -> int { | |
int ret = 1; | |
for (const auto& n : sv) { | |
ret *= n.get<int>(); | |
} | |
return ret; | |
}; | |
parser["Sign"] = [](const peg::SemanticValues &sv) -> int { | |
// choice() = 0 => '-', choice() = 1 => '+' | |
return sv.choice() == 0 ? -1 : 1; | |
}; | |
parser["PlainNumber"] = [](const peg::SemanticValues &sv) -> int { | |
return std::stoi(sv.token(), nullptr, 10); | |
}; | |
parser["TensorAccess"] = [](const peg::SemanticValues &sv, any &dt) -> int { | |
loop partial_loop = dt.get<loop>(); | |
std::cout << "Loop\n"; | |
partial_loop.dump(std::cout); | |
return 0; | |
}; | |
loop partial_loop; | |
loop l; | |
any dt = partial_loop; | |
if (!parser.parse(str, dt, l)) { | |
std::cout << "Failed to parse\n"; | |
assert(false); | |
} | |
std::cout << "Parse succeeded\n"; | |
l.dump(std::cout); | |
return l; | |
} | |
int main(int argc, char *argv[]) { | |
parse(argv[1]); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment