Skip to content

Instantly share code, notes, and snippets.

@tung
Last active March 19, 2025 13:06
Show Gist options
  • Save tung/c185c7b5c3af96b79b94537fa93b6747 to your computer and use it in GitHub Desktop.
Save tung/c185c7b5c3af96b79b94537fa93b6747 to your computer and use it in GitHub Desktop.
Calculator for infix math expressions with floating point numbers using Pratt parsing
/// Calculator for infix math expressions with floating point numbers using Pratt parsing.
const std = @import("std");
const Parser = struct {
text: []const u8,
const Self = @This();
const PrefixOperator = enum {
Negate,
const PrefixOperatorSelf = @This();
pub fn right_bind_power(self: PrefixOperator) u8 {
return switch (self) {
.Negate => 7,
};
}
};
const InfixOperator = enum {
Add,
Subtract,
Multiply,
Divide,
Exponent,
const InfixOperatorSelf = @This();
pub fn bind_powers(self: InfixOperatorSelf) struct { left: u8, right: u8 } {
return switch (self) {
.Add, .Subtract => .{ .left = 1, .right = 2 },
.Multiply, .Divide => .{ .left = 3, .right = 4 },
.Exponent => .{ .left = 5, .right = 6 },
};
}
};
pub fn init(text: []const u8) Self {
return .{
.text = text,
};
}
fn clone(self: Self) Self {
return .{
.text = self.text,
};
}
fn literal(self: *Self, lit: []const u8) bool {
if (!std.mem.startsWith(u8, self.text, lit)) return false;
self.text = self.text[lit.len..];
return true;
}
test "literal" {
const TestCase = struct {
in_text: []const u8,
in_lit: []const u8,
out_match: bool,
out_text: []const u8,
};
const cases = [_]TestCase{
.{ .in_text = "", .in_lit = "", .out_match = true, .out_text = "" },
.{ .in_text = "foo", .in_lit = "foo", .out_match = true, .out_text = "" },
.{ .in_text = "foobar", .in_lit = "foo", .out_match = true, .out_text = "bar" },
.{ .in_text = "foobar", .in_lit = "bar", .out_match = false, .out_text = "foobar" },
.{ .in_text = "foofoo", .in_lit = "foo", .out_match = true, .out_text = "foo" },
};
for (cases) |case| {
var s = Self.init(case.in_text);
try std.testing.expectEqual(case.out_match, s.literal(case.in_lit));
try std.testing.expectEqualStrings(case.out_text, s.text);
}
}
fn any(self: *Self, matches: []const u8) bool {
if (self.text.len == 0 or matches.len == 0) return false;
if (std.mem.indexOfScalar(u8, matches, self.text[0])) |_| {
self.text = self.text[1..];
return true;
}
return false;
}
test "any" {
const TestCase = struct {
in_text: []const u8,
in_matches: []const u8,
out: bool,
};
const cases = [_]TestCase{
.{ .in_text = "", .in_matches = "", .out = false },
.{ .in_text = "a", .in_matches = "a", .out = true },
.{ .in_text = "a", .in_matches = "abc", .out = true },
.{ .in_text = "b", .in_matches = "abc", .out = true },
.{ .in_text = "c", .in_matches = "abc", .out = true },
.{ .in_text = "d", .in_matches = "abc", .out = false },
.{ .in_text = "5", .in_matches = "0123456789", .out = true },
.{ .in_text = "-", .in_matches = "0123456789", .out = false },
.{ .in_text = "-", .in_matches = "+-*/^", .out = true },
.{ .in_text = "5", .in_matches = "+-*/^", .out = false },
};
for (cases) |case| {
var s = Self.init(case.in_text);
const out = s.any(case.in_matches);
try std.testing.expectEqual(case.out, out);
}
}
fn space(self: *Self) void {
while (self.any(" \t\n")) continue;
}
test "space" {
const TestCase = struct {
in_text: []const u8,
out_text: []const u8,
};
const cases = [_]TestCase{
.{ .in_text = "", .out_text = "" },
.{ .in_text = "1", .out_text = "1" },
.{ .in_text = " 1", .out_text = "1" },
.{ .in_text = "\t1", .out_text = "1" },
.{ .in_text = "\n1", .out_text = "1" },
.{ .in_text = " \n1", .out_text = "1" },
.{ .in_text = "\n 1", .out_text = "1" },
.{ .in_text = " \n \t1", .out_text = "1" },
.{ .in_text = "1 ", .out_text = "1 " },
};
for (cases) |case| {
var s = Self.init(case.in_text);
s.space();
try std.testing.expectEqualStrings(case.out_text, s.text);
}
}
fn number(self: *Self) !?f64 {
var num_match = self.clone();
while (num_match.any("0123456789")) continue;
if (num_match.text.len == self.text.len) return null;
var full_num_match = num_match.clone();
if (full_num_match.literal(".")) {
while (full_num_match.any("0123456789")) continue;
num_match = full_num_match;
}
const num = try std.fmt.parseFloat(f64, self.text[0 .. self.text.len - num_match.text.len]);
self.text = num_match.text;
return num;
}
fn expectFloatEq(comptime T: type, expected: T, actual: T) !void {
try std.testing.expect(std.math.approxEqAbs(T, expected, actual, std.math.floatEps(T)));
}
test "number" {
const TestCase = struct {
in: []const u8,
out_num: ?f64,
out_text: []const u8,
};
const cases = [_]TestCase{
.{ .in = "", .out_num = null, .out_text = "" },
.{ .in = "a", .out_num = null, .out_text = "a" },
.{ .in = "a1", .out_num = null, .out_text = "a1" },
.{ .in = "123", .out_num = 123.0, .out_text = "" },
.{ .in = "123.", .out_num = 123.0, .out_text = "" },
.{ .in = "123.0", .out_num = 123.0, .out_text = "" },
.{ .in = "123.456", .out_num = 123.456, .out_text = "" },
.{ .in = "123.456z", .out_num = 123.456, .out_text = "z" },
};
for (cases) |case| {
var s = Self.init(case.in);
const opt_res = try s.number();
if (case.out_num) |out_num| {
const res = opt_res orelse return error.TestExpectedNonNull;
try expectFloatEq(f64, out_num, res);
} else {
try std.testing.expectEqual(null, opt_res);
}
try std.testing.expectEqualStrings(case.out_text, s.text);
}
}
fn prefix_operator(self: *Self) ?PrefixOperator {
if (self.text.len == 0) return null;
var op_match = self.clone();
if (!op_match.any("-")) return null;
const op: PrefixOperator = switch (self.text[0]) {
'-' => .Negate,
else => unreachable,
};
self.text = op_match.text;
return op;
}
test "prefix_operator" {
const TestCase = struct {
in: []const u8,
out: ?PrefixOperator,
};
const cases = [_]TestCase{
.{ .in = "", .out = null },
.{ .in = "-", .out = .Negate },
.{ .in = "1", .out = null },
};
for (cases) |case| {
var s = Self.init(case.in);
const opt_res = s.prefix_operator();
if (case.out) |out| {
const res = opt_res orelse return error.TestExpectedNonNull;
try std.testing.expectEqual(out, res);
} else {
try std.testing.expectEqual(null, opt_res);
}
}
}
fn infix_operator(self: *Self, min_bind_power: u8) ?InfixOperator {
if (self.text.len == 0) return null;
var op_match = self.clone();
if (!op_match.any("+-*/^")) return null;
const op: InfixOperator = switch (self.text[0]) {
'+' => .Add,
'-' => .Subtract,
'*' => .Multiply,
'/' => .Divide,
'^' => .Exponent,
else => unreachable,
};
if (op.bind_powers().left < min_bind_power) return null;
self.text = op_match.text;
return op;
}
test "infix_operator" {
const TestCase = struct {
in_text: []const u8,
in_min_bp: u8,
out: ?InfixOperator,
};
const cases = [_]TestCase{
.{ .in_text = "", .in_min_bp = 0, .out = null },
.{ .in_text = "1", .in_min_bp = 0, .out = null },
.{ .in_text = "+", .in_min_bp = 0, .out = .Add },
.{ .in_text = "-", .in_min_bp = 0, .out = .Subtract },
.{ .in_text = "*", .in_min_bp = 0, .out = .Multiply },
.{ .in_text = "/", .in_min_bp = 0, .out = .Divide },
.{ .in_text = "^", .in_min_bp = 0, .out = .Exponent },
.{ .in_text = "+", .in_min_bp = 2, .out = null },
.{ .in_text = "-", .in_min_bp = 2, .out = null },
.{ .in_text = "*", .in_min_bp = 2, .out = .Multiply },
.{ .in_text = "/", .in_min_bp = 2, .out = .Divide },
.{ .in_text = "^", .in_min_bp = 2, .out = .Exponent },
.{ .in_text = "+", .in_min_bp = 4, .out = null },
.{ .in_text = "-", .in_min_bp = 4, .out = null },
.{ .in_text = "*", .in_min_bp = 4, .out = null },
.{ .in_text = "/", .in_min_bp = 4, .out = null },
.{ .in_text = "^", .in_min_bp = 4, .out = .Exponent },
.{ .in_text = "^", .in_min_bp = 6, .out = null },
};
for (cases) |case| {
var s = Self.init(case.in_text);
const opt_res = s.infix_operator(case.in_min_bp);
if (case.out) |out| {
const res = opt_res orelse return error.TestExpectedNonNull;
try std.testing.expectEqual(out, res);
} else {
try std.testing.expectEqual(null, opt_res);
}
}
}
pub const ParseError = error{
ExpectedExpression,
ExpectedOpenParen,
ExpectedCloseParen,
ExpectedComma,
} || std.fmt.ParseFloatError;
fn paren_expr(self: *Self) ParseError!f64 {
const inner = try self.expr(0);
if (!self.literal(")")) return error.ExpectedCloseParen;
return inner;
}
fn prefix_expr(self: *Self, op: PrefixOperator) ParseError!f64 {
const inner = try self.expr(op.right_bind_power());
return switch (op) {
.Negate => -inner,
};
}
fn pow_expr(self: *Self) ParseError!f64 {
self.space();
if (!self.literal("(")) return error.ExpectedOpenParen;
const x = try self.expr(0);
if (!self.literal(",")) return error.ExpectedComma;
const y = try self.paren_expr();
return std.math.pow(f64, x, y);
}
fn sqrt_expr(self: *Self) ParseError!f64 {
self.space();
if (!self.literal("(")) return error.ExpectedOpenParen;
const inner = try self.paren_expr();
return std.math.sqrt(inner);
}
fn cbrt_expr(self: *Self) ParseError!f64 {
self.space();
if (!self.literal("(")) return error.ExpectedOpenParen;
const inner = try self.paren_expr();
return std.math.cbrt(inner);
}
fn expr(self: *Self, min_bind_power: u8) ParseError!f64 {
self.space();
var num = blk: {
if (try self.number()) |n| break :blk n;
if (self.literal("(")) break :blk try self.paren_expr();
if (self.prefix_operator()) |op| break :blk try self.prefix_expr(op);
if (self.literal("pow")) break :blk try self.pow_expr();
if (self.literal("sqrt")) break :blk try self.sqrt_expr();
if (self.literal("cbrt")) break :blk try self.cbrt_expr();
return error.ExpectedExpression;
};
while (self.text.len > 0) {
self.space();
const op = self.infix_operator(min_bind_power) orelse break;
const inner = try self.expr(op.bind_powers().right);
switch (op) {
.Add => num += inner,
.Subtract => num -= inner,
.Multiply => num *= inner,
.Divide => num /= inner,
.Exponent => num = std.math.pow(f64, num, inner),
}
}
self.space();
return num;
}
const ExprTestCase = struct {
out: f64,
in: []const u8,
};
fn testExprs(cases: []const ExprTestCase) !void {
for (cases) |case| {
var s = Self.init(case.in);
try expectFloatEq(f64, case.out, try s.expr(0));
}
}
test "expr_solo" {
try testExprs(&.{
.{ .out = 1, .in = "1" },
.{ .out = 0.3, .in = "0.3" },
});
}
test "expr_ops" {
try testExprs(&.{
.{ .out = 15, .in = "12+3" },
.{ .out = 9, .in = "12-3" },
.{ .out = 36, .in = "12*3" },
.{ .out = 4, .in = "12/3" },
.{ .out = 1728, .in = "12^3" },
});
}
test "expr_precedence" {
try testExprs(&.{
.{ .out = 10, .in = "5+4*3/2-1" },
.{ .out = 5, .in = "4+3*6/9-1" },
.{ .out = 1, .in = "5+2*6/4-7" },
.{ .out = 81, .in = "-9^2" },
});
}
test "paren_expr" {
try testExprs(&.{
.{ .out = 1, .in = "(1)" },
.{ .out = 1, .in = "(((1)))" },
.{ .out = 60, .in = "(2+3)*(5+7)" },
});
}
test "prefix_expr" {
try testExprs(&.{
.{ .out = -1, .in = "-1" },
.{ .out = 1, .in = "--1" },
.{ .out = -1, .in = "---1" },
.{ .out = -1, .in = "- 1" },
.{ .out = -1, .in = "(-1)" },
.{ .out = -1, .in = "-(1)" },
.{ .out = 1, .in = "-(-1)" },
.{ .out = 2, .in = "1--1" },
.{ .out = 0, .in = "1---1" },
});
}
test "pow_expr" {
try testExprs(&.{
.{ .out = 1, .in = "pow(1 , 1)" },
.{ .out = 4, .in = "pow(2,2)" },
.{ .out = 8, .in = "pow(2,3)" },
.{ .out = 9, .in = "pow(3,2)" },
.{ .out = 27, .in = "pow(3,3)" },
.{ .out = 1, .in = "pow(4,0)" },
});
}
test "sqrt_expr" {
try testExprs(&.{
.{ .out = 1, .in = "sqrt(1)" },
.{ .out = 2, .in = "sqrt(4)" },
.{ .out = 3, .in = "sqrt(9)" },
.{ .out = 4, .in = "sqrt(16)" },
.{ .out = 5, .in = "sqrt(25)" },
.{ .out = 1, .in = "sqrt (1)" },
});
}
test "cbrt_expr" {
try testExprs(&.{
.{ .out = 1, .in = "cbrt(1)" },
.{ .out = 2, .in = "cbrt(8)" },
.{ .out = 3, .in = "cbrt(27)" },
.{ .out = 4, .in = "cbrt(64)" },
.{ .out = 5, .in = "cbrt(125)" },
.{ .out = 1, .in = "cbrt (1)" },
});
}
};
pub fn main() !void {
const stdin = std.io.getStdIn().reader();
const stdout = std.io.getStdOut().writer();
var gpa = std.heap.DebugAllocator(.{}).init;
defer if (gpa.deinit() != .ok) @panic("GPA LEAK");
const alloc = gpa.allocator();
const buf = try alloc.alloc(u8, 4096);
defer alloc.free(buf);
const input_and_prompt: struct {
input: []u8,
prompt: ?[]const u8,
} = blk: {
var args = std.process.args();
_ = args.next();
if (args.next()) |arg| {
const len = @min(arg.len, buf.len);
const arg_input = buf[0..len];
@memcpy(arg_input, arg);
break :blk .{ .input = arg_input, .prompt = null };
} else {
const prompt = "Expression: ";
try stdout.print("{s}", .{prompt});
break :blk .{
.input = stdin.readUntilDelimiter(buf, '\n') catch |err| switch (err) {
error.EndOfStream => {
try stdout.print("\n", .{});
return;
},
else => return err,
},
.prompt = prompt,
};
}
};
const input = input_and_prompt.input;
const prompt = input_and_prompt.prompt;
var parser = Parser.init(input);
const result = blk: {
if (parser.expr(0)) |result| {
break :blk if (parser.text.len > 0) error.ParseError else result;
} else |err| {
break :blk err;
}
};
if (result) |r| {
try stdout.print("{d}\n", .{r});
} else |err| {
if (prompt) |p| {
for (p) |_| try stdout.print(" ", .{});
for (parser.text.len..input.len) |_| try stdout.print(" ", .{});
try stdout.print("^\n", .{});
}
return err;
}
}
test "refs" {
std.testing.refAllDecls(@This());
_ = Parser;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment