Created
June 17, 2023 23:34
-
-
Save ItsDrike/9311c933e61139ed34c6e830b33f0bac to your computer and use it in GitHub Desktop.
Simple rust interpreter
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::collections::{HashMap, HashSet}; | |
#[derive(Debug)] | |
enum Ast { | |
Const(f64), | |
Var(u32), | |
Add(Box<Ast>, Box<Ast>), | |
Assign(u32, Box<Ast>), | |
Print(Box<Ast>), | |
Block(Vec<Ast>), | |
} | |
// Byte code instructions | |
#[derive(Debug, Clone, Copy)] | |
enum Instr { | |
Load { dst: u32, value: f64 }, | |
Copy { dst: u32, src: u32 }, | |
Add { dst: u32, src1: u32, src2: u32 }, | |
Print { src: u32 }, | |
} | |
struct Codegen { | |
code: Vec<Instr>, | |
dst_reg_map: Vec<Option<usize>>, | |
reg_last_written: Vec<usize>, | |
instrs_used: Vec<bool>, | |
} | |
impl Codegen { | |
fn new(num_locals: u32) -> Self { | |
Self { | |
code: vec![], | |
dst_reg_map: vec![None; num_locals as usize], | |
reg_last_written: vec![0; num_locals as usize], | |
instrs_used: vec![], | |
} | |
} | |
fn generate_bytecode(ast: &Ast) -> Vec<Instr> { | |
let num_locals = Self::count_locals(ast, &mut 0, &mut HashSet::new()); | |
println!( | |
"Found {} distinct local variables in given AST.", | |
num_locals | |
); | |
let mut cg = Codegen::new(num_locals); | |
cg.codegen(&ast); | |
let mut instructions: Vec<Instr> = vec![]; | |
for (instr, enabled) in cg.code.iter().zip(cg.instrs_used.iter()) { | |
if !enabled { | |
print!("[DISABLED] "); | |
} else { | |
instructions.push(*instr); | |
} | |
println!("{:?}", instr); | |
} | |
instructions | |
} | |
fn count_locals(ast: &Ast, tracked_amt: &mut u32, seen: &mut HashSet<u32>) -> u32 { | |
match ast { | |
Ast::Add(inner_ast1, inner_ast2) => { | |
Self::count_locals(inner_ast1, tracked_amt, seen); | |
Self::count_locals(inner_ast2, tracked_amt, seen); | |
} | |
Ast::Assign(register_no, inner_ast) => { | |
if !seen.contains(register_no) { | |
seen.insert(*register_no); | |
*tracked_amt += 1; | |
} | |
Self::count_locals(inner_ast, tracked_amt, seen); | |
} | |
Ast::Print(inner_ast) => { | |
Self::count_locals(inner_ast, tracked_amt, seen); | |
} | |
Ast::Block(inner_asts) => { | |
for inner_ast in inner_asts { | |
Self::count_locals(inner_ast, tracked_amt, seen); | |
} | |
} | |
_ => {} | |
}; | |
return *tracked_amt; | |
} | |
fn codegen(&mut self, ast: &Ast) -> u32 { | |
let result = match ast { | |
Ast::Const(value) => { | |
let dst = self.alloc_reg(); | |
self.add_instr(Instr::Load { dst, value: *value }); | |
dst | |
} | |
Ast::Var(src) => { | |
let dst = self.alloc_reg(); | |
self.add_instr(Instr::Copy { dst, src: *src }); | |
dst | |
} | |
Ast::Add(left, right) => { | |
println!(":=Add LHS {:?}", left); | |
let src1 = self.codegen(left); | |
println!(":=Add RHS {:?}", right); | |
let src2 = self.codegen(right); | |
let src1 = self.forward_copy_src(src1); | |
let src2 = self.forward_copy_src(src2); | |
let dst = self.alloc_reg(); | |
self.add_instr(Instr::Add { dst, src1, src2 }); | |
dst | |
} | |
Ast::Assign(dst, src) => { | |
let dst = *dst; | |
let src = self.codegen(src); | |
let src = self.forward_copy_src(src); | |
self.add_instr(Instr::Copy { dst, src }); | |
u32::MAX | |
} | |
Ast::Print(src) => { | |
let src = self.codegen(src); | |
let src = self.forward_copy_src(src); | |
self.add_instr(Instr::Print { src }); | |
u32::MAX | |
} | |
Ast::Block(children) => { | |
let mut result = u32::MAX; | |
for child in children { | |
// last child is the return value from this block | |
result = self.codegen(child); | |
} | |
result | |
} | |
}; | |
println!( | |
"AST: {:?}\nreg_map: {:?}\nLast written: {:?}\nInstr #{} -> {:?}\n", | |
ast, | |
self.dst_reg_map, | |
self.reg_last_written, | |
self.code.len() - 1, | |
self.code.last().unwrap(), | |
); | |
result | |
} | |
fn forward_copy_src(&mut self, src: u32) -> u32 { | |
// If the instruction that set this register (to be used as source) was a copy instr, | |
// use the original source register, instead of the copied one (copy dst), when this | |
// copy isn't necessary (when it's src wasn't written to). Otherwise, return the | |
// passed src (from copy). | |
// | |
// This will keep the copy instruction here, without actually being used, so also mark it | |
// unused. | |
if let Some(instr_no) = self.dst_reg_map[src as usize] { | |
if let Instr::Copy { dst: _, src } = self.code[instr_no] { | |
// We can only do this if the instruction that last changed this register | |
// is before the copy instruction. Otherwise, we would lose any potential | |
// changes made between the copy and the instruction that set this register | |
let last_write_instr_no = self.reg_last_written[src as usize]; | |
if last_write_instr_no <= instr_no { | |
println!( | |
"->Forwarding src (unneeded) copy instr #{} {:?}, src={}", | |
instr_no, self.code[instr_no], src | |
); | |
self.instrs_used[instr_no] = false; | |
return src; | |
} | |
} | |
} | |
src | |
} | |
fn alloc_reg(&mut self) -> u32 { | |
let result = self.dst_reg_map.len() as u32; | |
self.dst_reg_map.push(None); | |
self.reg_last_written.push(self.code.len()); | |
result | |
} | |
fn add_instr(&mut self, instr: Instr) { | |
// Skip writing a new copy instruction, if this copy's source register was | |
// assigned as some other existing instruction's destination. In this case, | |
// just modify this existing instruction and make the copy's destination (target) | |
// it's new destination. | |
if let Instr::Copy { dst, src } = instr { | |
if let Some(instr_no) = self.dst_reg_map[src as usize] { | |
let instr_dst = match &mut self.code[instr_no] { | |
Instr::Load { dst, value: _ } => Some(dst), | |
Instr::Copy { dst, src: _ } => Some(dst), | |
Instr::Add { | |
dst, | |
src1: _, | |
src2: _, | |
} => Some(dst), | |
Instr::Print { src: _ } => None, | |
}; | |
if let Some(instr_dst) = instr_dst { | |
// We can only forward the destination (override dst of previous instruction and | |
// skip copy) when there wasn't anything else between this copy and the prev instr | |
// that already wrote to the copy's destination (that would mean the copy would | |
// actually change the value in that register, overriding would mean this | |
// change wouldn't happen). Note that this situation can't even happen with the | |
// current implementation, but if it would happen, this would handle it. | |
let latest_write_instr_no = self.reg_last_written[dst as usize]; | |
if instr_no >= latest_write_instr_no { | |
println!( | |
"->Forwarding dst from (skipped) copy to instr #{:?}, new dst: {} (was: {})", | |
instr_no, dst, instr_dst | |
); | |
*instr_dst = dst; | |
self.reg_last_written[dst as usize] = instr_no; | |
return; | |
} | |
} | |
} | |
} | |
// Check if this instruction changes some register, if so, keep track of the change | |
let dst = match instr { | |
Instr::Load { dst, value: _ } => Some(dst), | |
Instr::Copy { dst, src: _ } => Some(dst), | |
Instr::Add { | |
dst, | |
src1: _, | |
src2: _, | |
} => Some(dst), | |
Instr::Print { src: _ } => None, | |
}; | |
if let Some(dst) = dst { | |
let instr_index = self.code.len(); | |
self.dst_reg_map[dst as usize] = Some(instr_index); | |
self.reg_last_written[dst as usize] = instr_index; | |
} | |
println!("->Adding {:?}", instr); | |
self.code.push(instr); | |
// Assume all instructions are used by default | |
self.instrs_used.push(true); | |
} | |
} | |
fn interpret(bytecode: &Vec<Instr>) { | |
let mut registers: HashMap<u32, f64> = HashMap::new(); | |
for instruction in bytecode { | |
match instruction { | |
Instr::Load { dst, value } => { | |
registers.insert(*dst, *value); | |
} | |
Instr::Copy { dst, src } => { | |
registers.insert(*dst, registers[src]); | |
} | |
Instr::Add { dst, src1, src2 } => { | |
registers.insert(*dst, registers[src1] + registers[src2]); | |
} | |
Instr::Print { src } => { | |
println!("INTERPRETER PRINT: {}", registers[src]); | |
} | |
} | |
} | |
} | |
fn main() { | |
if 0 == 1 { | |
// a = 1 | |
// b = 2 | |
// c = a + b | |
// print(c) | |
let ast = Ast::Block(vec![ | |
Ast::Assign(0, Box::new(Ast::Const(1.0))), | |
Ast::Assign(1, Box::new(Ast::Const(2.0))), | |
Ast::Assign( | |
2, | |
Box::new(Ast::Add(Box::new(Ast::Var(0)), Box::new(Ast::Var(1)))), | |
), | |
Ast::Print(Box::new(Ast::Var(2))), | |
]); | |
let bytecode = Codegen::generate_bytecode(&ast); | |
interpret(&bytecode); | |
} | |
if 1 == 1 { | |
// a = 1 | |
// print(a + { a = a + 1; a }) | |
let ast = Ast::Block(vec![ | |
Ast::Assign(0, Box::new(Ast::Const(1.0))), | |
Ast::Print(Box::new(Ast::Add( | |
Box::new(Ast::Var(0)), | |
Box::new(Ast::Block(vec![ | |
Ast::Assign( | |
0, | |
Box::new(Ast::Add(Box::new(Ast::Var(0)), Box::new(Ast::Const(1.0)))), | |
), | |
Ast::Var(0), | |
])), | |
))), | |
]); | |
let bytecode = Codegen::generate_bytecode(&ast); | |
interpret(&bytecode); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment