Skip to content

Instantly share code, notes, and snippets.

@notcancername
Created November 21, 2024 12:24
Show Gist options
  • Save notcancername/c0472da5d2a4d7400a967aa63228cf72 to your computer and use it in GitHub Desktop.
Save notcancername/c0472da5d2a4d7400a967aa63228cf72 to your computer and use it in GitHub Desktop.
risc-v emulator thing
const std = @import("std");
const assert = std.debug.assert;
const Log2Int = std.math.Log2Int;
comptime {
std.testing.refAllDeclsRecursive(@This());
}
const Register = extern union {
signed: i32,
unsigned: u32,
};
const InsnR = packed struct(u32) {
opcode: u7,
rd: u5,
funct3: u3,
rs1: u5,
rs2: u5,
funct7: u7,
};
const InsnI = packed struct(u32) {
opcode: u7,
rd: u5,
funct3: u3,
rs1: u5,
imm0_11: i12,
fn getImmediate(i: InsnI) Register {
return .{ .signed = i.imm0_11 };
}
};
const InsnS = packed struct(u32) {
opcode: u7,
imm0_4: u5,
funct3: u3,
rs1: u5,
rs2: u5,
imm5_11: u7,
fn getImmediate(s: InsnS) Register {
const low: u12 = s.imm0_4;
const high: u12 = s.imm5_11;
const full = high << 5 | low;
const signed: i12 = @bitCast(full);
return .{ .signed = signed };
}
};
const InsnB = packed struct(u32) {
opcode: u7,
imm11: u1,
imm1_4: u4,
funct3: u3,
rs1: u5,
rs2: u5,
imm5_10: u6,
imm12: u1,
fn getImmediate(b: InsnB) Register {
const low: u13 = b.imm1_4;
const high: u13 = b.imm5_10;
const higher: u13 = b.imm11;
const sign: u13 = b.imm12;
const full = sign << 12 | higher << 11 | high << 5 | low << 1;
const signed: i13 = @bitCast(full);
return .{ .signed = signed };
}
};
const InsnU = packed struct(u32) {
opcode: u7,
rd: u5,
imm12_31: u20,
fn getImmediate(u: InsnU) Register {
const high: u32 = u.imm12_31;
const full: u32 = high << 12;
const signed: i32 = @bitCast(full);
return .{ .signed = signed };
}
};
const InsnJ = packed struct(u32) {
opcode: u7,
rd: u5,
imm12_19: u8,
imm11: u1,
imm1_10: u10,
imm20: u1,
fn getImmediate(j: InsnJ) Register {
const low: u21 = j.imm1_10;
const mid: u21 = j.imm11;
const high: u21 = j.imm12_19;
const sign: u21 = j.imm20;
const full: u21 = sign << 20 | high << 12 | mid << 11 | low << 1;
const signed: i21 = @bitCast(full);
return .{ .signed = signed };
}
};
const Insn = extern union {
r: InsnR,
i: InsnI,
s: InsnS,
b: InsnB,
u: InsnU,
j: InsnJ,
signed: i32,
unsigned: u32,
fn getOpcode(insn: Insn) error{InvalidOpcode}!Opcode {
const op: u7 = @truncate(insn.unsigned);
return std.meta.intToEnum(Opcode, op) orelse error.InvalidOpcode;
}
};
pub const Opcode = enum(u7) {
// zig fmt: off
load = 0b0000011,
store = 0b0100011,
madd = 0b1000011,
branch = 0b1100011,
load_fp = 0b0000111,
store_fp = 0b0100111,
msub = 0b1000111,
jalr = 0b1100111,
nmsub = 0b1001011,
misc_mem = 0b0001111,
amo = 0b0101111,
nmadd = 0b1001111,
jal = 0b1101111,
op_imm = 0b0010011,
op = 0b0110011,
op_fp = 0b1010011,
system = 0b1110011,
auipc = 0b0010111,
lui = 0b0110111,
op_imm_32 = 0b0011011,
op_32 = 0b0111011,
// zig fmt: on
};
fn panicFmt(comptime fmt: []const u8, args: anytype) noreturn {
@setCold(true);
var b: [8192]u8 = undefined;
@panic(std.fmt.bufPrint(&b, fmt, args) catch fmt);
}
const ExecutionState = struct {
registers: [32]Register = [1]Register{.{ .unsigned = 0 }} ** 32,
pc: Register = .{.unsigned = 0},
instructions: []const Insn,
memory: []u8,
context: ?*anyopaque = null,
ebreak_handler: *const fn (?*anyopaque, *ExecutionState) void = handlers.nop,
ecall_handler: *const fn (?*anyopaque, *ExecutionState) void = handlers.nop,
const handlers = struct {
fn nop(_: ?*anyopaque, _: *ExecutionState) void {}
};
pub fn format(value: ExecutionState, comptime _: []const u8, _: std.fmt.FormatOptions, writer: anytype) !void {
try writer.print("pc = 0x{x}\n", .{value.pc});
inline for (.{
"zero",
"ra",
"sp",
"gp",
"tp",
"t0",
"t1",
"t2",
"s0",
"s1",
"a0",
"a1",
"a2",
"a3",
"a4",
"a5",
"a6",
"a7",
"s2",
"s3",
"s4",
"s5",
"s6",
"s7",
"s8",
"s9",
"s10",
"s11",
"t3",
"t4",
"t5",
"t6",
}, 0..) |name, index| {
try writer.print("{s} = 0x{x}\n", .{ name, value.registers[index] });
}
}
inline fn loadRegister(s: *ExecutionState, r: u5) Register {
assert(s.registers[0] == 0);
return s.registers[r];
}
inline fn storeRegister(s: *ExecutionState, r: u5, w: anytype) void {
s.registers[r] = if (r == 0) 0 else @as(i32, @bitCast(w));
}
// signExtend and zeroExtend exist separately to prevent mistakes.
inline fn signExtend(w: anytype) Register {
return .{ .signed = w };
}
inline fn zeroExtend(w: anytype) Register {
return .{ .unsigned = w };
}
inline fn load(s: *ExecutionState, addr: Register, comptime T: type) T {
return std.mem.readIntNative(T, s.memory[@intCast(addr.unsigned)..][0..@sizeOf(T)]);
}
inline fn store(s: *ExecutionState, addr: Register, comptime T: type, v: T) void {
const start = std.math.cast(usize, addr.unsigned) orelse
panicFmt("addr out of range: {d}: {}", .{ addr, s });
const len = @sizeOf(T);
if (start + len > s.memory.len)
panicFmt("oob store of {d} byte(s) at 0x{x} (memory was 0x{x} long): {}", .{ len, start, s.memory.len, s });
std.mem.writeIntNative(T, s.memory[start..][0..len], v);
}
fn step(s: *ExecutionState) ?void {
const insn_len = @sizeOf(Insn);
if(s.pc % insn_len != 0) {
@panic("TODO: handle unaligned pc");
}
const cur_insn: usize = @intCast(@divExact(s.pc, insn_len));
if (cur_insn >= s.instructions.len) return null;
const insn = s.instructions[cur_insn];
s.interpretInsn(insn);
s.pc +%= insn_len;
}
fn interpretInsn(s: *ExecutionState, insn: Insn) void {
const opc = insn.getOpcode catch panicFmt("invalid opcode in insn 0x{x}\n{}", .{ insn, s });
switch (opc) {
.lui => {
const u = insn.u;
s.storeRegister(u.rd, u.getImmediate());
},
.auipc => {
const u = insn.u;
s.storeRegister(u.rd, u.getImmediate() +% s.pc);
},
.jal => {
const j = insn.j;
const ret_addr = s.pc +% 4;
s.storeRegister(j.rd, ret_addr);
const target = s.pc.unsigned +% j.getImmediate().unsigned << 1;
if (target & 0b11 != 0) @panic("misaligned jal");
s.pc.unsigned = target;
},
.jalr => {
const i = insn.i;
const ret_addr = s.pc +% 4;
s.storeRegister(i.rd, ret_addr);
const target = (s.loadRegister(i.rs1).unsigned +% i.getImmediate().unsigned) >> 1 << 1;
if (target & 0b11 != 0) @panic("misaligned jalr");
s.pc.unsigned = target;
},
.branch => {
const b = insn.b;
const x = s.loadRegister(b.rs1);
const y = s.loadRegister(b.rs2);
const cond = switch (b.funct3) {
0b000 => x.signed == y.signed, // beq
0b001 => x.signed != y.signed, // bne
0b100 => x.signed < y.signed, // blt
0b101 => x.signed >= y.signed, // bge
0b110 => x.unsigned < y.unsigned, // bltu
0b111 => x.unsigned >= y.unsigned, // bgeu
else => @panic("invalid branch funct3"),
};
if (cond) s.pc.unsigned +%= b.getImmediate().unsigned -% @sizeOf(u32);
},
.load => {
const i = insn.i;
const addr: Register = .{ .unsigned = s.loadRegister(i.rs1).unsigned + i.getImmediate().unsigned };
const value = switch (i.funct3) {
0b000 => signExtend(s.load(addr, i8)),
0b001 => signExtend(s.load(addr, i16)),
0b010 => signExtend(s.load(addr, i32)),
0b011 => @panic("64 bit insn"),
0b100 => zeroExtend(s.load(addr, u8)),
0b101 => zeroExtend(s.load(addr, u16)),
0b110 => @panic("64 bit insn"),
else => @panic("invalid load funct3"),
};
s.storeRegister(i.rd, value);
},
.store => {
const is = insn.s;
const addr: Register = .{ .unsigned = s.loadRegister(is.rs1).unsigned +% is.getImmediate().unsigned };
const value = s.loadRegister(is.rs2).unsigned;
switch (is.funct3) {
0b000 => s.store(addr, u8, @as(u8, @truncate(value))),
0b001 => s.store(addr, u16, @as(u16, @truncate(value))),
0b010 => s.store(addr, u32, @as(u32, @truncate(value))),
0b011 => @panic("64-bit insn"),
else => @panic("invalid store funct3"),
}
},
.op_imm => {
const i = insn.i;
const r = insn.r;
const a = s.loadRegister(i.rs1);
const b = i.getImmediate();
const shamt = std.math.cast(Log2Int(i32), r.rs2) catch @panic("rs2 don't fit");
const value: Register = switch (i.funct3) {
0b000 => .{.signed = a.signed +% b.signed }, // addi
0b010 => .{.unsigned = @intFromBool(a.signed < b.signed) }, // slti
0b011 => .{.unsigned = @intFromBool(a.unsigned < b.unsigned) }, // sltiu
0b100 => .{.unsigned = a.unsigned ^ b.unsigned}, // xori
0b110 => .{.unsigned = a.unsigned | b.unsigned}, // ori
0b111 => .{.unsigned = a.unsigned & b.unsigned}, // andi
0b001 => .{.unsigned = a.unsigned << shamt}, // slli
0b101 => switch(r.funct7) {
0b0000000 => .{.unsigned = a.unsigned >> shamt}, // srli
0b0100000 => .{.unsigned = a.signed >> shamt}, // srai
else => @panic("invalid op_imm funct7"),
},
};
s.storeRegister(i.rd, value);
},
.op => {
const r = insn.r;
const a = s.loadRegister(r.rs1);
const b = s.loadRegister(r.rs2);
const shamt = std.math.cast(Log2Int(i32), b) catch @panic("b don't fit");
const value = switch (r.funct3) {
0b000 => switch (r.funct7) {
0b0000000 => a +% b, // add
0b0100000 => a -% b, // sub
0b0000001 => a *% b, // mul
else => @panic("invalid op funct7"),
},
0b001 => switch (r.funct7) {
0 => @intFromBool(a < b),
1 => @as(i32, @intCast(std.math.mulWide(i32, a, b) >> 32)),
else => @panic("invalid op funct7"),
},
0b010 => switch (r.funct7) {
1 => @as(i32, @intCast(@as(i64, a) *% @as(i64, @as(u32, @bitCast(b))) >> 32)),
else => @panic("invalid op funct7"),
},
0b011 => switch (r.funct7) {
0 => @intFromBool(@as(u32, @bitCast(a)) < @as(u32, @bitCast(b))),
1 => @as(i32, @intCast(@as(i64, @as(u32, @bitCast(a))) *% @as(i64, @as(u32, @bitCast(b))) >> 32)),
else => @panic("invalid op funct7"),
},
0b100 => switch (r.funct7) {
0 => a ^ b,
1 => @divTrunc(a, b),
else => @panic("invalid op funct7"),
},
0b101 => switch (r.funct7) {
0 => @as(i32, @bitCast(@as(u32, @bitCast(a)) >> shamt)),
0b0100000 => a >> shamt,
1 => @as(i32, @bitCast(@as(u32, @bitCast(a)) / @as(u32, @bitCast(b)))),
else => @panic("invalid op funct7"),
},
0b110 => switch (r.funct7) {
0 => a | b,
1 => @rem(a, b),
else => @panic("invalid op funct7"),
},
0b111 => a & b,
};
s.storeRegister(r.rd, value);
},
.system => {
const i = insn.i;
switch (i.funct3) {
0b000 => switch (i.imm0_11) {
0 => s.ecall_handler(s.context, s),
1 => s.ebreak_handler(s.context, s),
else => @panic("invalid imm system"),
},
0b001 => {
@panic("TODO IMPLEMENT");
},
else => @panic("invalid funct3 system"),
}
},
else => |op| {
var b: [512]u8 = undefined;
@panic(std.fmt.bufPrint(&b, "TODO opcode {} insn 0x{x}\n", .{ op, insn }) catch unreachable);
},
}
}
};
// const jit_instructions = struct {
// const names = [_][]const u8{ "rax", "rbx", "rcx", "rdx", "rdi", "rsi", "rbp", "rsp", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15" };
// inline fn get(comptime r: usize) u64 {
// return asm volatile ("mov %" ++ names[r] ++ ", %[dst]"
// : [dst] "=r" (-> u64),
// );
// }
// inline fn set(comptime r: usize, v: u64) void {
// return asm volatile ("mov %[src], %" ++ names[r]
// :
// : [src] "r" (v),
// );
// }
// export fn save_regs_and_call(f: *const fn () void) void {
// const g = struct {
// var regs: [16]u64 = undefined;
// };
// inline for (0..16) |i| g.regs[i] = get(i);
// f();
// inline for (0..16) |i| set(i, g.regs[i]);
// }
// const save_regs_and_call_addr = &save_regs_and_call;
// const GeneratedInsn = union(enum) {};
// };
// test "example program" {
// const instructions = [_]u32{
// 0x00f5d613,
// 0x00a64633,
// 0x735a36b7,
// 0xd9768693,
// 0x00d60633,
// 0x00a64633,
// 0x00f65513,
// 0x00b54533,
// 0xcaf656b7,
// 0x9a968693,
// 0x00d50533,
// 0x00b545b3,
// 0x0105d513,
// 0x00c54533,
// 0x01065613,
// 0x00c5c5b3,
// };
// var s = ExecutionState{
// .memory = &.{},
// .instructions = @ptrCast(@as([]const u32, &instructions)),
// };
// while (s.step()) |_| {}
// try std.testing.expectEqual(@as(i32, 1935337312), s.registers[10]);
// try std.testing.expectEqual(@as(i32, -889765113), s.registers[11]);
// try std.testing.expectEqual(@as(i32, 29530), s.registers[12]);
// try std.testing.expectEqual(@as(i32, -889828951), s.registers[13]);
// }
// test "addi" {
// const instructions = [_]u32{
// 0x53900093,
// 0xac700113,
// 0xe5c00193,
// 0x1a400213,
// };
// var s = ExecutionState{
// .instructions = &instructions,
// .memory = &.{},
// };
// while (s.step()) |_| {}
// try std.testing.expectEqual(@as(i32, 1337), s.registers[1]);
// try std.testing.expectEqual(@as(i32, -1337), s.registers[2]);
// try std.testing.expectEqual(@as(i32, -420), s.registers[3]);
// try std.testing.expectEqual(@as(i32, 420), s.registers[4]);
// }
// test "hello world" {
// const instructions = [_]u32{
// 0x04800513,
// 0x00a00023,
// 0x06500513,
// 0x00a000a3,
// 0x06c00513,
// 0x00a00123,
// 0x00a001a3,
// 0x06f00593,
// 0x00b00223,
// 0x02c00613,
// 0x00c002a3,
// 0x02000613,
// 0x00c00323,
// 0x05700613,
// 0x00c003a3,
// 0x00b00423,
// 0x07200593,
// 0x00b004a3,
// 0x00a00523,
// 0x06400513,
// 0x00a005a3,
// };
// var memory: ["Hello, World"[0..].len]u8 = undefined;
// var s = ExecutionState{
// .instructions = &instructions,
// .memory = &memory,
// };
// while (s.step()) |_| {}
// try std.testing.expectEqualSlices(u8, "Hello, World", &memory);
// }
// test "ecall" {
// const instructions = [_]u32{
// // store "Hello, World" in 0
// 0x04800513,
// 0x00a00023,
// 0x06500513,
// 0x00a000a3,
// 0x06c00513,
// 0x00a00123,
// 0x00a001a3,
// 0x06f00593,
// 0x00b00223,
// 0x02c00613,
// 0x00c002a3,
// 0x02000613,
// 0x00c00323,
// 0x05700613,
// 0x00c003a3,
// 0x00b00423,
// 0x07200593,
// 0x00b004a3,
// 0x00a00523,
// 0x06400513,
// 0x00a005a3,
// // Ecall with a0 = 69, a1 = 0, a2 = 12
// 0x04500513,
// 0x00000593,
// 0x00c00613,
// 0x00000073,
// };
// var memory: [8192]u8 = undefined;
// const HandlerState = struct {
// got_ecall: bool = false,
// fn handleEcall(state_p: ?*anyopaque, s: *ExecutionState) void {
// var state = @as(*@This(), @ptrCast(@alignCast(state_p)));
// if (s.registers[10] == 69 and s.registers[11] == 0 and s.registers[12] == 12)
// state.got_ecall = true;
// }
// };
// var hs = HandlerState{};
// var s = ExecutionState{
// .instructions = &instructions,
// .memory = &memory,
// .context = @as(?*anyopaque, @ptrCast(&hs)),
// .ecall_handler = HandlerState.handleEcall,
// };
// while (s.step()) |_| {}
// try std.testing.expect(hs.got_ecall);
// }
// test "branching" {
// const instructions = [_]u32{
// 0x04500593, // li a1, 69
// 0x00b51863, // bne a0, a1, .LBB0_2
// 0x00100513, // li a0, 1
// 0x00000073, // ecall
// 0x00008067, // ret
// 0x00000513, // .LBB0_2: li a0, 0
// 0x00000073, // ecall
// 0x00008067, // ret
// };
// const HandlerState = struct {
// arg1: ?i32 = null,
// fn handleEcall(state_p: ?*anyopaque, s: *ExecutionState) void {
// var state = @as(*@This(), @ptrCast(@alignCast(state_p)));
// state.arg1 = s.registers[10];
// }
// };
// var hs = HandlerState{};
// var s = ExecutionState{
// .instructions = &instructions,
// .memory = &.{},
// .context = @as(?*anyopaque, @ptrCast(&hs)),
// .ecall_handler = HandlerState.handleEcall,
// };
// s.registers[1] = instructions.len * @sizeOf(u32) + 1;
// s.registers[10] = 69;
// while (s.step()) |_| {}
// try std.testing.expectEqual(@as(i32, 1), hs.arg1.?);
// }
// test "matmul non-vector" {
// const instructions = b: {
// const t = comptime std.mem.bytesAsSlice(u32, @embedFile("matmul_test.bin"));
// var d: [t.len]u32 = undefined;
// @memcpy(&d, t);
// break :b d;
// };
// const stack_size = 8192;
// var memory = std.testing.allocator.create([stack_size + 16 * 3]u8) catch unreachable;
// defer std.testing.allocator.free(memory);
// @memset(memory[0..stack_size], undefined);
// @memset(memory[stack_size..], 2);
// var s = ExecutionState{
// .instructions = &instructions,
// .memory = memory,
// };
// s.registers[1] = @as(i32, @intCast((instructions.len + 1) * @sizeOf(u32))); // return address
// s.registers[2] = stack_size; // stack pointer
// s.registers[10] = stack_size; // dst
// s.registers[11] = stack_size + 16; // a
// s.registers[12] = stack_size + 16 * 2; // b
// while (s.step()) |_| {}
// for (memory[stack_size..][0..16]) |e| try std.testing.expectEqual(@as(u8, 16), e);
// }
comptime {
@compileLog(@as(InsnI, @bitCast(@as(u32, 0x22858407))));
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment