Last active
January 30, 2022 13:57
-
-
Save matu3ba/183d0477b1fe19c7bcf0dc0c1058cf25 to your computer and use it in GitHub Desktop.
unsuccesful mulv optimization with wrapping overflow and without division (case sum of clzs = 32 fails)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| const clz = @import("count0bits.zig"); | |
| const std = @import("std"); | |
| const math = std.math; | |
| // mulv - multiplication oVerflow | |
| // * @panic, if result can not be represented | |
| // - mulvXi4_genericPerf for generic performance implementation | |
| // TODO benchmark this approach against the simpler division approach for good and bad cases | |
| // TODO measure used binary space | |
| // assume | |
| // 1. CPUs with higher integer range have intrinsics, | |
| // 2. LLVM or other backends like C backend can not use bigger intermediate types | |
| // to store the result (ie 128bit*128bit == 256bit) | |
| fn mulvXi_generic(comptime ST: type) fn (a: ST, b: ST) callconv(.C) ST { | |
| return struct { | |
| fn f(a: ST, b: ST) callconv(.C) ST { | |
| const clzfn = switch (ST) { | |
| i32 => clz.__clzsi2, | |
| i64 => clz.__clzdi2, | |
| i128 => clz.__clzti2, | |
| else => unreachable, | |
| }; | |
| const N: ST = @bitSizeOf(ST); | |
| const min: ST = math.minInt(ST); // // p_{n}-th bit | |
| const m: i32 = clzfn(a) + clzfn(~a); | |
| const n: i32 = clzfn(b) + clzfn(~b); | |
| const sum: i32 = m + n; | |
| // 1. no overflow, if s^{S_A+S_B} < 2^{n-1} | |
| // => S_A + S_B < n-1 return a*b; | |
| if (sum >= 34) return a * b; | |
| // 2. guaranteed overflow | |
| if (sum <= 31) return -5; | |
| // 3. S_A + S_B = n => 2^{n-2} <= |P| <= 2^n | |
| // overflow may occur, but magnitude does not exceed 2^n | |
| const res = a *% b; | |
| if (sum == 33) { | |
| if (m ^ n ^ (m ^ n) < 0) return -5; | |
| return res; | |
| } | |
| std.testing.expectEqual(@as(i32, 32), sum) catch unreachable; | |
| const sgn_a = a >> (N - 1); | |
| const sgn_b = b >> (N - 1); | |
| const mask_high: ST = (1 << (N - 2)); // p_{n-1}-th bit | |
| const sign_extract = sgn_a ^ sgn_b; | |
| if (sign_extract == 0) { | |
| // 4. both operands have same sign | |
| if (sgn_a > 0) { | |
| std.debug.print("both operands positive\n", .{}); | |
| // 4.1. both operands positive => P<2^n => p_n==0 always and | |
| // p_{n-1}==1 only on overflow | |
| std.testing.expectEqual(@as(i32, 0), res & min) catch unreachable; | |
| if (res & mask_high > 0) return -5; | |
| } else { | |
| std.debug.print("both operands negative\n", .{}); | |
| // 4.2. both operands negative => P<=2^n => p_n==0 always and | |
| // p_{n-1}==1 only on overlfow | |
| if (res == min) { | |
| // Exception: P=2^n, then p_n==1 and p_{n-1}==0 | |
| return -5; | |
| } | |
| // p_n==0 always | |
| std.testing.expectEqual(@as(i32, 0), res & min) catch unreachable; | |
| // p_{n-1}==1 on overflow | |
| //if (sum & mask_high == mask_high) return -5; | |
| if (res & mask_high == mask_high) return -5; | |
| } | |
| } else { | |
| // 4.3. different signs => -2^n<P => p_n==1 always | |
| // and p_{n-1}==0 only when overflow | |
| // p_n==1 | |
| if (res & min != min) std.debug.print("broken in:{d} {d}, res:{d}, sum_and_min:{d}\n", .{ a, b, res, sum & min }); | |
| std.testing.expectEqual(@as(i32, min), sum & min) catch unreachable; | |
| // p_{n-1}==0 on overflow | |
| if (res & mask_high == 0) return -5; | |
| } | |
| return res; | |
| } | |
| }.f; | |
| } | |
| //broken in:-65536 -65536, sum:32,min:-2147483648,sum_and_min:0 | |
| //expected -2147483648, found 0 | |
| //broken in:-65536 -65536, sum:32,min:-2147483648,sum_and_min:0 | |
| pub const __mulvsi3 = mulvXi_generic(i32); | |
| pub const __mulvdi3 = mulvXi_generic(i64); | |
| pub const __mulvti3 = mulvXi_generic(i128); | |
| test { | |
| _ = @import("mulvsi3_test.zig"); | |
| //_ = @import("mulvdi3_test.zig"); | |
| //_ = @import("mulvti3_test.zig"); | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| const mulv = @import("mulv.zig"); | |
| const std = @import("std"); | |
| const testing = std.testing; | |
| const math = std.math; | |
| const debug = std.debug; | |
| // expected: i32 | |
| fn test__mulvsi3(a: i32, b: i32) !void { | |
| var result = mulv.__mulvsi3(a, b); | |
| //var result = llvm_mulvsi3(a, b); | |
| var expected: i32 = simple_mulvsi3(a, b); | |
| if (expected != result) { | |
| debug.print("in: {d} {d}, res: {d}, exp: {d}", .{ a, b, result, expected }); | |
| std.process.exit(1); | |
| } | |
| //try testing.expectEqual(expected, result); | |
| } | |
| //const E = error.Mulvsi3Overflow; | |
| // ported LLVM overflow to ensure portability of tests | |
| fn llvm_mulvsi3(a: i32, b: i32) i32 { | |
| const N: i32 = @bitSizeOf(i32); | |
| const min = math.minInt(i32); | |
| const max = math.maxInt(i32); | |
| if (a == min) { | |
| if (b == 0 or b == 1) | |
| return a * b; | |
| return -5; | |
| } | |
| if (b == min) { | |
| if (a == 0 or a == 1) | |
| return a * b; | |
| return -5; | |
| } | |
| var sign_a: i32 = a >> (N - 1); | |
| var abs_a: i32 = (a ^ sign_a) - sign_a; | |
| var sign_b: i32 = b >> (N - 1); | |
| var abs_b: i32 = (b ^ sign_b) - sign_b; | |
| if (abs_a < 2 or abs_b < 2) | |
| return a * b; | |
| if (sign_a == sign_b) { | |
| if (abs_a > @divTrunc(max, abs_b)) | |
| return -5; | |
| } else { | |
| if (abs_a > @divTrunc(min, -abs_b)) | |
| return -5; | |
| } | |
| return a * b; | |
| } | |
| //broken in:-65536 -65536, sum:32,min:-2147483648,sum_and_min:0 | |
| //expected -2147483648, found 0 | |
| fn simple_mulvsi3(a: i32, b: i32) i32 { | |
| const min: i32 = -2147483648; | |
| const max: i32 = 2147483647; | |
| var a_tmp: i64 = a; | |
| var b_tmp: i64 = b; | |
| var result: i64 = a_tmp * b_tmp; | |
| if (result < min or result > max) | |
| return -5; | |
| return @truncate(i32, result); | |
| } | |
| // TODO use other approach with bigger number for case generation | |
| //https://www.fefe.de/intof.html | |
| //int umult32(uint32 a,uint32 b,uint32* c) { | |
| //unsigned long long x=(unsigned long long)a*b; | |
| //if (x>0xffffffff) return 0; | |
| //*c=x&0xffffffff; | |
| //return 1; | |
| //test "mulvsi3" { | |
| pub fn main() !void { | |
| // -2^31 <= i32 <= 2^31-1 | |
| // 2^31 = 2147483648 | |
| // 2^31-1 = 2147483647 | |
| try test__mulvsi3(-65536, -65536); | |
| //const min: i32 = -2147483648; | |
| //const max: i32 = 2147483647; | |
| //var i: i32 = min; | |
| //while (i < max) : (i += 1) { | |
| // try test__mulvsi3(i, i); | |
| //} | |
| //i = min + 1; | |
| //while (i < max) : (i += 2) { | |
| // try test__mulvsi3(i, -i); | |
| // try test__mulvsi3(-i, i); | |
| //} | |
| //try test__mulvsi3(1073741824, 2); | |
| //const n1: i32 = 1073741824; | |
| //const n2: i32 = 2; | |
| //const n2: i32 = 1073741824; | |
| //std.debug.print("{d}\n", .{n1 *% n2}); | |
| // TODO: either LLVM compiler_rt is wrong for input (-65536, -65536) | |
| // or Zig miscompiles compiler_rt | |
| // 1. check with bigger variable, if LLVM is wrong | |
| //const res : i32 | |
| //try std.testing.expect(2147483647 < 4294967296); // latter is multiplication result | |
| //try test__mulvsi3(-65536, -65536); // this should be 4294967296 | |
| //expected -2147483648, found 0 | |
| // zig fmt: off | |
| // edge cases | |
| // 0 * 0 = 0 | |
| // MIN * MIN = panic | |
| // MAX * MAX = panic | |
| // 0 * MIN = 0 | |
| // 0 * MAX = 0 | |
| // MIN * 0 = 0 | |
| // MAX * 0 = 0 | |
| // MIN * MAX panic | |
| // MAX * MIN panic | |
| // zig fmt: on | |
| //try test__mulvsi3(0, 0, 0); | |
| //try test__mulvsi3(-2147483648, -2147483648, -5); | |
| //try test__mulvsi3(2147483647, 2147483647, -5); | |
| //try test__mulvsi3(0, -2147483648, 0); | |
| //try test__mulvsi3(0, 2147483647, 0); | |
| //try test__mulvsi3(-2147483648, 0, 0); | |
| //try test__mulvsi3(2147483647, 0, 0); | |
| //try test__mulvsi3(-2147483648, 2147483647, -5); | |
| //try test__mulvsi3(2147483647, -2147483648, -5); | |
| //// this is messed up! | |
| //// below one tricks 80 expecetd -2147483648, found 0 | |
| //try test__mulvsi3(1073741823, 2, 2147483646); | |
| //try test__mulvsi3(1073741823, -2, -2147483646); | |
| //try test__mulvsi3(-1073741823, -2, 2147483646); | |
| //try test__mulvsi3(-1073741823, 2, -2147483646); // broken | |
| //try test__mulvsi3(-2147483648, 1, -2147483648); // broken | |
| //try test__mulvsi3(-2147483647, 1, -2147483647); // also | |
| //try test__mulvsi3(-2147483646, 1, -2147483646); // also | |
| //try test__mulvsi3(-1073741823, -2, 2147483646); // also broken | |
| //try test__mulvsi3(-1073741823, 2, -2147483646); | |
| //try test__mulvsi3(-2147483646, 1, -2147483646); // also | |
| //try test__mulvsi3(-147483640, 2, -294967280); | |
| //try test__mulvsi3(-147483640, 1, -147483640); | |
| //try test__mulvsi3(-47483640, 1, -47483640); | |
| //try test__mulvsi3(2147483647, 1, 2147483647); | |
| //try test__mulvsi3(2147483647, -1, -2147483647); | |
| //try test__mulvsi3(-1, -1, 1); | |
| //try test__mulvsi3(-1, -3, 3); | |
| //try test__mulvsi3(-2, -2, 4); | |
| //try test__mulvsi3(-2, -3, 6); | |
| //try test__mulvsi3(0, 0, 0); | |
| //try test__mulvsi3(1, 1, 1); | |
| //try test__mulvsi3(1, 1, 1); | |
| //try test__mulvsi3(1, 3, 3); | |
| //try test__mulvsi3(2, 2, 4); | |
| //try test__mulvsi3(2, 3, 6); | |
| //try test__mulvsi3(2147483647, -2147483648, 0); | |
| //try test__mulvsi3(-2147483648, 0, -2147483648); | |
| // zig fmt: off | |
| // derived edge cases | |
| // MIN+1 - MIN = 1 | |
| // MAX-1 - MAX = -1 | |
| // 1 - MIN panic | |
| // -1 - MIN = MAX | |
| // -1 - MAX = MIN | |
| // +1 - MAX = MIN+2 | |
| // MIN - 1 panic | |
| // MIN - -1 = MIN+1 | |
| // MAX - 1 = MAX-1 | |
| // MAX - -1 panic | |
| // zig fmt: on | |
| //@import("std").debug.print("signed wrappign: {d}", .{@as(i32, -2147483648) -% @as(i32, 2147483647)}); | |
| // underflow | |
| // MIN==-2147483648 - MAX==2147483647 | |
| // try test__mulvsi3(-2147483648, 2147483647, -5); | |
| // try test__mulvsi3(-2147483647, 2147483647, -5); | |
| // try test__mulvsi3(-2147483647, 2147483647, -5); | |
| // | |
| // // 0 | |
| // try test__mulvsi3(-2147483648, -2147483648, 0); | |
| // try test__mulvsi3(-2147483648, -2147483647, -1); | |
| // try test__mulvsi3(-2147483647, -2147483647, 0); | |
| // | |
| // // overflow | |
| // try test__mulvsi3(2147483647, -2147483648, -5); | |
| // try test__mulvsi3(2147483647, -2147483647, -5); | |
| // try test__mulvsi3(2147483647, -2147483646, -5); | |
| // | |
| // try test__mulvsi3(2147483647, 2147483647, 0); | |
| // try test__mulvsi3(2147483646, 2147483647, -1); | |
| // try test__mulvsi3(2147483647, 2147483646, -1); | |
| // | |
| // // multiplying i by 2 until MAX | |
| // // i*2, MAX - i*2, MAX | |
| // // MAX - i*2, i*2, MAX | |
| // try test__mulvsi3(0, 2147483647, 2147483647); | |
| // try test__mulvsi3(2147483647, 0, 2147483647); | |
| // try test__mulvsi3(1, 2147483646, 2147483647); | |
| // try test__mulvsi3(2147483646, 1, 2147483647); | |
| // try test__mulvsi3(2, 2147483645, 2147483647); | |
| // try test__mulvsi3(2147483645, 2, 2147483647); | |
| // | |
| // multiplying i by 2 until MIN | |
| // -i*2, MIN - (-i*2), MIN | |
| // MAX - (-i*2), -i*2, MIN | |
| //try test__mulvsi3(-2147483648, , -1); | |
| //try test__mulvsi3(-2147483647, 2147483647, 0); | |
| // multiplying i by 2 until MAX | |
| // i*2, (-i*2), 0 | |
| // (-i*2), -i*2, 0 | |
| //try test__mulvsi3(2147483644, -2147483644); | |
| //try test__mulvsi3(2147483645, -2147483645); | |
| //try test__mulvsi3(2147483646, -2147483646); | |
| //try test__mulvsi3(2147483647, -2147483647); | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment