Created
May 26, 2024 18:57
-
-
Save oopsmishap/622734a7623e8a8124e002b357e6fc10 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#pragma once | |
#include <memory> | |
#include <string> | |
#include <vector> | |
#include <fmt/core.h> | |
#include <stdexcept> | |
#include <zasm/zasm.hpp> | |
#include <windows.h> | |
using namespace zasm; | |
class ASTNode | |
{ | |
public: | |
virtual ~ASTNode() = default; | |
virtual void print() const = 0; | |
virtual void optimize(std::vector<std::unique_ptr<ASTNode>>& optimized) const = 0; | |
}; | |
using Nodes = std::vector<std::unique_ptr<ASTNode>>; | |
void optimizeSequence(const Nodes& nodes, Nodes& optimized); | |
class Operation : public ASTNode | |
{ | |
public: | |
explicit Operation(int count) : value(count) | |
{ | |
} | |
void print() const override | |
{ | |
fmt::println("Operation({})", value); | |
} | |
void optimize(Nodes& optimized) const override | |
{ | |
if (value != 0) | |
{ | |
optimized.push_back(std::make_unique<Operation>(*this)); | |
} | |
} | |
int value; | |
}; | |
class Move : public ASTNode | |
{ | |
public: | |
explicit Move(int count) : value(count) | |
{ | |
} | |
void print() const override | |
{ | |
fmt::println("Move({})", value); | |
} | |
void optimize(Nodes& optimized) const override | |
{ | |
if (value != 0) | |
{ | |
optimized.push_back(std::make_unique<Move>(*this)); | |
} | |
} | |
int value; | |
}; | |
class Output : public ASTNode | |
{ | |
public: | |
void print() const override | |
{ | |
fmt::println("Output()"); | |
} | |
void optimize(Nodes& optimized) const override | |
{ | |
optimized.push_back(std::make_unique<Output>(*this)); | |
} | |
}; | |
class Input : public ASTNode | |
{ | |
public: | |
void print() const override | |
{ | |
fmt::println("Input()"); | |
} | |
void optimize(Nodes& optimized) const override | |
{ | |
optimized.push_back(std::make_unique<Input>(*this)); | |
} | |
}; | |
class SetToZero : public ASTNode | |
{ | |
public: | |
void print() const override | |
{ | |
fmt::println("SetToZero()"); | |
} | |
void optimize(Nodes& optimized) const override | |
{ | |
optimized.push_back(std::make_unique<SetToZero>(*this)); | |
} | |
}; | |
int indent = 0; | |
class Loop : public ASTNode | |
{ | |
public: | |
explicit Loop(Nodes body) : body(std::move(body)) | |
{ | |
} | |
void print() const override | |
{ | |
fmt::println("{} LoopStart()", std::string(indent, ' ')); | |
indent++; | |
for (const auto& node : body) | |
{ | |
fmt::print("{}", std::string(indent, ' ')); | |
node->print(); | |
} | |
indent--; | |
fmt::println("{} LoopEnd()", std::string(indent, ' ')); | |
} | |
void optimize(Nodes& optimized) const override | |
{ | |
Nodes optimizedBody; | |
optimizeSequence(body, optimizedBody); | |
if (!optimizedBody.empty()) | |
{ | |
optimized.push_back(std::make_unique<Loop>(std::move(optimizedBody))); | |
} | |
} | |
Nodes body; | |
}; | |
void optimizeSequence(const Nodes& nodes, Nodes& optimized) | |
{ | |
int moveCount = 0; | |
int opCount = 0; | |
auto pushMove = [&optimized](int count) | |
{ | |
if (count != 0) | |
{ | |
optimized.push_back(std::make_unique<Move>(count)); | |
} | |
}; | |
auto pushOperation = [&optimized](int count) | |
{ | |
if (count != 0) | |
{ | |
optimized.push_back(std::make_unique<Operation>(count)); | |
} | |
}; | |
auto finalizeCurrent = [&]() | |
{ | |
pushMove(moveCount); | |
pushOperation(opCount); | |
moveCount = 0; | |
opCount = 0; | |
}; | |
for (const auto& node : nodes) | |
{ | |
if (auto move = dynamic_cast<Move*>(node.get())) | |
{ | |
finalizeCurrent(); | |
moveCount += move->value; | |
} | |
else if (auto op = dynamic_cast<Operation*>(node.get())) | |
{ | |
finalizeCurrent(); | |
opCount += op->value; | |
} | |
else if( auto loop = dynamic_cast< Loop* >( node.get() ) ) | |
{ | |
// Check if the loop is a SetToZero pattern | |
if (loop->body.size() == 1) | |
{ | |
if (auto loopOp = dynamic_cast<Operation*>(loop->body.front().get())) | |
{ | |
if (loopOp->value == -1) | |
{ | |
finalizeCurrent(); | |
optimized.push_back(std::make_unique<SetToZero>()); | |
continue; | |
} | |
} | |
} | |
finalizeCurrent(); | |
Nodes optimizedBody; | |
optimizeSequence(loop->body, optimizedBody); | |
optimized.push_back(std::make_unique<Loop>(std::move(optimizedBody))); | |
} | |
else | |
{ | |
finalizeCurrent(); | |
node->optimize(optimized); | |
} | |
} | |
finalizeCurrent(); | |
} | |
Nodes optimizeAST(const Nodes& nodes) | |
{ | |
Nodes optimized; | |
optimizeSequence(nodes, optimized); | |
return optimized; | |
} | |
class Lexer | |
{ | |
public: | |
explicit Lexer(const std::string& source) : source(source), pos(0) | |
{ | |
} | |
char next() | |
{ | |
char c = peek(); | |
if (c != '\0') | |
advance(); | |
return c; | |
} | |
char peek() | |
{ | |
while (pos < source.size() && !isValidCharacter(source[pos])) | |
{ | |
advance(); | |
} | |
return (pos < source.size()) ? source[pos] : '\0'; | |
} | |
private: | |
std::string source; | |
size_t pos; | |
const std::string validChars = "+-<>[].,"; | |
void advance() | |
{ | |
++pos; | |
} | |
bool isValidCharacter(char c) const | |
{ | |
return validChars.find(c) != std::string::npos; | |
} | |
}; | |
class Parser | |
{ | |
public: | |
explicit Parser(Lexer& lexer) : lexer(lexer) | |
{ | |
} | |
Nodes parse() | |
{ | |
Nodes nodes; | |
char token = 0; | |
while ((token = lexer.next()) != '\0') | |
{ | |
nodes.push_back(parseToken(token)); | |
} | |
return nodes; | |
} | |
private: | |
Lexer& lexer; | |
std::unique_ptr<ASTNode> parseToken(char token) | |
{ | |
int count = 1; | |
switch (token) | |
{ | |
case '+': | |
while (lexer.peek() == '+') | |
{ | |
lexer.next(); | |
++count; | |
} | |
return std::make_unique<Operation>(count); | |
case '-': | |
while (lexer.peek() == '-') | |
{ | |
lexer.next(); | |
++count; | |
} | |
return std::make_unique<Operation>(-count); | |
case '>': | |
while (lexer.peek() == '>') | |
{ | |
lexer.next(); | |
++count; | |
} | |
return std::make_unique<Move>(count); | |
case '<': | |
while (lexer.peek() == '<') | |
{ | |
lexer.next(); | |
++count; | |
} | |
return std::make_unique<Move>(-count); | |
case '.': | |
return std::make_unique<Output>(); | |
case ',': | |
return std::make_unique<Input>(); | |
case '[': | |
return parseLoop(); | |
default: | |
throw std::runtime_error("Invalid token"); | |
} | |
} | |
std::unique_ptr<ASTNode> parseLoop() | |
{ | |
Nodes nodes; | |
char token; | |
while ((token = lexer.next()) != ']') | |
{ | |
if (token == '\0') | |
{ | |
throw std::runtime_error("Unmatched '['"); | |
} | |
nodes.push_back(parseToken(token)); | |
} | |
return std::make_unique<Loop>(std::move(nodes)); | |
} | |
}; | |
struct BrainFuckGenerator | |
{ | |
using FunctionType = void(__fastcall*)(void*); | |
static void generate(Label labelFunc, Program& program, const Nodes& nodes) | |
{ | |
x86::Assembler a(program); | |
a.bind(labelFunc); | |
a.mov(x86::rsi, x86::rcx); | |
a.mov(x86::r14, reinterpret_cast<uintptr_t>(&putchar)); | |
a.mov(x86::r15, reinterpret_cast<uintptr_t>(&getchar)); | |
generateAsm(a, nodes); | |
a.ret(); | |
} | |
private: | |
static void generateAsm(x86::Assembler& a, const Nodes& nodes) | |
{ | |
for (const auto& node : nodes) | |
{ | |
if (auto move = dynamic_cast<Move*>(node.get())) | |
{ | |
if (move->value > 0) | |
{ | |
a.add(x86::rsi, move->value); | |
} | |
else | |
{ | |
a.sub(x86::rsi, -move->value); | |
} | |
} | |
else if (auto op = dynamic_cast<Operation*>(node.get())) | |
{ | |
if (op->value > 0) | |
{ | |
a.add(x86::byte_ptr(x86::rsi), op->value); | |
} | |
else | |
{ | |
a.sub(x86::byte_ptr(x86::rsi), -op->value); | |
} | |
} | |
else if (auto output = dynamic_cast<Output*>(node.get())) | |
{ | |
a.mov(x86::cl, x86::byte_ptr(x86::rsi)); | |
a.call(x86::r14); | |
} | |
else if (auto input = dynamic_cast<Input*>(node.get())) | |
{ | |
a.call(x86::r15); | |
a.mov(x86::byte_ptr(x86::rsi), x86::al); | |
} | |
else if (auto setToZero = dynamic_cast<SetToZero*>(node.get())) | |
{ | |
a.mov(x86::byte_ptr(x86::rsi), 0); | |
} | |
else if (auto loop = dynamic_cast<Loop*>(node.get())) | |
{ | |
auto startLabel = a.createLabel(); | |
auto endLabel = a.createLabel(); | |
a.bind(startLabel); | |
a.cmp(x86::byte_ptr(x86::rsi), 0); | |
a.jz(endLabel); | |
generateAsm(a, loop->body); | |
a.jmp(startLabel); | |
a.bind(endLabel); | |
} | |
} | |
} | |
}; | |
size_t estimateCodeSize(const Program& program) | |
{ | |
std::size_t size = 0; | |
for (auto* node = program.getHead(); node != nullptr; node = node->getNext()) | |
{ | |
if (auto* nodeData = node->getIf<Data>(); nodeData != nullptr) | |
{ | |
size += nodeData->getTotalSize(); | |
} | |
else if (auto* nodeInstr = node->getIf<Instruction>(); nodeInstr != nullptr) | |
{ | |
const auto& instrInfo = nodeInstr->getDetail(program.getMode()); | |
if (instrInfo.hasValue()) | |
{ | |
size += instrInfo->getLength(); | |
} | |
else | |
{ | |
fmt::println("Error: Unable to get instruction info"); | |
} | |
} | |
else if (auto* nodeEmbeddedLabel = node->getIf<EmbeddedLabel>(); nodeEmbeddedLabel != nullptr) | |
{ | |
const auto bitSize = nodeEmbeddedLabel->getSize(); | |
if (bitSize == BitSize::_32) | |
size += 4; | |
if (bitSize == BitSize::_64) | |
size += 8; | |
} | |
} | |
return size; | |
} | |
void* allocate(size_t size) | |
{ | |
void* ptr = VirtualAlloc(nullptr, size, MEM_COMMIT | MEM_RESERVE, PAGE_EXECUTE_READWRITE); | |
if (ptr == nullptr) | |
{ | |
throw std::runtime_error("Failed to allocate memory"); | |
} | |
return ptr; | |
} | |
void jitExecute(const Nodes& nodes) | |
{ | |
Program program(MachineMode::AMD64); | |
auto labelFunc = program.createLabel("BrainFuck"); | |
BrainFuckGenerator::generate(labelFunc, program, nodes); | |
const auto codeSize = estimateCodeSize(program); | |
void* code = VirtualAlloc(nullptr, codeSize, MEM_COMMIT | MEM_RESERVE, PAGE_EXECUTE_READWRITE); | |
if (code == nullptr) | |
{ | |
throw std::runtime_error("Failed to allocate memory"); | |
} | |
Serializer serializer; | |
if (auto err = serializer.serialize(program, reinterpret_cast<int64_t>(code)); err != zasm::ErrorCode::None) | |
{ | |
fmt::println("Serialization failure: {}", err.getErrorName()); | |
throw std::runtime_error("Serialization failure"); | |
} | |
memcpy(code, serializer.getCode(), serializer.getCodeSize()); | |
const auto funcAddress = serializer.getLabelAddress(labelFunc.getId()); | |
assert(funcAddress != -1); | |
const auto brainFuckEntry = reinterpret_cast<BrainFuckGenerator::FunctionType>(funcAddress); | |
std::vector<uint8_t> memory(30000, 0); | |
brainFuckEntry(memory.data()); | |
VirtualFree(code, 0, MEM_RELEASE); | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include "brainfuck.h" | |
#include <iostream> | |
#include <fstream> | |
#include <sstream> | |
int main(int argc, char** argv) | |
{ | |
if (argc < 2) | |
{ | |
fmt::println("Usage: {} <file>", argv[0]); | |
return 1; | |
} | |
std::ifstream file(argv[1]); | |
if (!file) | |
{ | |
fmt::println("Failed to open file: {}", argv[1]); | |
return 1; | |
} | |
std::stringstream buffer; | |
buffer << file.rdbuf(); | |
std::string fileContents = buffer.str(); | |
Lexer lexer(fileContents); | |
Parser parser(lexer); | |
auto nodes = parser.parse(); | |
auto optimized = optimizeAST(nodes); | |
/*for( const auto& node : optimized ) | |
{ | |
node->print(); | |
}*/ | |
jitExecute(optimized); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment