Last active
March 19, 2025 13:06
-
-
Save tung/c185c7b5c3af96b79b94537fa93b6747 to your computer and use it in GitHub Desktop.
Calculator for infix math expressions with floating point numbers using Pratt parsing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/// Calculator for infix math expressions with floating point numbers using Pratt parsing. | |
const std = @import("std"); | |
const Parser = struct { | |
text: []const u8, | |
const Self = @This(); | |
const PrefixOperator = enum { | |
Negate, | |
const PrefixOperatorSelf = @This(); | |
pub fn right_bind_power(self: PrefixOperator) u8 { | |
return switch (self) { | |
.Negate => 7, | |
}; | |
} | |
}; | |
const InfixOperator = enum { | |
Add, | |
Subtract, | |
Multiply, | |
Divide, | |
Exponent, | |
const InfixOperatorSelf = @This(); | |
pub fn bind_powers(self: InfixOperatorSelf) struct { left: u8, right: u8 } { | |
return switch (self) { | |
.Add, .Subtract => .{ .left = 1, .right = 2 }, | |
.Multiply, .Divide => .{ .left = 3, .right = 4 }, | |
.Exponent => .{ .left = 5, .right = 6 }, | |
}; | |
} | |
}; | |
pub fn init(text: []const u8) Self { | |
return .{ | |
.text = text, | |
}; | |
} | |
fn clone(self: Self) Self { | |
return .{ | |
.text = self.text, | |
}; | |
} | |
fn literal(self: *Self, lit: []const u8) bool { | |
if (!std.mem.startsWith(u8, self.text, lit)) return false; | |
self.text = self.text[lit.len..]; | |
return true; | |
} | |
test "literal" { | |
const TestCase = struct { | |
in_text: []const u8, | |
in_lit: []const u8, | |
out_match: bool, | |
out_text: []const u8, | |
}; | |
const cases = [_]TestCase{ | |
.{ .in_text = "", .in_lit = "", .out_match = true, .out_text = "" }, | |
.{ .in_text = "foo", .in_lit = "foo", .out_match = true, .out_text = "" }, | |
.{ .in_text = "foobar", .in_lit = "foo", .out_match = true, .out_text = "bar" }, | |
.{ .in_text = "foobar", .in_lit = "bar", .out_match = false, .out_text = "foobar" }, | |
.{ .in_text = "foofoo", .in_lit = "foo", .out_match = true, .out_text = "foo" }, | |
}; | |
for (cases) |case| { | |
var s = Self.init(case.in_text); | |
try std.testing.expectEqual(case.out_match, s.literal(case.in_lit)); | |
try std.testing.expectEqualStrings(case.out_text, s.text); | |
} | |
} | |
fn any(self: *Self, matches: []const u8) bool { | |
if (self.text.len == 0 or matches.len == 0) return false; | |
if (std.mem.indexOfScalar(u8, matches, self.text[0])) |_| { | |
self.text = self.text[1..]; | |
return true; | |
} | |
return false; | |
} | |
test "any" { | |
const TestCase = struct { | |
in_text: []const u8, | |
in_matches: []const u8, | |
out: bool, | |
}; | |
const cases = [_]TestCase{ | |
.{ .in_text = "", .in_matches = "", .out = false }, | |
.{ .in_text = "a", .in_matches = "a", .out = true }, | |
.{ .in_text = "a", .in_matches = "abc", .out = true }, | |
.{ .in_text = "b", .in_matches = "abc", .out = true }, | |
.{ .in_text = "c", .in_matches = "abc", .out = true }, | |
.{ .in_text = "d", .in_matches = "abc", .out = false }, | |
.{ .in_text = "5", .in_matches = "0123456789", .out = true }, | |
.{ .in_text = "-", .in_matches = "0123456789", .out = false }, | |
.{ .in_text = "-", .in_matches = "+-*/^", .out = true }, | |
.{ .in_text = "5", .in_matches = "+-*/^", .out = false }, | |
}; | |
for (cases) |case| { | |
var s = Self.init(case.in_text); | |
const out = s.any(case.in_matches); | |
try std.testing.expectEqual(case.out, out); | |
} | |
} | |
fn space(self: *Self) void { | |
while (self.any(" \t\n")) continue; | |
} | |
test "space" { | |
const TestCase = struct { | |
in_text: []const u8, | |
out_text: []const u8, | |
}; | |
const cases = [_]TestCase{ | |
.{ .in_text = "", .out_text = "" }, | |
.{ .in_text = "1", .out_text = "1" }, | |
.{ .in_text = " 1", .out_text = "1" }, | |
.{ .in_text = "\t1", .out_text = "1" }, | |
.{ .in_text = "\n1", .out_text = "1" }, | |
.{ .in_text = " \n1", .out_text = "1" }, | |
.{ .in_text = "\n 1", .out_text = "1" }, | |
.{ .in_text = " \n \t1", .out_text = "1" }, | |
.{ .in_text = "1 ", .out_text = "1 " }, | |
}; | |
for (cases) |case| { | |
var s = Self.init(case.in_text); | |
s.space(); | |
try std.testing.expectEqualStrings(case.out_text, s.text); | |
} | |
} | |
fn number(self: *Self) !?f64 { | |
var num_match = self.clone(); | |
while (num_match.any("0123456789")) continue; | |
if (num_match.text.len == self.text.len) return null; | |
var full_num_match = num_match.clone(); | |
if (full_num_match.literal(".")) { | |
while (full_num_match.any("0123456789")) continue; | |
num_match = full_num_match; | |
} | |
const num = try std.fmt.parseFloat(f64, self.text[0 .. self.text.len - num_match.text.len]); | |
self.text = num_match.text; | |
return num; | |
} | |
fn expectFloatEq(comptime T: type, expected: T, actual: T) !void { | |
try std.testing.expect(std.math.approxEqAbs(T, expected, actual, std.math.floatEps(T))); | |
} | |
test "number" { | |
const TestCase = struct { | |
in: []const u8, | |
out_num: ?f64, | |
out_text: []const u8, | |
}; | |
const cases = [_]TestCase{ | |
.{ .in = "", .out_num = null, .out_text = "" }, | |
.{ .in = "a", .out_num = null, .out_text = "a" }, | |
.{ .in = "a1", .out_num = null, .out_text = "a1" }, | |
.{ .in = "123", .out_num = 123.0, .out_text = "" }, | |
.{ .in = "123.", .out_num = 123.0, .out_text = "" }, | |
.{ .in = "123.0", .out_num = 123.0, .out_text = "" }, | |
.{ .in = "123.456", .out_num = 123.456, .out_text = "" }, | |
.{ .in = "123.456z", .out_num = 123.456, .out_text = "z" }, | |
}; | |
for (cases) |case| { | |
var s = Self.init(case.in); | |
const opt_res = try s.number(); | |
if (case.out_num) |out_num| { | |
const res = opt_res orelse return error.TestExpectedNonNull; | |
try expectFloatEq(f64, out_num, res); | |
} else { | |
try std.testing.expectEqual(null, opt_res); | |
} | |
try std.testing.expectEqualStrings(case.out_text, s.text); | |
} | |
} | |
fn prefix_operator(self: *Self) ?PrefixOperator { | |
if (self.text.len == 0) return null; | |
var op_match = self.clone(); | |
if (!op_match.any("-")) return null; | |
const op: PrefixOperator = switch (self.text[0]) { | |
'-' => .Negate, | |
else => unreachable, | |
}; | |
self.text = op_match.text; | |
return op; | |
} | |
test "prefix_operator" { | |
const TestCase = struct { | |
in: []const u8, | |
out: ?PrefixOperator, | |
}; | |
const cases = [_]TestCase{ | |
.{ .in = "", .out = null }, | |
.{ .in = "-", .out = .Negate }, | |
.{ .in = "1", .out = null }, | |
}; | |
for (cases) |case| { | |
var s = Self.init(case.in); | |
const opt_res = s.prefix_operator(); | |
if (case.out) |out| { | |
const res = opt_res orelse return error.TestExpectedNonNull; | |
try std.testing.expectEqual(out, res); | |
} else { | |
try std.testing.expectEqual(null, opt_res); | |
} | |
} | |
} | |
fn infix_operator(self: *Self, min_bind_power: u8) ?InfixOperator { | |
if (self.text.len == 0) return null; | |
var op_match = self.clone(); | |
if (!op_match.any("+-*/^")) return null; | |
const op: InfixOperator = switch (self.text[0]) { | |
'+' => .Add, | |
'-' => .Subtract, | |
'*' => .Multiply, | |
'/' => .Divide, | |
'^' => .Exponent, | |
else => unreachable, | |
}; | |
if (op.bind_powers().left < min_bind_power) return null; | |
self.text = op_match.text; | |
return op; | |
} | |
test "infix_operator" { | |
const TestCase = struct { | |
in_text: []const u8, | |
in_min_bp: u8, | |
out: ?InfixOperator, | |
}; | |
const cases = [_]TestCase{ | |
.{ .in_text = "", .in_min_bp = 0, .out = null }, | |
.{ .in_text = "1", .in_min_bp = 0, .out = null }, | |
.{ .in_text = "+", .in_min_bp = 0, .out = .Add }, | |
.{ .in_text = "-", .in_min_bp = 0, .out = .Subtract }, | |
.{ .in_text = "*", .in_min_bp = 0, .out = .Multiply }, | |
.{ .in_text = "/", .in_min_bp = 0, .out = .Divide }, | |
.{ .in_text = "^", .in_min_bp = 0, .out = .Exponent }, | |
.{ .in_text = "+", .in_min_bp = 2, .out = null }, | |
.{ .in_text = "-", .in_min_bp = 2, .out = null }, | |
.{ .in_text = "*", .in_min_bp = 2, .out = .Multiply }, | |
.{ .in_text = "/", .in_min_bp = 2, .out = .Divide }, | |
.{ .in_text = "^", .in_min_bp = 2, .out = .Exponent }, | |
.{ .in_text = "+", .in_min_bp = 4, .out = null }, | |
.{ .in_text = "-", .in_min_bp = 4, .out = null }, | |
.{ .in_text = "*", .in_min_bp = 4, .out = null }, | |
.{ .in_text = "/", .in_min_bp = 4, .out = null }, | |
.{ .in_text = "^", .in_min_bp = 4, .out = .Exponent }, | |
.{ .in_text = "^", .in_min_bp = 6, .out = null }, | |
}; | |
for (cases) |case| { | |
var s = Self.init(case.in_text); | |
const opt_res = s.infix_operator(case.in_min_bp); | |
if (case.out) |out| { | |
const res = opt_res orelse return error.TestExpectedNonNull; | |
try std.testing.expectEqual(out, res); | |
} else { | |
try std.testing.expectEqual(null, opt_res); | |
} | |
} | |
} | |
pub const ParseError = error{ | |
ExpectedExpression, | |
ExpectedOpenParen, | |
ExpectedCloseParen, | |
ExpectedComma, | |
} || std.fmt.ParseFloatError; | |
fn paren_expr(self: *Self) ParseError!f64 { | |
const inner = try self.expr(0); | |
if (!self.literal(")")) return error.ExpectedCloseParen; | |
return inner; | |
} | |
fn prefix_expr(self: *Self, op: PrefixOperator) ParseError!f64 { | |
const inner = try self.expr(op.right_bind_power()); | |
return switch (op) { | |
.Negate => -inner, | |
}; | |
} | |
fn pow_expr(self: *Self) ParseError!f64 { | |
self.space(); | |
if (!self.literal("(")) return error.ExpectedOpenParen; | |
const x = try self.expr(0); | |
if (!self.literal(",")) return error.ExpectedComma; | |
const y = try self.paren_expr(); | |
return std.math.pow(f64, x, y); | |
} | |
fn sqrt_expr(self: *Self) ParseError!f64 { | |
self.space(); | |
if (!self.literal("(")) return error.ExpectedOpenParen; | |
const inner = try self.paren_expr(); | |
return std.math.sqrt(inner); | |
} | |
fn cbrt_expr(self: *Self) ParseError!f64 { | |
self.space(); | |
if (!self.literal("(")) return error.ExpectedOpenParen; | |
const inner = try self.paren_expr(); | |
return std.math.cbrt(inner); | |
} | |
fn expr(self: *Self, min_bind_power: u8) ParseError!f64 { | |
self.space(); | |
var num = blk: { | |
if (try self.number()) |n| break :blk n; | |
if (self.literal("(")) break :blk try self.paren_expr(); | |
if (self.prefix_operator()) |op| break :blk try self.prefix_expr(op); | |
if (self.literal("pow")) break :blk try self.pow_expr(); | |
if (self.literal("sqrt")) break :blk try self.sqrt_expr(); | |
if (self.literal("cbrt")) break :blk try self.cbrt_expr(); | |
return error.ExpectedExpression; | |
}; | |
while (self.text.len > 0) { | |
self.space(); | |
const op = self.infix_operator(min_bind_power) orelse break; | |
const inner = try self.expr(op.bind_powers().right); | |
switch (op) { | |
.Add => num += inner, | |
.Subtract => num -= inner, | |
.Multiply => num *= inner, | |
.Divide => num /= inner, | |
.Exponent => num = std.math.pow(f64, num, inner), | |
} | |
} | |
self.space(); | |
return num; | |
} | |
const ExprTestCase = struct { | |
out: f64, | |
in: []const u8, | |
}; | |
fn testExprs(cases: []const ExprTestCase) !void { | |
for (cases) |case| { | |
var s = Self.init(case.in); | |
try expectFloatEq(f64, case.out, try s.expr(0)); | |
} | |
} | |
test "expr_solo" { | |
try testExprs(&.{ | |
.{ .out = 1, .in = "1" }, | |
.{ .out = 0.3, .in = "0.3" }, | |
}); | |
} | |
test "expr_ops" { | |
try testExprs(&.{ | |
.{ .out = 15, .in = "12+3" }, | |
.{ .out = 9, .in = "12-3" }, | |
.{ .out = 36, .in = "12*3" }, | |
.{ .out = 4, .in = "12/3" }, | |
.{ .out = 1728, .in = "12^3" }, | |
}); | |
} | |
test "expr_precedence" { | |
try testExprs(&.{ | |
.{ .out = 10, .in = "5+4*3/2-1" }, | |
.{ .out = 5, .in = "4+3*6/9-1" }, | |
.{ .out = 1, .in = "5+2*6/4-7" }, | |
.{ .out = 81, .in = "-9^2" }, | |
}); | |
} | |
test "paren_expr" { | |
try testExprs(&.{ | |
.{ .out = 1, .in = "(1)" }, | |
.{ .out = 1, .in = "(((1)))" }, | |
.{ .out = 60, .in = "(2+3)*(5+7)" }, | |
}); | |
} | |
test "prefix_expr" { | |
try testExprs(&.{ | |
.{ .out = -1, .in = "-1" }, | |
.{ .out = 1, .in = "--1" }, | |
.{ .out = -1, .in = "---1" }, | |
.{ .out = -1, .in = "- 1" }, | |
.{ .out = -1, .in = "(-1)" }, | |
.{ .out = -1, .in = "-(1)" }, | |
.{ .out = 1, .in = "-(-1)" }, | |
.{ .out = 2, .in = "1--1" }, | |
.{ .out = 0, .in = "1---1" }, | |
}); | |
} | |
test "pow_expr" { | |
try testExprs(&.{ | |
.{ .out = 1, .in = "pow(1 , 1)" }, | |
.{ .out = 4, .in = "pow(2,2)" }, | |
.{ .out = 8, .in = "pow(2,3)" }, | |
.{ .out = 9, .in = "pow(3,2)" }, | |
.{ .out = 27, .in = "pow(3,3)" }, | |
.{ .out = 1, .in = "pow(4,0)" }, | |
}); | |
} | |
test "sqrt_expr" { | |
try testExprs(&.{ | |
.{ .out = 1, .in = "sqrt(1)" }, | |
.{ .out = 2, .in = "sqrt(4)" }, | |
.{ .out = 3, .in = "sqrt(9)" }, | |
.{ .out = 4, .in = "sqrt(16)" }, | |
.{ .out = 5, .in = "sqrt(25)" }, | |
.{ .out = 1, .in = "sqrt (1)" }, | |
}); | |
} | |
test "cbrt_expr" { | |
try testExprs(&.{ | |
.{ .out = 1, .in = "cbrt(1)" }, | |
.{ .out = 2, .in = "cbrt(8)" }, | |
.{ .out = 3, .in = "cbrt(27)" }, | |
.{ .out = 4, .in = "cbrt(64)" }, | |
.{ .out = 5, .in = "cbrt(125)" }, | |
.{ .out = 1, .in = "cbrt (1)" }, | |
}); | |
} | |
}; | |
pub fn main() !void { | |
const stdin = std.io.getStdIn().reader(); | |
const stdout = std.io.getStdOut().writer(); | |
var gpa = std.heap.DebugAllocator(.{}).init; | |
defer if (gpa.deinit() != .ok) @panic("GPA LEAK"); | |
const alloc = gpa.allocator(); | |
const buf = try alloc.alloc(u8, 4096); | |
defer alloc.free(buf); | |
const input_and_prompt: struct { | |
input: []u8, | |
prompt: ?[]const u8, | |
} = blk: { | |
var args = std.process.args(); | |
_ = args.next(); | |
if (args.next()) |arg| { | |
const len = @min(arg.len, buf.len); | |
const arg_input = buf[0..len]; | |
@memcpy(arg_input, arg); | |
break :blk .{ .input = arg_input, .prompt = null }; | |
} else { | |
const prompt = "Expression: "; | |
try stdout.print("{s}", .{prompt}); | |
break :blk .{ | |
.input = stdin.readUntilDelimiter(buf, '\n') catch |err| switch (err) { | |
error.EndOfStream => { | |
try stdout.print("\n", .{}); | |
return; | |
}, | |
else => return err, | |
}, | |
.prompt = prompt, | |
}; | |
} | |
}; | |
const input = input_and_prompt.input; | |
const prompt = input_and_prompt.prompt; | |
var parser = Parser.init(input); | |
const result = blk: { | |
if (parser.expr(0)) |result| { | |
break :blk if (parser.text.len > 0) error.ParseError else result; | |
} else |err| { | |
break :blk err; | |
} | |
}; | |
if (result) |r| { | |
try stdout.print("{d}\n", .{r}); | |
} else |err| { | |
if (prompt) |p| { | |
for (p) |_| try stdout.print(" ", .{}); | |
for (parser.text.len..input.len) |_| try stdout.print(" ", .{}); | |
try stdout.print("^\n", .{}); | |
} | |
return err; | |
} | |
} | |
test "refs" { | |
std.testing.refAllDecls(@This()); | |
_ = Parser; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment