Created
November 21, 2024 12:24
-
-
Save notcancername/c0472da5d2a4d7400a967aa63228cf72 to your computer and use it in GitHub Desktop.
risc-v emulator thing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
const std = @import("std"); | |
const assert = std.debug.assert; | |
const Log2Int = std.math.Log2Int; | |
comptime { | |
std.testing.refAllDeclsRecursive(@This()); | |
} | |
const Register = extern union { | |
signed: i32, | |
unsigned: u32, | |
}; | |
const InsnR = packed struct(u32) { | |
opcode: u7, | |
rd: u5, | |
funct3: u3, | |
rs1: u5, | |
rs2: u5, | |
funct7: u7, | |
}; | |
const InsnI = packed struct(u32) { | |
opcode: u7, | |
rd: u5, | |
funct3: u3, | |
rs1: u5, | |
imm0_11: i12, | |
fn getImmediate(i: InsnI) Register { | |
return .{ .signed = i.imm0_11 }; | |
} | |
}; | |
const InsnS = packed struct(u32) { | |
opcode: u7, | |
imm0_4: u5, | |
funct3: u3, | |
rs1: u5, | |
rs2: u5, | |
imm5_11: u7, | |
fn getImmediate(s: InsnS) Register { | |
const low: u12 = s.imm0_4; | |
const high: u12 = s.imm5_11; | |
const full = high << 5 | low; | |
const signed: i12 = @bitCast(full); | |
return .{ .signed = signed }; | |
} | |
}; | |
const InsnB = packed struct(u32) { | |
opcode: u7, | |
imm11: u1, | |
imm1_4: u4, | |
funct3: u3, | |
rs1: u5, | |
rs2: u5, | |
imm5_10: u6, | |
imm12: u1, | |
fn getImmediate(b: InsnB) Register { | |
const low: u13 = b.imm1_4; | |
const high: u13 = b.imm5_10; | |
const higher: u13 = b.imm11; | |
const sign: u13 = b.imm12; | |
const full = sign << 12 | higher << 11 | high << 5 | low << 1; | |
const signed: i13 = @bitCast(full); | |
return .{ .signed = signed }; | |
} | |
}; | |
const InsnU = packed struct(u32) { | |
opcode: u7, | |
rd: u5, | |
imm12_31: u20, | |
fn getImmediate(u: InsnU) Register { | |
const high: u32 = u.imm12_31; | |
const full: u32 = high << 12; | |
const signed: i32 = @bitCast(full); | |
return .{ .signed = signed }; | |
} | |
}; | |
const InsnJ = packed struct(u32) { | |
opcode: u7, | |
rd: u5, | |
imm12_19: u8, | |
imm11: u1, | |
imm1_10: u10, | |
imm20: u1, | |
fn getImmediate(j: InsnJ) Register { | |
const low: u21 = j.imm1_10; | |
const mid: u21 = j.imm11; | |
const high: u21 = j.imm12_19; | |
const sign: u21 = j.imm20; | |
const full: u21 = sign << 20 | high << 12 | mid << 11 | low << 1; | |
const signed: i21 = @bitCast(full); | |
return .{ .signed = signed }; | |
} | |
}; | |
const Insn = extern union { | |
r: InsnR, | |
i: InsnI, | |
s: InsnS, | |
b: InsnB, | |
u: InsnU, | |
j: InsnJ, | |
signed: i32, | |
unsigned: u32, | |
fn getOpcode(insn: Insn) error{InvalidOpcode}!Opcode { | |
const op: u7 = @truncate(insn.unsigned); | |
return std.meta.intToEnum(Opcode, op) orelse error.InvalidOpcode; | |
} | |
}; | |
pub const Opcode = enum(u7) { | |
// zig fmt: off | |
load = 0b0000011, | |
store = 0b0100011, | |
madd = 0b1000011, | |
branch = 0b1100011, | |
load_fp = 0b0000111, | |
store_fp = 0b0100111, | |
msub = 0b1000111, | |
jalr = 0b1100111, | |
nmsub = 0b1001011, | |
misc_mem = 0b0001111, | |
amo = 0b0101111, | |
nmadd = 0b1001111, | |
jal = 0b1101111, | |
op_imm = 0b0010011, | |
op = 0b0110011, | |
op_fp = 0b1010011, | |
system = 0b1110011, | |
auipc = 0b0010111, | |
lui = 0b0110111, | |
op_imm_32 = 0b0011011, | |
op_32 = 0b0111011, | |
// zig fmt: on | |
}; | |
fn panicFmt(comptime fmt: []const u8, args: anytype) noreturn { | |
@setCold(true); | |
var b: [8192]u8 = undefined; | |
@panic(std.fmt.bufPrint(&b, fmt, args) catch fmt); | |
} | |
const ExecutionState = struct { | |
registers: [32]Register = [1]Register{.{ .unsigned = 0 }} ** 32, | |
pc: Register = .{.unsigned = 0}, | |
instructions: []const Insn, | |
memory: []u8, | |
context: ?*anyopaque = null, | |
ebreak_handler: *const fn (?*anyopaque, *ExecutionState) void = handlers.nop, | |
ecall_handler: *const fn (?*anyopaque, *ExecutionState) void = handlers.nop, | |
const handlers = struct { | |
fn nop(_: ?*anyopaque, _: *ExecutionState) void {} | |
}; | |
pub fn format(value: ExecutionState, comptime _: []const u8, _: std.fmt.FormatOptions, writer: anytype) !void { | |
try writer.print("pc = 0x{x}\n", .{value.pc}); | |
inline for (.{ | |
"zero", | |
"ra", | |
"sp", | |
"gp", | |
"tp", | |
"t0", | |
"t1", | |
"t2", | |
"s0", | |
"s1", | |
"a0", | |
"a1", | |
"a2", | |
"a3", | |
"a4", | |
"a5", | |
"a6", | |
"a7", | |
"s2", | |
"s3", | |
"s4", | |
"s5", | |
"s6", | |
"s7", | |
"s8", | |
"s9", | |
"s10", | |
"s11", | |
"t3", | |
"t4", | |
"t5", | |
"t6", | |
}, 0..) |name, index| { | |
try writer.print("{s} = 0x{x}\n", .{ name, value.registers[index] }); | |
} | |
} | |
inline fn loadRegister(s: *ExecutionState, r: u5) Register { | |
assert(s.registers[0] == 0); | |
return s.registers[r]; | |
} | |
inline fn storeRegister(s: *ExecutionState, r: u5, w: anytype) void { | |
s.registers[r] = if (r == 0) 0 else @as(i32, @bitCast(w)); | |
} | |
// signExtend and zeroExtend exist separately to prevent mistakes. | |
inline fn signExtend(w: anytype) Register { | |
return .{ .signed = w }; | |
} | |
inline fn zeroExtend(w: anytype) Register { | |
return .{ .unsigned = w }; | |
} | |
inline fn load(s: *ExecutionState, addr: Register, comptime T: type) T { | |
return std.mem.readIntNative(T, s.memory[@intCast(addr.unsigned)..][0..@sizeOf(T)]); | |
} | |
inline fn store(s: *ExecutionState, addr: Register, comptime T: type, v: T) void { | |
const start = std.math.cast(usize, addr.unsigned) orelse | |
panicFmt("addr out of range: {d}: {}", .{ addr, s }); | |
const len = @sizeOf(T); | |
if (start + len > s.memory.len) | |
panicFmt("oob store of {d} byte(s) at 0x{x} (memory was 0x{x} long): {}", .{ len, start, s.memory.len, s }); | |
std.mem.writeIntNative(T, s.memory[start..][0..len], v); | |
} | |
fn step(s: *ExecutionState) ?void { | |
const insn_len = @sizeOf(Insn); | |
if(s.pc % insn_len != 0) { | |
@panic("TODO: handle unaligned pc"); | |
} | |
const cur_insn: usize = @intCast(@divExact(s.pc, insn_len)); | |
if (cur_insn >= s.instructions.len) return null; | |
const insn = s.instructions[cur_insn]; | |
s.interpretInsn(insn); | |
s.pc +%= insn_len; | |
} | |
fn interpretInsn(s: *ExecutionState, insn: Insn) void { | |
const opc = insn.getOpcode catch panicFmt("invalid opcode in insn 0x{x}\n{}", .{ insn, s }); | |
switch (opc) { | |
.lui => { | |
const u = insn.u; | |
s.storeRegister(u.rd, u.getImmediate()); | |
}, | |
.auipc => { | |
const u = insn.u; | |
s.storeRegister(u.rd, u.getImmediate() +% s.pc); | |
}, | |
.jal => { | |
const j = insn.j; | |
const ret_addr = s.pc +% 4; | |
s.storeRegister(j.rd, ret_addr); | |
const target = s.pc.unsigned +% j.getImmediate().unsigned << 1; | |
if (target & 0b11 != 0) @panic("misaligned jal"); | |
s.pc.unsigned = target; | |
}, | |
.jalr => { | |
const i = insn.i; | |
const ret_addr = s.pc +% 4; | |
s.storeRegister(i.rd, ret_addr); | |
const target = (s.loadRegister(i.rs1).unsigned +% i.getImmediate().unsigned) >> 1 << 1; | |
if (target & 0b11 != 0) @panic("misaligned jalr"); | |
s.pc.unsigned = target; | |
}, | |
.branch => { | |
const b = insn.b; | |
const x = s.loadRegister(b.rs1); | |
const y = s.loadRegister(b.rs2); | |
const cond = switch (b.funct3) { | |
0b000 => x.signed == y.signed, // beq | |
0b001 => x.signed != y.signed, // bne | |
0b100 => x.signed < y.signed, // blt | |
0b101 => x.signed >= y.signed, // bge | |
0b110 => x.unsigned < y.unsigned, // bltu | |
0b111 => x.unsigned >= y.unsigned, // bgeu | |
else => @panic("invalid branch funct3"), | |
}; | |
if (cond) s.pc.unsigned +%= b.getImmediate().unsigned -% @sizeOf(u32); | |
}, | |
.load => { | |
const i = insn.i; | |
const addr: Register = .{ .unsigned = s.loadRegister(i.rs1).unsigned + i.getImmediate().unsigned }; | |
const value = switch (i.funct3) { | |
0b000 => signExtend(s.load(addr, i8)), | |
0b001 => signExtend(s.load(addr, i16)), | |
0b010 => signExtend(s.load(addr, i32)), | |
0b011 => @panic("64 bit insn"), | |
0b100 => zeroExtend(s.load(addr, u8)), | |
0b101 => zeroExtend(s.load(addr, u16)), | |
0b110 => @panic("64 bit insn"), | |
else => @panic("invalid load funct3"), | |
}; | |
s.storeRegister(i.rd, value); | |
}, | |
.store => { | |
const is = insn.s; | |
const addr: Register = .{ .unsigned = s.loadRegister(is.rs1).unsigned +% is.getImmediate().unsigned }; | |
const value = s.loadRegister(is.rs2).unsigned; | |
switch (is.funct3) { | |
0b000 => s.store(addr, u8, @as(u8, @truncate(value))), | |
0b001 => s.store(addr, u16, @as(u16, @truncate(value))), | |
0b010 => s.store(addr, u32, @as(u32, @truncate(value))), | |
0b011 => @panic("64-bit insn"), | |
else => @panic("invalid store funct3"), | |
} | |
}, | |
.op_imm => { | |
const i = insn.i; | |
const r = insn.r; | |
const a = s.loadRegister(i.rs1); | |
const b = i.getImmediate(); | |
const shamt = std.math.cast(Log2Int(i32), r.rs2) catch @panic("rs2 don't fit"); | |
const value: Register = switch (i.funct3) { | |
0b000 => .{.signed = a.signed +% b.signed }, // addi | |
0b010 => .{.unsigned = @intFromBool(a.signed < b.signed) }, // slti | |
0b011 => .{.unsigned = @intFromBool(a.unsigned < b.unsigned) }, // sltiu | |
0b100 => .{.unsigned = a.unsigned ^ b.unsigned}, // xori | |
0b110 => .{.unsigned = a.unsigned | b.unsigned}, // ori | |
0b111 => .{.unsigned = a.unsigned & b.unsigned}, // andi | |
0b001 => .{.unsigned = a.unsigned << shamt}, // slli | |
0b101 => switch(r.funct7) { | |
0b0000000 => .{.unsigned = a.unsigned >> shamt}, // srli | |
0b0100000 => .{.unsigned = a.signed >> shamt}, // srai | |
else => @panic("invalid op_imm funct7"), | |
}, | |
}; | |
s.storeRegister(i.rd, value); | |
}, | |
.op => { | |
const r = insn.r; | |
const a = s.loadRegister(r.rs1); | |
const b = s.loadRegister(r.rs2); | |
const shamt = std.math.cast(Log2Int(i32), b) catch @panic("b don't fit"); | |
const value = switch (r.funct3) { | |
0b000 => switch (r.funct7) { | |
0b0000000 => a +% b, // add | |
0b0100000 => a -% b, // sub | |
0b0000001 => a *% b, // mul | |
else => @panic("invalid op funct7"), | |
}, | |
0b001 => switch (r.funct7) { | |
0 => @intFromBool(a < b), | |
1 => @as(i32, @intCast(std.math.mulWide(i32, a, b) >> 32)), | |
else => @panic("invalid op funct7"), | |
}, | |
0b010 => switch (r.funct7) { | |
1 => @as(i32, @intCast(@as(i64, a) *% @as(i64, @as(u32, @bitCast(b))) >> 32)), | |
else => @panic("invalid op funct7"), | |
}, | |
0b011 => switch (r.funct7) { | |
0 => @intFromBool(@as(u32, @bitCast(a)) < @as(u32, @bitCast(b))), | |
1 => @as(i32, @intCast(@as(i64, @as(u32, @bitCast(a))) *% @as(i64, @as(u32, @bitCast(b))) >> 32)), | |
else => @panic("invalid op funct7"), | |
}, | |
0b100 => switch (r.funct7) { | |
0 => a ^ b, | |
1 => @divTrunc(a, b), | |
else => @panic("invalid op funct7"), | |
}, | |
0b101 => switch (r.funct7) { | |
0 => @as(i32, @bitCast(@as(u32, @bitCast(a)) >> shamt)), | |
0b0100000 => a >> shamt, | |
1 => @as(i32, @bitCast(@as(u32, @bitCast(a)) / @as(u32, @bitCast(b)))), | |
else => @panic("invalid op funct7"), | |
}, | |
0b110 => switch (r.funct7) { | |
0 => a | b, | |
1 => @rem(a, b), | |
else => @panic("invalid op funct7"), | |
}, | |
0b111 => a & b, | |
}; | |
s.storeRegister(r.rd, value); | |
}, | |
.system => { | |
const i = insn.i; | |
switch (i.funct3) { | |
0b000 => switch (i.imm0_11) { | |
0 => s.ecall_handler(s.context, s), | |
1 => s.ebreak_handler(s.context, s), | |
else => @panic("invalid imm system"), | |
}, | |
0b001 => { | |
@panic("TODO IMPLEMENT"); | |
}, | |
else => @panic("invalid funct3 system"), | |
} | |
}, | |
else => |op| { | |
var b: [512]u8 = undefined; | |
@panic(std.fmt.bufPrint(&b, "TODO opcode {} insn 0x{x}\n", .{ op, insn }) catch unreachable); | |
}, | |
} | |
} | |
}; | |
// const jit_instructions = struct { | |
// const names = [_][]const u8{ "rax", "rbx", "rcx", "rdx", "rdi", "rsi", "rbp", "rsp", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15" }; | |
// inline fn get(comptime r: usize) u64 { | |
// return asm volatile ("mov %" ++ names[r] ++ ", %[dst]" | |
// : [dst] "=r" (-> u64), | |
// ); | |
// } | |
// inline fn set(comptime r: usize, v: u64) void { | |
// return asm volatile ("mov %[src], %" ++ names[r] | |
// : | |
// : [src] "r" (v), | |
// ); | |
// } | |
// export fn save_regs_and_call(f: *const fn () void) void { | |
// const g = struct { | |
// var regs: [16]u64 = undefined; | |
// }; | |
// inline for (0..16) |i| g.regs[i] = get(i); | |
// f(); | |
// inline for (0..16) |i| set(i, g.regs[i]); | |
// } | |
// const save_regs_and_call_addr = &save_regs_and_call; | |
// const GeneratedInsn = union(enum) {}; | |
// }; | |
// test "example program" { | |
// const instructions = [_]u32{ | |
// 0x00f5d613, | |
// 0x00a64633, | |
// 0x735a36b7, | |
// 0xd9768693, | |
// 0x00d60633, | |
// 0x00a64633, | |
// 0x00f65513, | |
// 0x00b54533, | |
// 0xcaf656b7, | |
// 0x9a968693, | |
// 0x00d50533, | |
// 0x00b545b3, | |
// 0x0105d513, | |
// 0x00c54533, | |
// 0x01065613, | |
// 0x00c5c5b3, | |
// }; | |
// var s = ExecutionState{ | |
// .memory = &.{}, | |
// .instructions = @ptrCast(@as([]const u32, &instructions)), | |
// }; | |
// while (s.step()) |_| {} | |
// try std.testing.expectEqual(@as(i32, 1935337312), s.registers[10]); | |
// try std.testing.expectEqual(@as(i32, -889765113), s.registers[11]); | |
// try std.testing.expectEqual(@as(i32, 29530), s.registers[12]); | |
// try std.testing.expectEqual(@as(i32, -889828951), s.registers[13]); | |
// } | |
// test "addi" { | |
// const instructions = [_]u32{ | |
// 0x53900093, | |
// 0xac700113, | |
// 0xe5c00193, | |
// 0x1a400213, | |
// }; | |
// var s = ExecutionState{ | |
// .instructions = &instructions, | |
// .memory = &.{}, | |
// }; | |
// while (s.step()) |_| {} | |
// try std.testing.expectEqual(@as(i32, 1337), s.registers[1]); | |
// try std.testing.expectEqual(@as(i32, -1337), s.registers[2]); | |
// try std.testing.expectEqual(@as(i32, -420), s.registers[3]); | |
// try std.testing.expectEqual(@as(i32, 420), s.registers[4]); | |
// } | |
// test "hello world" { | |
// const instructions = [_]u32{ | |
// 0x04800513, | |
// 0x00a00023, | |
// 0x06500513, | |
// 0x00a000a3, | |
// 0x06c00513, | |
// 0x00a00123, | |
// 0x00a001a3, | |
// 0x06f00593, | |
// 0x00b00223, | |
// 0x02c00613, | |
// 0x00c002a3, | |
// 0x02000613, | |
// 0x00c00323, | |
// 0x05700613, | |
// 0x00c003a3, | |
// 0x00b00423, | |
// 0x07200593, | |
// 0x00b004a3, | |
// 0x00a00523, | |
// 0x06400513, | |
// 0x00a005a3, | |
// }; | |
// var memory: ["Hello, World"[0..].len]u8 = undefined; | |
// var s = ExecutionState{ | |
// .instructions = &instructions, | |
// .memory = &memory, | |
// }; | |
// while (s.step()) |_| {} | |
// try std.testing.expectEqualSlices(u8, "Hello, World", &memory); | |
// } | |
// test "ecall" { | |
// const instructions = [_]u32{ | |
// // store "Hello, World" in 0 | |
// 0x04800513, | |
// 0x00a00023, | |
// 0x06500513, | |
// 0x00a000a3, | |
// 0x06c00513, | |
// 0x00a00123, | |
// 0x00a001a3, | |
// 0x06f00593, | |
// 0x00b00223, | |
// 0x02c00613, | |
// 0x00c002a3, | |
// 0x02000613, | |
// 0x00c00323, | |
// 0x05700613, | |
// 0x00c003a3, | |
// 0x00b00423, | |
// 0x07200593, | |
// 0x00b004a3, | |
// 0x00a00523, | |
// 0x06400513, | |
// 0x00a005a3, | |
// // Ecall with a0 = 69, a1 = 0, a2 = 12 | |
// 0x04500513, | |
// 0x00000593, | |
// 0x00c00613, | |
// 0x00000073, | |
// }; | |
// var memory: [8192]u8 = undefined; | |
// const HandlerState = struct { | |
// got_ecall: bool = false, | |
// fn handleEcall(state_p: ?*anyopaque, s: *ExecutionState) void { | |
// var state = @as(*@This(), @ptrCast(@alignCast(state_p))); | |
// if (s.registers[10] == 69 and s.registers[11] == 0 and s.registers[12] == 12) | |
// state.got_ecall = true; | |
// } | |
// }; | |
// var hs = HandlerState{}; | |
// var s = ExecutionState{ | |
// .instructions = &instructions, | |
// .memory = &memory, | |
// .context = @as(?*anyopaque, @ptrCast(&hs)), | |
// .ecall_handler = HandlerState.handleEcall, | |
// }; | |
// while (s.step()) |_| {} | |
// try std.testing.expect(hs.got_ecall); | |
// } | |
// test "branching" { | |
// const instructions = [_]u32{ | |
// 0x04500593, // li a1, 69 | |
// 0x00b51863, // bne a0, a1, .LBB0_2 | |
// 0x00100513, // li a0, 1 | |
// 0x00000073, // ecall | |
// 0x00008067, // ret | |
// 0x00000513, // .LBB0_2: li a0, 0 | |
// 0x00000073, // ecall | |
// 0x00008067, // ret | |
// }; | |
// const HandlerState = struct { | |
// arg1: ?i32 = null, | |
// fn handleEcall(state_p: ?*anyopaque, s: *ExecutionState) void { | |
// var state = @as(*@This(), @ptrCast(@alignCast(state_p))); | |
// state.arg1 = s.registers[10]; | |
// } | |
// }; | |
// var hs = HandlerState{}; | |
// var s = ExecutionState{ | |
// .instructions = &instructions, | |
// .memory = &.{}, | |
// .context = @as(?*anyopaque, @ptrCast(&hs)), | |
// .ecall_handler = HandlerState.handleEcall, | |
// }; | |
// s.registers[1] = instructions.len * @sizeOf(u32) + 1; | |
// s.registers[10] = 69; | |
// while (s.step()) |_| {} | |
// try std.testing.expectEqual(@as(i32, 1), hs.arg1.?); | |
// } | |
// test "matmul non-vector" { | |
// const instructions = b: { | |
// const t = comptime std.mem.bytesAsSlice(u32, @embedFile("matmul_test.bin")); | |
// var d: [t.len]u32 = undefined; | |
// @memcpy(&d, t); | |
// break :b d; | |
// }; | |
// const stack_size = 8192; | |
// var memory = std.testing.allocator.create([stack_size + 16 * 3]u8) catch unreachable; | |
// defer std.testing.allocator.free(memory); | |
// @memset(memory[0..stack_size], undefined); | |
// @memset(memory[stack_size..], 2); | |
// var s = ExecutionState{ | |
// .instructions = &instructions, | |
// .memory = memory, | |
// }; | |
// s.registers[1] = @as(i32, @intCast((instructions.len + 1) * @sizeOf(u32))); // return address | |
// s.registers[2] = stack_size; // stack pointer | |
// s.registers[10] = stack_size; // dst | |
// s.registers[11] = stack_size + 16; // a | |
// s.registers[12] = stack_size + 16 * 2; // b | |
// while (s.step()) |_| {} | |
// for (memory[stack_size..][0..16]) |e| try std.testing.expectEqual(@as(u8, 16), e); | |
// } | |
comptime { | |
@compileLog(@as(InsnI, @bitCast(@as(u32, 0x22858407)))); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment