Skip to content

Instantly share code, notes, and snippets.

@lichray
Last active November 2, 2018 06:12
Show Gist options
  • Save lichray/1e546f7680842be4ba11d9ba7c0165c4 to your computer and use it in GitHub Desktop.
Save lichray/1e546f7680842be4ba11d9ba7c0165c4 to your computer and use it in GitHub Desktop.
Fraction class (MSVC support is experimental)
#include <limits>
#include <ostream>
#include <stdexcept>
#include <system_error>
#include <type_traits>
#include <assert.h>
#include <stdint.h>
#include <stdio.h>
#ifdef _MSC_VER
#include <intrin.h>
#pragma intrinsic(_umul128)
#pragma intrinsic(_BitScanReverse64)
#pragma warning(disable : 4146)
#endif
namespace ext
{
template <class T>
struct pair
{
using value_type = pair<std::remove_reference_t<T>>;
T a, b;
pair(pair const&) = default;
pair(pair&&) = default;
constexpr pair& operator=(value_type const& other)
{
a = other.a;
b = other.b;
return *this;
}
constexpr pair& operator=(value_type&& other)
{
a = std::move(other.a);
b = std::move(other.b);
return *this;
}
};
template <class T>
constexpr pair<T&> tie(T& a, T& b) noexcept
{
return { a, b };
}
// inspired by P0437R0
template <class T>
constexpr auto num_digits = std::numeric_limits<T>::digits;
template <class T>
constexpr auto num_digits10 = std::numeric_limits<T>::digits10;
template <class T>
constexpr auto num_min = (std::numeric_limits<T>::min)();
template <class T>
constexpr auto num_max = (std::numeric_limits<T>::max)();
template <class T>
constexpr T mod(T a, T b)
{
return a % b;
}
template <class T>
constexpr T gcd(T a, T b)
{
while (b)
tie(a, b) = pair<T>{ b, mod(a, b) };
return a;
}
// the design follows <bit> (P0553R2)
template <class T>
inline int countl_zero(T x);
template <>
inline int countl_zero(unsigned long long x)
{
assert(x != 0);
#ifndef _MSC_VER
return __builtin_clzll(x);
#else
unsigned long i;
_BitScanReverse64(&i, x);
return i;
#endif
}
template <class T>
struct abs_result
{
bool sign;
std::make_unsigned_t<T> val;
};
// represent a signed integer in sign-magnitude using an extra bit
template <class T>
constexpr abs_result<T> abs(T v)
{
using U = decltype(abs_result<T>{}.val);
if (v < 0)
return { true, static_cast<U>(-U(v)) };
else
return { false, U(v) };
}
// may be specialized to create a user-defined extended type for T
// if the member typedef-name `type` names the injected-class-name
template <class T>
struct extended_result;
template <>
struct extended_result<char>
{
using type = int;
};
template <>
struct extended_result<signed char>
{
using type = int;
};
template <>
struct extended_result<unsigned char>
{
using type = unsigned;
};
template <>
struct extended_result<short>
{
using type = int;
};
template <>
struct extended_result<unsigned short>
{
using type = unsigned;
};
template <>
struct extended_result<int>
{
using type = int64_t;
};
template <>
struct extended_result<unsigned int>
{
using type = uint64_t;
};
template <>
struct extended_result<long>
{
#ifndef _MSC_VER
using type = std::conditional_t<num_digits<long> == num_digits<int>,
int64_t, __int128_t>;
#else
using type = int64_t;
#endif
};
template <>
struct extended_result<unsigned long>
{
#ifndef _MSC_VER
using type = std::conditional_t<num_digits<long> == num_digits<int>,
uint64_t, __uint128_t>;
#else
using type = uint64_t;
#endif
};
template <>
struct extended_result<long long>
{
#ifndef _MSC_VER
using type = __int128_t;
#else
using type = extended_result;
uint64_t hi, lo;
#endif
};
template <>
struct extended_result<unsigned long long>
{
#ifndef _MSC_VER
using type = __uint128_t;
#else
using type = extended_result;
uint64_t hi, lo;
friend constexpr bool operator<(extended_result const& x,
extended_result const& y)
{
return x.hi < y.hi || (x.hi == y.hi && x.lo < y.lo);
}
friend constexpr bool operator==(extended_result const& x,
extended_result const& y)
{
return x.hi == y.hi && x.lo == y.lo;
}
constexpr bool is_even() const { return ~lo & 1; }
constexpr bool is_odd() const { return lo & 1; }
constexpr explicit operator bool() const { return lo || hi; }
extended_result& operator<<=(int n) &
{
if (n == 0)
return *this;
if (n < 64)
{
hi = (hi << n) ^ (lo >> (64 - n));
lo <<= n;
}
else
{
hi = lo << (n - 64);
lo = 0;
}
return *this;
}
extended_result& lshift() &
{
hi = (hi << 1) ^ (lo >> 63);
lo <<= 1;
return *this;
}
extended_result lshift() const &
{
auto x = *this;
x.lshift();
return x;
}
extended_result& rshift() &
{
lo = (lo >> 1) | (hi << 63);
hi >>= 1;
return *this;
}
extended_result rshift() const &
{
auto x = *this;
x.rshift();
return x;
}
constexpr operator extended_result<long long>() const
{
return { hi, lo };
}
constexpr extended_result operator-() const
{
auto hi_cmpl = ~hi;
auto lo_cmpl = ~lo;
auto nlo = lo_cmpl + 1;
return { hi_cmpl + (lo_cmpl > nlo), nlo };
}
friend constexpr extended_result operator-(extended_result const& x,
extended_result const& y)
{
return { (x.hi - y.hi - (x.lo < y.lo)), x.lo - y.lo };
}
friend constexpr extended_result operator/(extended_result x,
extended_result y)
{
if (x < y)
return { 0, 0 };
if (x == y)
return { 0, 1 };
/*
if (x.hi == 0 && y.hi == 0)
return { 0, x.lo / y.lo };
*/
extended_result q = {};
auto diff = clz(x) - clz(y);
y <<= diff;
do
{
q.lshift();
if (!(x < y))
{
x = x - y;
q.lo |= 1;
}
y.rshift();
} while (diff--);
return q;
}
static int clz(extended_result const& x)
{
if (x.hi)
return countl_zero(x.hi) + 64;
else
return countl_zero(x.lo);
}
#endif
};
#ifndef _MSC_VER
template <>
struct abs_result<__int128_t>
{
bool sign;
__uint128_t val;
};
#else
template <>
struct abs_result<extended_result<int64_t>>
{
bool sign;
extended_result<uint64_t> val;
};
constexpr auto abs(extended_result<int64_t> const& v)
-> abs_result<extended_result<int64_t>>
{
using U = extended_result<uint64_t>;
if (int64_t(v.hi) < 0)
return { true, -U{ v.hi, v.lo } };
else
return { false, U{ v.hi, v.lo } };
}
// avoid 128-bit modulus with the binary GCD algorithm
constexpr auto gcd(extended_result<uint64_t> const& u,
extended_result<uint64_t> const& v)
-> extended_result<uint64_t>
{
if (u == v)
return u;
if (!u)
return v;
if (!v)
return u;
/*
if (u.hi == 0 && v.hi == 0)
return { 0, gcd(u.lo, v.lo) };
*/
if (u.is_even())
{
if (v.is_odd())
return gcd(u.rshift(), v);
else
return gcd(u.rshift(), v.rshift()).lshift();
}
if (v.is_even())
return gcd(u, v.rshift());
if (v < u)
return gcd((u - v).rshift(), v);
return gcd((v - u).rshift(), u);
}
#endif
template <class T>
using extended_result_t = typename extended_result<T>::type;
template <class T>
extended_result_t<T> _extended(T v)
{
return v;
}
template <class T>
constexpr bool _add_overflowed(T a, T b, T& out)
{
#ifndef _MSC_VER
return __builtin_add_overflow(a, b, &out);
#else
out = a + b;
return a > out;
#endif
}
template <class T, std::enable_if_t<std::is_unsigned<T>::value, int> = 0>
constexpr bool _assign_overflowed(extended_result_t<T> a, T& out)
{
out = static_cast<T>(a);
return num_max<T> < a;
}
template <class T, std::enable_if_t<std::is_signed<T>::value, int> = 0>
constexpr bool _assign_overflowed(extended_result_t<T> a, T& out)
{
out = static_cast<T>(a);
return num_max<T> < a || a < num_min<T>;
}
#ifdef _MSC_VER
template <>
extended_result<uint64_t> _extended(uint64_t v)
{
return { 0, v };
}
inline bool _add_overflowed(extended_result<uint64_t> const& a,
extended_result<uint64_t> const& b,
extended_result<uint64_t>& out)
{
bool carry = _add_overflowed(a.lo, b.lo, out.lo);
uint64_t hi;
carry = _add_overflowed(a.hi, uint64_t(carry), hi);
return carry | _add_overflowed(hi, b.hi, out.hi);
}
template <class T>
constexpr bool _assign_overflowed(extended_result<uint64_t> const& a, T& out)
{
out = a.lo;
if (std::is_signed<T>::value)
return a.hi != -(out < 0);
else
return a.hi;
}
#endif
#define DEFINE_EXTENDED_OP(name, op) \
template <class T> \
constexpr auto _extended_##name(T x, T y) \
{ \
extended_result_t<T> xp = x, yp = y; \
return xp op yp; \
}
DEFINE_EXTENDED_OP(add, +)
DEFINE_EXTENDED_OP(sub, -)
DEFINE_EXTENDED_OP(mul, *)
DEFINE_EXTENDED_OP(div, -)
#undef DEFINE_EXTENDED_OP
#ifdef _MSC_VER
template <>
inline auto _extended_mul(uint64_t x, uint64_t y)
{
extended_result<uint64_t> r;
r.lo = _umul128(x, y, &r.hi);
return r;
}
template <>
inline auto _extended_add(int64_t x, int64_t y)
{
extended_result<int64_t> r;
bool sx = x < 0, sy = y < 0;
if (sx != sy)
{
r.lo = x + y;
r.hi = -(r.lo < 0);
}
else
{
r.lo = uint64_t(x) + uint64_t(y);
r.hi = -uint64_t(sx);
}
return r;
}
template <>
inline auto _extended_sub(int64_t x, int64_t y)
{
extended_result<int64_t> r;
bool sx = x < 0, sy = y < 0;
if (sx == sy)
{
r.lo = x - y;
r.hi = -(r.lo < 0);
}
else
{
r.lo = uint64_t(x) - uint64_t(y);
r.hi = -uint64_t(sx);
}
return r;
}
#endif
// the design follows <charconv>
struct to_chars_result
{
char* ptr;
std::errc ec;
};
template <class T>
auto& _fmt = T{};
template <>
auto& _fmt<signed char> = "%hhd/%hhu";
template <>
auto& _fmt<short> = "%hd/%hu";
template <>
auto& _fmt<int> = "%d/%u";
template <>
auto& _fmt<long> = "%ld/%lu";
template <>
auto& _fmt<long long> = "%lld/%llu";
template <class T>
auto& _fmt1 = T{};
template <>
auto& _fmt1<signed char> = "%hhd";
template <>
auto& _fmt1<short> = "%hd";
template <>
auto& _fmt1<int> = "%d";
template <>
auto& _fmt1<long> = "%ld";
template <>
auto& _fmt1<long long> = "%lld";
// an object of fraction<intN_t> represents the sign along with the
// numerator in a signed integer of N bits and the denominator in an
// unsigned integer of the same size; passing a denominator greater
// or equal to 2^(N-1) is not allowed.
template <class T>
struct fraction
{
static_assert(std::is_signed<T>::value,
"fraction only models signed arithmetic");
static_assert(!std::is_same<T, char>::value,
"char may be unsigned on some platforms");
constexpr fraction() : fraction(0) {}
constexpr fraction(T num) : num_(num), den_(1) {} // implicit
constexpr fraction(T num, T den) : num_(), den_()
{
auto a = abs(num), b = abs(den);
auto t = gcd(a.val, b.val);
if (a.sign != b.sign)
num_ = -(a.val / t);
else
num_ = a.val / t;
den_ = b.val / t;
}
friend constexpr bool operator<(fraction const& x, fraction const& y)
{
if (x.den_ == y.den_)
return x.num_ < y.num_;
else
{
auto a = abs(x.num_), b = abs(y.num_);
if (a.sign == b.sign)
{
auto lnum = _extended_mul(a.val, y.den_),
rnum = _extended_mul(x.den_, b.val);
return a.sign ^ (lnum < rnum);
}
else
return a.sign;
}
}
friend constexpr bool operator==(fraction const& x, fraction const& y)
{
return x.num_ == y.num_ && x.den_ == y.den_;
}
friend constexpr bool operator!=(fraction const& x, fraction const& y)
{
return !(x == y);
}
friend constexpr bool operator>(fraction const& x, fraction const& y)
{
return y < x;
}
friend constexpr bool operator<=(fraction const& x, fraction const& y)
{
return !(y < x);
}
friend constexpr bool operator>=(fraction const& x, fraction const& y)
{
return !(x < y);
}
friend constexpr fraction operator+(fraction const& x, fraction const& y)
{
if (x.den_ == y.den_)
{
auto n = _extended_add(x.num_, y.num_);
return { extended_tag, n, _extended(x.den_) };
}
else
{
auto a = abs(x.num_), b = abs(y.num_);
if (a.sign == b.sign)
return sum(a, b, x.den_, y.den_);
else
return diff(a, b, x.den_, y.den_);
}
}
friend constexpr fraction operator-(fraction const& x, fraction const& y)
{
if (x.den_ == y.den_)
{
auto n = _extended_sub(x.num_, y.num_);
return { extended_tag, n, _extended(x.den_) };
}
else
{
auto a = abs(x.num_), b = abs(y.num_);
if (a.sign == b.sign)
return diff(a, b, x.den_, y.den_);
else
return sum(a, b, x.den_, y.den_);
}
}
friend fraction operator*(fraction const& x, fraction const& y)
{
auto a = abs(x.num_), b = abs(y.num_);
auto num = _extended_mul(a.val, b.val);
auto den = _extended_mul(x.den_, y.den_);
return { extended_tag, a.sign != b.sign, num, den };
}
friend fraction operator/(fraction const& x, fraction const& y)
{
auto a = abs(x.num_), b = abs(y.num_);
auto num = _extended_mul(a.val, y.den_);
auto den = _extended_mul(x.den_, b.val);
return { extended_tag, a.sign != b.sign, num, den };
}
constexpr fraction& operator+=(fraction const& y) &
{
return *this = *this + y;
}
constexpr fraction& operator-=(fraction const& y) &
{
return *this = *this - y;
}
constexpr fraction& operator*=(fraction const& y) &
{
return *this = *this * y;
}
constexpr fraction& operator/=(fraction const& y) &
{
return *this = *this / y;
}
// this function null-terminates the output at where the `ptr`
// member of the return value points to if succeeds, therefore
// the output range [frist, ptr) is at least one character
// shorter than the input range [first, last)
to_chars_result to_chars(char* first, char* last) const
{
auto sz = last - first;
int n = den_ == 1 ? snprintf(first, sz, _fmt1<T>, num_)
: snprintf(first, sz, _fmt<T>, num_, den_);
assert(0 < n);
if (n < sz)
return { first + n, {} };
else
return { last, std::errc::value_too_large };
}
// very simple output of the form [-]N/N for a fraction,
// [-]N for an integer, works well with field width settings
template <class CharT, class Traits>
friend auto& operator<<(std::basic_ostream<CharT, Traits>& out,
fraction const& fr)
{
using std::begin;
using std::end;
// sign + (n+1) + slash + (m+1) + nul
char buf[num_digits10<T> + num_digits10<rep_type> + 5];
fr.to_chars(begin(buf), end(buf));
return out << buf;
}
private:
using rep_type = std::make_unsigned_t<T>;
using se_type = extended_result_t<T>;
using ir_type = extended_result_t<rep_type>;
using abs_type = abs_result<T>;
struct extended_tag_t
{
};
static constexpr extended_tag_t extended_tag{};
constexpr fraction(extended_tag_t, se_type num, ir_type den)
: fraction(extended_tag, abs(num), den)
{
}
constexpr fraction(extended_tag_t, abs_result<se_type> num, ir_type den)
: fraction(extended_tag, num.sign, num.val, den)
{
}
constexpr fraction(extended_tag_t, bool sign, ir_type num, ir_type den)
: num_(), den_()
{
auto t = gcd(num, den);
bool r = 0;
if (sign)
r |= _assign_overflowed(-(num / t), num_);
else
r |= _assign_overflowed(num / t, num_);
r |= _assign_overflowed(den / t, den_);
if (r)
throw std::overflow_error{ "fraction" };
}
static constexpr fraction diff(abs_type a, abs_type b, rep_type xden,
rep_type yden)
{
auto lnum = _extended_mul(a.val, yden),
rnum = _extended_mul(xden, b.val);
auto den = _extended_mul(xden, yden);
auto num = lnum - rnum;
if (a.sign)
return { extended_tag, se_type(-num), den };
else
return { extended_tag, se_type(num), den };
}
static constexpr fraction sum(abs_type a, abs_type b, rep_type xden,
rep_type yden)
{
auto t = gcd(xden, yden);
rep_type c = xden / t, d = yden / t;
auto lnum = _extended_mul(a.val, d), rnum = _extended_mul(c, b.val);
auto den = _extended_mul(xden, d);
ir_type num{};
if (_add_overflowed(lnum, rnum, num))
throw std::overflow_error{ "sum" };
return { extended_tag, a.sign, num, den };
}
T num_;
rep_type den_;
};
}
#include <iomanip>
#include <iostream>
#include <sstream>
void test_pair()
{
using ext::pair;
using ext::tie;
{
constexpr pair<int> x = { 3, 5 };
static_assert(x.a == 3, "");
static_assert(x.b == 5, "");
}
{
pair<int> x = {};
assert(x.a == 0);
assert(x.b == 0);
x = { 1, 2 };
assert(x.a == 1);
assert(x.b == 2);
auto y = x;
assert(x.a == 1);
assert(x.b == 2);
assert(y.a == 1);
assert(y.b == 2);
x = {};
y = x;
assert(y.a == 0);
assert(y.b == 0);
tie(x.b, x.a) = { 3, 4 };
assert(x.a == 4);
assert(x.b == 3);
int a, b;
tie(a, b) = x;
assert(a == 4);
assert(b == 3);
}
}
void test_fraction_promotion_arithmatic()
{
using T = ext::fraction<signed char>;
T a{ 1, 120 };
assert(a < 1);
assert(a > 0);
assert(a < T(8, 120));
assert(a > T(-7, 120));
assert(T(-4, 7) <= a);
assert(T(7, 120) >= a);
a -= { 8, 120 };
assert(a == T(-7, 120));
a += { -17, 80 };
assert(a == T(-13, 48));
a -= { -17, 48 };
assert(a == T(1, 12));
a += { 7, 36 };
assert(a == T(5, 18));
a += { 17, 18 };
assert(a == T(11, 9));
a -= { -5, 12 };
assert(a == T(59, 36));
a += { -4, 12 };
assert(a == T(47, 36));
a *= -1;
assert(a == T(-47, 36));
a -= { -11, 18 };
assert(a == T(-25, 36));
try
{
a + 5;
assert(0);
}
catch (std::overflow_error&)
{
assert(1);
}
a /= { 6, 5 };
assert(a != T(-125, uint8_t(216)));
a *= 2;
assert(a == T(-125, 108));
}
template <class T>
void test_basics()
{
T a{ -1236, 2434 };
assert((a - a) == T());
assert((a / a) == 1);
assert((a / 1) == a);
assert((a + 0) == a);
assert((a / -1) == T(618, 1217));
assert((1 / a) == T(-1217, 618));
assert((-1 / a) == T(1217, 618));
assert(a < T(-2123543, 4356462));
assert(a > T(-1232374, 232864));
a -= { 12362, 43435 };
assert(a == T(-5983912, 7551485));
a += { 1232, 2434 };
assert(a == T(-2161632, 7551485));
a += (a * -57);
assert(a == T(121051392, 7551485));
a /= { 32, 5 };
assert(a == T(3782856, 1510297));
a += { 1232, 24340 };
assert(a == T(19296508, 7551485));
a -= { 1232, 2434 };
assert(a == T(15474228, 7551485));
a *= { -35, 243 };
assert(a == T(36106532, -122334057));
}
void test_fraction_widen_arithmatic()
{
using T = ext::fraction<int>;
test_basics<T>();
try
{
constexpr T a{ 36106532, -122334057 };
constexpr T b{ 232, 43535 };
a + b;
assert(0);
}
catch (std::overflow_error&)
{
assert(1);
}
}
void test_fraction_128bit_arithmatic()
{
using T = ext::fraction<int64_t>;
test_basics<T>();
T a{ 123712403545, 3412459054623245 };
assert(a == T(24742480709, 682491810924649));
a += { 4189236437, 682491810924649 };
assert(a == T(28931717146, 682491810924649));
#ifdef _MSC_VER
#endif
}
void test_fraction_io()
{
std::stringstream ss;
auto str = [&](auto&& x) {
ss.str("");
ss << x;
return ss.str();
};
{
constexpr auto x = ext::fraction<int>(1, -3);
ss << std::setw(7) << std::left;
assert(str(x) == "-1/3 ");
ss << std::setw(7) << std::right;
assert(str(x) == " -1/3");
constexpr auto y = ext::fraction<int>(-3 * 13 * 4, 4);
ss << std::setw(9) << std::left;
assert(str(y) == "-39 ");
ss << std::setw(9) << std::right;
assert(str(y) == " -39");
assert(str(x + y) == "-118/3");
assert(str(x - y) == "116/3");
assert(str(x * y) == "13");
assert(str(x / y) == "1/117");
}
{
using T = ext::fraction<int8_t>;
T x{ -128, 127 };
T y{ -127, -128 };
assert(str(T{}) == "0");
assert(str(x) == "-128/127");
assert(str(y) == "127/128");
assert(str(x * y) == "-1");
}
{
using T = ext::fraction<int64_t>;
T x{ -9223372036854775807L - 1, 9223372036854775807L };
T y{ -9223372036854775807L, -9223372036854775807L - 1 };
assert(str(x) == "-9223372036854775808/9223372036854775807");
assert(str(y) == "9223372036854775807/9223372036854775808");
assert(str(x * y) == "-1");
}
{
ext::fraction<int8_t> x{ -128, 127 };
char buf[10];
auto r = x.to_chars(buf, buf);
assert(r.ptr == buf);
assert(r.ec == std::errc::value_too_large);
r = x.to_chars(buf, buf + 8);
assert(r.ptr == buf + 8);
assert(r.ec == std::errc::value_too_large);
r = x.to_chars(buf, buf + 9);
assert(r.ptr == buf + 8);
assert(r.ec == std::errc{});
r = (1 + x).to_chars(buf, buf + 9);
assert(std::string(buf) == "-1/127");
assert(r.ptr == buf + 6);
assert(r.ec == std::errc{});
}
}
int main()
{
test_pair();
test_fraction_promotion_arithmatic();
test_fraction_widen_arithmatic();
test_fraction_128bit_arithmatic();
test_fraction_io();
ext::fraction<int64_t> a{ 1237, 12325 };
ext::fraction<int64_t> b{ 237, 1225 };
std::cout << (a < b) << '\n';
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment