#include <gmp.h>
#include <stdio.h>
#include <stdlib.h>
#include "../gmp-6.1.2/gmp-impl.h"
#include "../gmp-6.1.2/longlong.h"

static inline mp_limb_t
getbits (const mp_ptr p, mp_bitcnt_t bi, int nbits)
{
  int nbits_in_r;
  mp_limb_t r;
  mp_size_t i;

  if (bi < nbits)
    {
      return p[0] & (((mp_limb_t) 1 << bi) - 1);
    }
  else
    {
      bi -= nbits;			/* bit index of low bit to extract */
      i = bi / GMP_NUMB_BITS;		/* word index of low bit to extract */
      bi %= GMP_NUMB_BITS;		/* bit index in low word */
      r = p[i] >> bi;			/* extract (low) bits */
      nbits_in_r = GMP_NUMB_BITS - bi;	/* number of bits now in r */
      if (nbits_in_r < nbits)		/* did we get enough bits? */
	  r += p[i + 1] << nbits_in_r;	/* prepend bits from higher word */
      return r & (((mp_limb_t ) 1 << nbits) - 1);
    }
}

// minor modification of mpz_mul_2exp to accept limb pointers and return limb count for result.  
// Note: only considers positive values for my purposes
static mp_size_t 
mpn_mul_2exp (mp_ptr rp, mp_srcptr up, mp_size_t un, mp_bitcnt_t cnt)
{
  mp_size_t rn;
  mp_size_t limb_cnt;
  mp_limb_t rlimb;

  limb_cnt = cnt / GMP_NUMB_BITS;
  rn = un + limb_cnt;

  if (un == 0)
    rn = 0;
  else
    {
      cnt %= GMP_NUMB_BITS;
      if (cnt != 0)
	{
	  rlimb = mpn_lshift (rp + limb_cnt, up, un, cnt);
	  rp[rn] = rlimb;
	  rn += (rlimb != 0);
	}
      else
	{
	  MPN_COPY_DECR (rp + limb_cnt, up, un);
	}

      /* Zero all whole limbs at low end.  Do it here and not before calling
	 mpn_lshift, not to lose for U == R.  */
      MPN_ZERO (rp, limb_cnt);
    }
    return rn;
}

// Left to right k-ary exponentiation
void mpz_pow2m(mpz_ptr rop, mpz_srcptr e, mpz_srcptr m) {
    mp_size_t mn, rn, rnmax, en;
    mp_ptr mp, rp, sqrp, qp, ep;
    int ebi, k, ik;
    mp_limb_t bits;
    TMP_DECL;
    TMP_MARK;

    ep = PTR(e);
    en = SIZ(e);
    MPN_SIZEINBASE_2EXP(ebi, ep, en, 1);

    mn = SIZ(m);
    rnmax = 2*mn+1;
    rn = rnmax;
    rp = MPZ_REALLOC(rop, mn);

    for(k=5; (1<<k)-1+GMP_NUMB_BITS<=(rnmax-mn)*GMP_NUMB_BITS; ++k) ; 
    --k;
    //fprintf(stderr, "k=%d\n",k);
    mp = PTR(m);

    sqrp = TMP_ALLOC_LIMBS(rn+(rn-mn+1));
    qp = sqrp+rn;
    //sqrp = TMP_ALLOC_LIMBS(rn);
    //qp = TMP_ALLOC_LIMBS(rn-mn+1);

    ik = ebi % k;
    if (ik==0)
        ik = k;

    bits = getbits(ep, ebi, ik);
    rn = 1;
    rp[0] = 1L;
    rn = mpn_mul_2exp(rp, rp, rn, bits);
    ebi -= ik;

    while(ebi > 0) {

        // A <-- A^(2^k)
        for(ik=0; ik<k; ++ik) {
            mpn_sqr(sqrp, rp, rn);
            rn *= 2;
            MPN_NORMALIZE (sqrp, rn);

            // A <-- A mod m
            if (rn > mn) {
                mpn_tdiv_qr(qp, rp, 0, sqrp, rn, mp, mn);
                rn = mn;
                MPN_NORMALIZE (rp, rn);
            } else {
                mpn_copyd(rp, sqrp, rn);
            }
        }

        // A <-- A*g_e_i
        bits = getbits(ep, ebi, k);
        if (bits) { 
            rn = mpn_mul_2exp(sqrp, rp, rn, bits);

            // A <-- A mod m
            if (rn > mn) {
                mpn_tdiv_qr(qp, rp, 0, sqrp, rn, mp, mn);
                rn = mn;
                MPN_NORMALIZE (rp, rn);
            } else {
                mpn_copyd(rp, sqrp, rn);
            }
        }

        ebi -= k;
    }

    // final output
    if (rn == mn && mpn_cmp(rp, mp, mn)>=0) {
        mpn_tdiv_qr(qp, rp, 0, rp, rn, mp, mn);
        MPN_NORMALIZE (rp, rn);
    }

    SIZ(rop) = rn;
    TMP_FREE;
}

// Pipes raw mpz_t data in/out while filtering away composites
int main(int argc, char**argv) {
    if (argc != 1) {
        fprintf(stderr, "usage: filter_fermat2k\n");
        return 1;
    }

    mpz_t n; mpz_init(n);
    mpz_t r; mpz_init(r);
    mpz_t e; mpz_init(e);

    unsigned long total = 0;
    unsigned long good = 0;
    while (mpz_inp_raw(n, stdin)) {
        ++total;
        mpz_sub_ui(e, n, 1L);
        mpz_pow2m(r, e, n);
        if (mpz_cmp_ui(r, 1L) == 0) {
            ++good;
            mpz_out_raw(stdout, n);
            //fflush(stdout);
        }
    }
    fprintf(stderr, "filter_fermat2k: %lu - %lu = %lu remaining\n", total, total-good, good);
    return 0;
}