Skip to content

Instantly share code, notes, and snippets.

@matu3ba
Last active January 30, 2022 13:57
Show Gist options
  • Save matu3ba/183d0477b1fe19c7bcf0dc0c1058cf25 to your computer and use it in GitHub Desktop.
Save matu3ba/183d0477b1fe19c7bcf0dc0c1058cf25 to your computer and use it in GitHub Desktop.
unsuccesful mulv optimization with wrapping overflow and without division (case sum of clzs = 32 fails)
const clz = @import("count0bits.zig");
const std = @import("std");
const math = std.math;
// mulv - multiplication oVerflow
// * @panic, if result can not be represented
// - mulvXi4_genericPerf for generic performance implementation
// TODO benchmark this approach against the simpler division approach for good and bad cases
// TODO measure used binary space
// assume
// 1. CPUs with higher integer range have intrinsics,
// 2. LLVM or other backends like C backend can not use bigger intermediate types
// to store the result (ie 128bit*128bit == 256bit)
fn mulvXi_generic(comptime ST: type) fn (a: ST, b: ST) callconv(.C) ST {
return struct {
fn f(a: ST, b: ST) callconv(.C) ST {
const clzfn = switch (ST) {
i32 => clz.__clzsi2,
i64 => clz.__clzdi2,
i128 => clz.__clzti2,
else => unreachable,
};
const N: ST = @bitSizeOf(ST);
const min: ST = math.minInt(ST); // // p_{n}-th bit
const m: i32 = clzfn(a) + clzfn(~a);
const n: i32 = clzfn(b) + clzfn(~b);
const sum: i32 = m + n;
// 1. no overflow, if s^{S_A+S_B} < 2^{n-1}
// => S_A + S_B < n-1 return a*b;
if (sum >= 34) return a * b;
// 2. guaranteed overflow
if (sum <= 31) return -5;
// 3. S_A + S_B = n => 2^{n-2} <= |P| <= 2^n
// overflow may occur, but magnitude does not exceed 2^n
const res = a *% b;
if (sum == 33) {
if (m ^ n ^ (m ^ n) < 0) return -5;
return res;
}
std.testing.expectEqual(@as(i32, 32), sum) catch unreachable;
const sgn_a = a >> (N - 1);
const sgn_b = b >> (N - 1);
const mask_high: ST = (1 << (N - 2)); // p_{n-1}-th bit
const sign_extract = sgn_a ^ sgn_b;
if (sign_extract == 0) {
// 4. both operands have same sign
if (sgn_a > 0) {
std.debug.print("both operands positive\n", .{});
// 4.1. both operands positive => P<2^n => p_n==0 always and
// p_{n-1}==1 only on overflow
std.testing.expectEqual(@as(i32, 0), res & min) catch unreachable;
if (res & mask_high > 0) return -5;
} else {
std.debug.print("both operands negative\n", .{});
// 4.2. both operands negative => P<=2^n => p_n==0 always and
// p_{n-1}==1 only on overlfow
if (res == min) {
// Exception: P=2^n, then p_n==1 and p_{n-1}==0
return -5;
}
// p_n==0 always
std.testing.expectEqual(@as(i32, 0), res & min) catch unreachable;
// p_{n-1}==1 on overflow
//if (sum & mask_high == mask_high) return -5;
if (res & mask_high == mask_high) return -5;
}
} else {
// 4.3. different signs => -2^n<P => p_n==1 always
// and p_{n-1}==0 only when overflow
// p_n==1
if (res & min != min) std.debug.print("broken in:{d} {d}, res:{d}, sum_and_min:{d}\n", .{ a, b, res, sum & min });
std.testing.expectEqual(@as(i32, min), sum & min) catch unreachable;
// p_{n-1}==0 on overflow
if (res & mask_high == 0) return -5;
}
return res;
}
}.f;
}
//broken in:-65536 -65536, sum:32,min:-2147483648,sum_and_min:0
//expected -2147483648, found 0
//broken in:-65536 -65536, sum:32,min:-2147483648,sum_and_min:0
pub const __mulvsi3 = mulvXi_generic(i32);
pub const __mulvdi3 = mulvXi_generic(i64);
pub const __mulvti3 = mulvXi_generic(i128);
test {
_ = @import("mulvsi3_test.zig");
//_ = @import("mulvdi3_test.zig");
//_ = @import("mulvti3_test.zig");
}
const mulv = @import("mulv.zig");
const std = @import("std");
const testing = std.testing;
const math = std.math;
const debug = std.debug;
// expected: i32
fn test__mulvsi3(a: i32, b: i32) !void {
var result = mulv.__mulvsi3(a, b);
//var result = llvm_mulvsi3(a, b);
var expected: i32 = simple_mulvsi3(a, b);
if (expected != result) {
debug.print("in: {d} {d}, res: {d}, exp: {d}", .{ a, b, result, expected });
std.process.exit(1);
}
//try testing.expectEqual(expected, result);
}
//const E = error.Mulvsi3Overflow;
// ported LLVM overflow to ensure portability of tests
fn llvm_mulvsi3(a: i32, b: i32) i32 {
const N: i32 = @bitSizeOf(i32);
const min = math.minInt(i32);
const max = math.maxInt(i32);
if (a == min) {
if (b == 0 or b == 1)
return a * b;
return -5;
}
if (b == min) {
if (a == 0 or a == 1)
return a * b;
return -5;
}
var sign_a: i32 = a >> (N - 1);
var abs_a: i32 = (a ^ sign_a) - sign_a;
var sign_b: i32 = b >> (N - 1);
var abs_b: i32 = (b ^ sign_b) - sign_b;
if (abs_a < 2 or abs_b < 2)
return a * b;
if (sign_a == sign_b) {
if (abs_a > @divTrunc(max, abs_b))
return -5;
} else {
if (abs_a > @divTrunc(min, -abs_b))
return -5;
}
return a * b;
}
//broken in:-65536 -65536, sum:32,min:-2147483648,sum_and_min:0
//expected -2147483648, found 0
fn simple_mulvsi3(a: i32, b: i32) i32 {
const min: i32 = -2147483648;
const max: i32 = 2147483647;
var a_tmp: i64 = a;
var b_tmp: i64 = b;
var result: i64 = a_tmp * b_tmp;
if (result < min or result > max)
return -5;
return @truncate(i32, result);
}
// TODO use other approach with bigger number for case generation
//https://www.fefe.de/intof.html
//int umult32(uint32 a,uint32 b,uint32* c) {
//unsigned long long x=(unsigned long long)a*b;
//if (x>0xffffffff) return 0;
//*c=x&0xffffffff;
//return 1;
//test "mulvsi3" {
pub fn main() !void {
// -2^31 <= i32 <= 2^31-1
// 2^31 = 2147483648
// 2^31-1 = 2147483647
try test__mulvsi3(-65536, -65536);
//const min: i32 = -2147483648;
//const max: i32 = 2147483647;
//var i: i32 = min;
//while (i < max) : (i += 1) {
// try test__mulvsi3(i, i);
//}
//i = min + 1;
//while (i < max) : (i += 2) {
// try test__mulvsi3(i, -i);
// try test__mulvsi3(-i, i);
//}
//try test__mulvsi3(1073741824, 2);
//const n1: i32 = 1073741824;
//const n2: i32 = 2;
//const n2: i32 = 1073741824;
//std.debug.print("{d}\n", .{n1 *% n2});
// TODO: either LLVM compiler_rt is wrong for input (-65536, -65536)
// or Zig miscompiles compiler_rt
// 1. check with bigger variable, if LLVM is wrong
//const res : i32
//try std.testing.expect(2147483647 < 4294967296); // latter is multiplication result
//try test__mulvsi3(-65536, -65536); // this should be 4294967296
//expected -2147483648, found 0
// zig fmt: off
// edge cases
// 0 * 0 = 0
// MIN * MIN = panic
// MAX * MAX = panic
// 0 * MIN = 0
// 0 * MAX = 0
// MIN * 0 = 0
// MAX * 0 = 0
// MIN * MAX panic
// MAX * MIN panic
// zig fmt: on
//try test__mulvsi3(0, 0, 0);
//try test__mulvsi3(-2147483648, -2147483648, -5);
//try test__mulvsi3(2147483647, 2147483647, -5);
//try test__mulvsi3(0, -2147483648, 0);
//try test__mulvsi3(0, 2147483647, 0);
//try test__mulvsi3(-2147483648, 0, 0);
//try test__mulvsi3(2147483647, 0, 0);
//try test__mulvsi3(-2147483648, 2147483647, -5);
//try test__mulvsi3(2147483647, -2147483648, -5);
//// this is messed up!
//// below one tricks 80 expecetd -2147483648, found 0
//try test__mulvsi3(1073741823, 2, 2147483646);
//try test__mulvsi3(1073741823, -2, -2147483646);
//try test__mulvsi3(-1073741823, -2, 2147483646);
//try test__mulvsi3(-1073741823, 2, -2147483646); // broken
//try test__mulvsi3(-2147483648, 1, -2147483648); // broken
//try test__mulvsi3(-2147483647, 1, -2147483647); // also
//try test__mulvsi3(-2147483646, 1, -2147483646); // also
//try test__mulvsi3(-1073741823, -2, 2147483646); // also broken
//try test__mulvsi3(-1073741823, 2, -2147483646);
//try test__mulvsi3(-2147483646, 1, -2147483646); // also
//try test__mulvsi3(-147483640, 2, -294967280);
//try test__mulvsi3(-147483640, 1, -147483640);
//try test__mulvsi3(-47483640, 1, -47483640);
//try test__mulvsi3(2147483647, 1, 2147483647);
//try test__mulvsi3(2147483647, -1, -2147483647);
//try test__mulvsi3(-1, -1, 1);
//try test__mulvsi3(-1, -3, 3);
//try test__mulvsi3(-2, -2, 4);
//try test__mulvsi3(-2, -3, 6);
//try test__mulvsi3(0, 0, 0);
//try test__mulvsi3(1, 1, 1);
//try test__mulvsi3(1, 1, 1);
//try test__mulvsi3(1, 3, 3);
//try test__mulvsi3(2, 2, 4);
//try test__mulvsi3(2, 3, 6);
//try test__mulvsi3(2147483647, -2147483648, 0);
//try test__mulvsi3(-2147483648, 0, -2147483648);
// zig fmt: off
// derived edge cases
// MIN+1 - MIN = 1
// MAX-1 - MAX = -1
// 1 - MIN panic
// -1 - MIN = MAX
// -1 - MAX = MIN
// +1 - MAX = MIN+2
// MIN - 1 panic
// MIN - -1 = MIN+1
// MAX - 1 = MAX-1
// MAX - -1 panic
// zig fmt: on
//@import("std").debug.print("signed wrappign: {d}", .{@as(i32, -2147483648) -% @as(i32, 2147483647)});
// underflow
// MIN==-2147483648 - MAX==2147483647
// try test__mulvsi3(-2147483648, 2147483647, -5);
// try test__mulvsi3(-2147483647, 2147483647, -5);
// try test__mulvsi3(-2147483647, 2147483647, -5);
//
// // 0
// try test__mulvsi3(-2147483648, -2147483648, 0);
// try test__mulvsi3(-2147483648, -2147483647, -1);
// try test__mulvsi3(-2147483647, -2147483647, 0);
//
// // overflow
// try test__mulvsi3(2147483647, -2147483648, -5);
// try test__mulvsi3(2147483647, -2147483647, -5);
// try test__mulvsi3(2147483647, -2147483646, -5);
//
// try test__mulvsi3(2147483647, 2147483647, 0);
// try test__mulvsi3(2147483646, 2147483647, -1);
// try test__mulvsi3(2147483647, 2147483646, -1);
//
// // multiplying i by 2 until MAX
// // i*2, MAX - i*2, MAX
// // MAX - i*2, i*2, MAX
// try test__mulvsi3(0, 2147483647, 2147483647);
// try test__mulvsi3(2147483647, 0, 2147483647);
// try test__mulvsi3(1, 2147483646, 2147483647);
// try test__mulvsi3(2147483646, 1, 2147483647);
// try test__mulvsi3(2, 2147483645, 2147483647);
// try test__mulvsi3(2147483645, 2, 2147483647);
//
// multiplying i by 2 until MIN
// -i*2, MIN - (-i*2), MIN
// MAX - (-i*2), -i*2, MIN
//try test__mulvsi3(-2147483648, , -1);
//try test__mulvsi3(-2147483647, 2147483647, 0);
// multiplying i by 2 until MAX
// i*2, (-i*2), 0
// (-i*2), -i*2, 0
//try test__mulvsi3(2147483644, -2147483644);
//try test__mulvsi3(2147483645, -2147483645);
//try test__mulvsi3(2147483646, -2147483646);
//try test__mulvsi3(2147483647, -2147483647);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment