/* Implementation of the binary recursive gcd algorithm.
   Version 1.0, 04 November 2005.

   Reference: "A Binary Recursive Gcd Algorithm",
   by D. Stehlé and P. Zimmermann, Proceedings of the Algorithmic
   Number Theory Symposium (ANTS VI), pages 411-425, 2004.
   
Copyright 2001, 2002, 2003, 2004, 2005 Damien Stehlé and Paul Zimmermann

This code is free software; you can redistribute it and/or modify
it under the terms of the GNU Lesser General Public License as published by
the Free Software Foundation; either version 2.1 of the License, or (at your
option) any later version.

This code is distributed in the hope that it will be useful, but
WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public
License for more details.

You should have received a copy of the GNU Lesser General Public License;
see the file COPYING.LIB.  If not, write to
the Free Software Foundation, Inc., 51 Franklin Place, Fifth Floor, Boston,
MA 02110-1301, USA. */

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <limits.h> /* for ULONG_MAX */
#include "gmp.h"
#include "gmp-impl.h"
#include "longlong.h"
#include <sys/types.h>
#include <sys/resource.h>

/* #define DEBUG 1 */
/* #define DEBUG2 */

double mul_cost, mul_cost_hgcd;

#define THRESHOLD1 200000000
#define MPZ_MUL(a, b, c) do { \
   mul_cost += (double) mpz_size (b) * mpz_size (c); \
   if (mpz_size(b) >= THRESHOLD1 || mpz_size(c) >= THRESHOLD1) \
      printf ("mpz_mul %d %d\n", mpz_size(b), mpz_size(c)); \
   mpz_mul ((a), (b), (c)); } while (0)
#define MPZ_ADDMUL(a, b, c) do { \
   mul_cost += (double) mpz_size (b) * mpz_size (c); \
   if (mpz_size(b) >= THRESHOLD1 || mpz_size(c) >= THRESHOLD1) \
      printf ("mpz_addmul %d %d\n", mpz_size(b), mpz_size(c)); \
   mpz_addmul ((a), (b), (c)); } while (0)
/* #define DD 1  */
#define BBOUND 63


int cputime ();
void mpz_addmul_si (mpz_t, mpz_t, long);
unsigned long mpz_divi_slow (mpz_t, mpz_t, mpz_t, mpz_t, int);
unsigned long mpz_divi (mpz_t, mpz_t, mpz_t, mpz_t, int);
unsigned long mpz_divi_si (long*, long*, long*, long*, mpz_t, mpz_t, mpz_t,
                           unsigned long, int);
int check (mpz_t, mpz_t, mpz_t, mpz_t, mpz_t, mpz_t, mpz_t, mpz_t,
           unsigned long);
long mpz_int_rfeea (mpz_t, mpz_t, mpz_t, mpz_t, mpz_t, mpz_t, unsigned long,
                    mpz_t*, int);
void mpz_fastgcd (mpz_t, const mpz_t, const mpz_t);
void Hstar (mp_ptr, mp_srcptr, mp_srcptr, mp_size_t);
void H (mp_ptr, mp_srcptr, mp_srcptr, mp_size_t);
void G (mp_ptr, mp_srcptr, mp_srcptr, mp_size_t);
void mpz_bfib2_ui (mpz_t, mpz_t, unsigned long);
long
mpz_int_rfeea0 (mpz_t, mpz_t, mpz_t, mpz_t, mpz_t, mpz_t, unsigned long, 
		mpz_t*);
unsigned long base_case_bin (mpz_t, mpz_t, mpz_t, mpz_t, mpz_t, mpz_t, 
			     unsigned long, unsigned long);





int
cputime ()
{
  struct rusage rus;

  getrusage (0, &rus);
  return rus.ru_utime.tv_sec * 1000 + rus.ru_utime.tv_usec / 1000;
}

#define OUT(x) { mpz_out_str (stdout, 10, x); printf (":\n"); }
#define OUT2(x) { mpz_out_str (stdout, 2, x); printf (":\n"); }

#define VAL(P) mpz_scan1(P, 0)
#define VAL2(P,i) mpz_scan1(P, i)
#define INFINITY ULONG_MAX

/* a <- a + b * i */
void
mpz_addmul_si (mpz_t a, mpz_t b, long i)
{
  if (i >= 0)
    mpz_addmul_ui (a, b, (unsigned long) i);
  else
    mpz_submul_ui (a, b, (unsigned long) (-i));
}     

/* given P and Q with val(P) < val(Q), replaces P by R and returns q
   such that q*Q/2^j+P = R and val(R) > val(Q) with j := val(Q)-val(P)
   with -2^j <= q < 2^j.
   We have q = - P'/Q' (mod 2^(j+1)) with P' = P/2^val(P), Q' = Q/2^val(Q).
   Returns j := val(Q)-val(P).
   tmp is an auxiliary variable.
*/
unsigned long
mpz_divi_slow (mpz_t q, mpz_t P, mpz_t Q, mpz_t tmp, int extended)
{
  unsigned long valP, valQ, v;

  valP = VAL(P);
  valQ = VAL(Q);
#ifdef DEBUG
  if (valP >= valQ)
    {
      fprintf (stderr, "Error in mpz_divi_slow: val(P) >= val(Q)\n");
      exit (1);
    }
#endif
  if (extended)
    mpz_set_ui (q, 0);
  v = valP;
  while (v < valQ)
    {
      if (extended)
        mpz_setbit (q, v - valP);
      mpz_div_2exp (tmp, Q, valQ - v);
      mpz_add (P, P, tmp);
      v = VAL2 (P, v + 1);
#ifdef DEBUG
      if (v != VAL(P))
        {
          fprintf (stderr, "Error in mpz_divi_slow: v <> VAL(P)\n");
          exit (1);
        }
#endif
    }
  if (v == valQ)
    {
#define CENTERED
#ifdef CENTERED
#define MPZ_ADD mpz_sub /* last step is centered, q may be negative */
#else
#define MPZ_ADD mpz_add /* last step is not centered, q is positive */
#endif
      if (extended) /* q = q - 2^j */
        {
          mpz_set_ui (tmp, 1);
          mpz_mul_2exp (tmp, tmp, valQ - valP);
          MPZ_ADD (q, q, tmp);
        }
      MPZ_ADD (P, P, Q);
    }
  return valQ - valP;
}

/* given P and Q with val(P) < val(Q), replaces P by R and returns q
   such that q*Q/2^j+P = R and val(R) > val(Q) with j := val(Q)-val(P)
   with -2^j < q < 2^j (q is always odd).
   We have q = - P'/Q' (mod 2^(j+1)) with P' = P/2^val(P), Q' = Q/2^val(Q).
   Returns j := val(Q)-val(P).
   tmp is an auxiliary variable.
*/
unsigned long
mpz_divi (mpz_t q, mpz_t P, mpz_t Q, mpz_t tmp, int extended)
{
  unsigned long valP, valQ, j, i;
  mp_limb_t pp, qp, mask;
  long qq;
  mp_ptr tp;

#ifdef DEBUG
  if (mpz_cmp_ui (Q, 0) == 0)
    {
      fprintf (stderr, "Error: Q=0\n");
      exit (1);
    }
#endif
  valP = VAL(P);
  valQ = VAL(Q);
  j = valQ - valP;
  if (j >= GMP_NUMB_BITS)
    return mpz_divi_slow (q, P, Q, tmp, extended);

#ifdef DEBUG
  if (mpz_sizeinbase (P, 2) < valP + j + 1)
    {
      fprintf (stderr, "Error: bits(P) < valP+j+1\n");
      exit (1);
    }
#endif    
  tp = PTR(P) + valP / GMP_NUMB_BITS;
  valP %= GMP_NUMB_BITS;
  pp = (valP) ? ((tp[0] >> valP) + (tp[1] << (GMP_NUMB_BITS - valP)))
    : tp[0];
  /* now pp contains the low j+1 significant bits from P */

#ifdef DEBUG
  if (mpz_sizeinbase (Q, 2) < valQ + j + 1)
    {
      fprintf (stderr, "Error: bits(Q) < valQ+j+1\n");
      exit (1);
    }
#endif    
  tp = PTR(Q) + valQ / GMP_NUMB_BITS;
  valQ %= GMP_NUMB_BITS;
  qp = (valQ) ? ((tp[0] >> valQ) + (tp[1] << (GMP_NUMB_BITS - valQ)))
    : tp[0];
  /* now qp contains the low j+1 significant bits from Q */

  /* we now have the low j+1 bits from P and Q in pp and qp */

      qq = (long) 0; /* always odd */
      mask = (mp_limb_t) 1;
      for (i = 0; i < j; i++)
        {
          /* invariant: mask = 1 << i */
          if (pp & mask)
            {
              qq += mask;
              pp += qp;
            }
          qp <<= 1;
          mask <<= 1;
        }

      /* last step is centered */
      qq -= pp & mask;

      if (mpz_sgn (P) != mpz_sgn (Q))
        qq = - qq;

  if (extended)
    mpz_set_si (q, qq);

  /* replace P by q*Q/2^j+P */
  mpz_div_2exp (tmp, Q, j);
  if (qq >= 0)
    {
      if (qq == 1)
        mpz_add (P, P, tmp);
      else
        mpz_addmul_ui (P, tmp, qq);
    }
  else
    {
      if (qq == -1)
        mpz_sub (P, P, tmp);
      else
        mpz_submul_ui (P, tmp, -qq);
    }

  return j;
}






/* returns j and matrix [a,b;c,d] which reduces P0, Q0 into P, Q.
   Assumes 0 = val(P) < val(Q) initially.
   Modifies P, Q so that val(Q) < val(P), and divides by 2^j
   so that val(Q)=0.
   Stop if k < val(P) is attained.
   If swap is non-zero, swap P and Q at the end, s.t. 0 = val(P) <= k < val(Q).
*/
unsigned long
mpz_divi_si (long *a, long *b, long *c, long *d,
             mpz_t P, mpz_t Q, mpz_t tmp, unsigned long k, int swap)
{
  unsigned long jtot;
  long j;
  mp_limb_t pp0[2], qp0[2], *pp, *qp, *tp, invq, twoj;
  long q;
  double sign = (mpz_sgn (P) == mpz_sgn (Q)) ? 1.0 : -1.0;
  double aa, bb, cc, dd;
  long significant_bits;
  double newa, newb, newc, newd, twojd, qq;
  long sizep, sizeq;

#ifdef DEBUG
  mpz_t Pin, Qin;

  mpz_init_set (Pin, P);
  mpz_init_set (Qin, Q);
  if (sizep < 2)
    {
      fprintf (stderr, "Error: sizep < 2\n");
      exit (1);
    }
#endif


  sizep = ABS(SIZ(P));
  pp = pp0;
  pp[0] = PTR(P)[0];
  pp[1] = (sizep >= 2) ? PTR(P)[1] : 0;
  /* now {pp,2} contains the low 2 significant limbs from |P| */

  j = VAL(Q);
  /* if j >= GMP_NUMB_BITS, then the coefficients may not fit in a long */
  if (j >= GMP_NUMB_BITS)
    abort (); /* should call mpz_divi_slow here */
  /* now j < GMP_NUMB_BITS */


#ifdef DEBUG
  if (sizeq < 2)
    {
      fprintf (stderr, "Error: sizeq < 2\n");
      exit (1);
    }
#endif


  /* since val(P) < val(Q), j cannot be 0 */
  sizeq = ABS(SIZ(Q));
  qp = qp0;
  if (sizeq == 1)
    {
      mpn_rshift (qp, PTR(Q), 1, j);
      qp[1] = 0;
    }
  else
    {
      mpn_rshift (qp, PTR(Q), 2, j);
      if (sizeq > 2)
        qp[1] |= PTR(Q)[2] << (GMP_NUMB_BITS - j);
    }
  /* now {qp,2} contains the low 2 significant limbs from |Q| */

  aa = 1.0;
  bb = 0.0;
  cc = 0.0;
  dd = 1.0;
  jtot = 0; /* matrix denominator is 2^jtot */

  /* number of significant bits from p and q*2^j */
  significant_bits = 2 * GMP_NUMB_BITS + j;

  /* invariant: j = valuation difference = VAL(Q) - VAL(P), and jtot=val(P) */
  while (jtot + j <= k && j < significant_bits)
    { /* we need at least j+1 significant bits in pp and qq */
      /* reduce P wrt Q, by computing -P/Q mod 2^(j+1) */
      significant_bits -= 2 * j;
      twoj = 1L << j;
      modlimb_invert (invq, qp[0]);
      invq *= (~pp[0]) + 1UL;
      invq &= (twoj << 1) - 1UL; /* -P/Q mod 2^(j+1) */
      q = (invq >= twoj) ? invq - (twoj << 1) : invq;
      /*      printf ("q=%1.0f\n", sign*(double)q); */

      /* replace P by P + q*Q/2^j */
      if (q >= 0)
        mpn_addmul_1 (pp, qp, 2, q);
      else
        mpn_submul_1 (pp, qp, 2, -q);

      /* update matrix: [newP, newQ] -> [Q, P+q/2^j*Q] 
         thus jtot -> jtot + j,
         [a, b; c, d] -> [0, 2^j; 2^j; q] * [a, b; c, d] */
      /* a[0] <- 2^j*c[0], c[0] <- 2^j*a[0] + q*c[0]
         b[0] <- 2^j*d[0], d[0] <- 2^j*b[0] + q*d[0] */
#define BOUND 2147483647.0
#define ok(x) (-BOUND <= x && x <= BOUND)
      twojd = (double) twoj;
      qq = sign * (double) q;
      newa = twojd * cc;
      newc = twojd * aa + qq * cc;
      newb = twojd * dd;
      newd = twojd * bb + qq * dd;
      if (ok(newa) && ok(newb) && ok(newc) && ok(newd))
        {
          aa = newa;
          bb = newb;
          cc = newc;
          dd = newd;
        }
      else
        break;
      jtot += j;

      /* we know that val(P) >= j+1 here, new j is val(P) - j */
      if (pp[0] == 0)
        break;
      count_trailing_zeros(q, pp[0]);
      mpn_rshift (pp, pp, 2, q);
      j = q - j;

      /* exchange P and Q */
      tp = pp; pp = qp; qp = tp;
    }

  c[0] = (long) aa;
  d[0] = (long) bb;
  a[0] = (long) cc;
  b[0] = (long) dd;

  /* update P and Q: Q <- (a[0]*P + b[0]*Q)/2^jtot,
                     P <- (c[0]*P + d[0]*Q)/2^jtot */
  { 
    if (swap == 0) /* last computed term will be in P */
      {
        mpz_mul_si (tmp, P, c[0]);
        mpz_addmul_si (tmp, Q, d[0]);
        mpz_mul_si (P, P, a[0]);
        mpz_addmul_si (P, Q, b[0]);
        mpz_div_2exp (P, P, 2 * jtot);
        mpz_div_2exp (Q, tmp, 2 * jtot);
      }
    else /* swap=1: last computed term will be in Q */
      {
        mpz_mul_si (tmp, P, c[0]);
        mpz_addmul_si (tmp, Q, d[0]);
        mpz_mul_si (Q, Q, b[0]);
        mpz_addmul_si (Q, P, a[0]);
        mpz_div_2exp (Q, Q, 2 * jtot);
        mpz_div_2exp (P, tmp, 2 * jtot);
      }
  }

#ifdef DEBUG
  if ((swap == 0 && VAL(Q) != 0) || (swap == 1 && VAL(P) != 0))
    {
      printf ("VAL(Q)<>0: VAL(P)=%u VAL(Q)=%u\n", VAL(P), VAL(Q));
      printf ("k=%u Pin=", k); OUT(Pin);
      printf ("Qin="); OUT(Qin);
      exit (1);
    }
  if ((swap == 0 && VAL(P) <= VAL(Q)) || (swap == 1 && VAL(Q) <= VAL(P)))
    {
      fprintf (stderr, "Error in divi_si: VAL(P) <= VAL(Q) at exit\n");
      exit (1);
    }
  mpz_clear (Pin);
  mpz_clear (Qin);
#endif

  return jtot;
}



unsigned long
base_case_bin (mpz_t R11, mpz_t R12, mpz_t R21, mpz_t R22, mpz_t P, mpz_t Q, 
	       unsigned long k, unsigned long l)
{
  mpz_t S11, S12, S21, S22, t1;
  long a[4];
  unsigned long jtot, j;
  unsigned long kk; 

  mpz_init (S11); 
  mpz_init (S12);   
  mpz_init (S21); 
  mpz_init (S22);
  mpz_init (t1);

  /*  if (k <=10)
      printf ( "k= %u \n " , k); OUT(P); OUT(Q); printf( "\n");  */
  jtot = 0;


  if ( l <= BBOUND)
    {
      jtot = mpz_divi_si (a, a+1, a+2, a+3, P, Q, t1, k, 1);
      mpz_set_si (R11, a[2]);
      mpz_set_si (R12, a[3]);
      mpz_set_si (R21, a[0]);
      mpz_set_si (R22, a[1]);
      
      while (jtot + VAL(Q) <= l)
	{
	  j = mpz_divi_si (a, a+1, a+2, a+3, P, Q, t1, k - jtot, 1);
	  jtot += j;
	  
	  
	  /* R = [a2,a3;a0,a1] * R, i.e.
	     R11 <- a2*R11+a3*R21, R21 <- a0*R11+a1*R21,
	     R12 <- a2*R12+a3*R22, R22 <- a0*R12+a1*R22 */
	  mpz_mul_si (t1, R11, a[0]);
	  mpz_mul_si (R11, R11, a[2]);
	  mpz_addmul_si (R11, R21, a[3]);
	  mpz_mul_si (R21, R21, a[1]);
	  mpz_add (R21, R21, t1);
	  mpz_mul_si (t1, R12, a[0]);
	  mpz_mul_si (R12, R12, a[2]);
	  mpz_addmul_si (R12, R22, a[3]);
	  mpz_mul_si (R22, R22, a[1]);
	  mpz_add (R22, R22, t1);
	}
    }
  else
    {
      
      /* printf("coucou\n"); */
      
      if (VAL(Q) > k) 
	{
	  mpz_set_ui (R11, 1);
	  mpz_set_ui (R12, 0);
	  mpz_set_ui (R21, 0);
	  mpz_set_ui (R22, 1);
	}
      else
	{
	  kk = l/2;
	  jtot = base_case_bin (R11, R12, R21, R22, P, Q, k, kk);
	      
	  if (VAL(Q) <= k)
	    {
	      k -= jtot;
	      jtot += base_case_bin (S11, S12, S21, S22, P, Q, k, l-jtot);
	      /* R = S * R */
	      mpz_mul (t1, R11, S21);
	      mpz_mul (R11, R11, S11);
	      mpz_addmul (R11, R21, S12);
	      mpz_mul (R21, R21, S22);
	      mpz_add (R21, R21, t1);
	      mpz_mul (t1, R12, S21);
	      mpz_mul (R12, R12, S11);
	      mpz_addmul (R12, R22, S12);
	      mpz_mul (R22, R22, S22);
	      mpz_add (R22, R22, t1);
	    }
	}
    }	 
 
  mpz_clear (t1);
  mpz_clear (S11);
  mpz_clear (S12);
  mpz_clear (S21);
  mpz_clear (S22);

  /*printf ( "jtot = %u \n", jtot); */
  return jtot;
}













/* check that R*[Pin,Qin] = [P,Q] */
int
check (mpz_t R11, mpz_t R12, mpz_t R21, mpz_t R22,
       mpz_t Pin, mpz_t Qin, mpz_t P, mpz_t Q, unsigned long j)
{
  mpz_t P0, Q0, t1, q;
  int res;
 
  mpz_init (P0);
  mpz_init (Q0);
  mpz_init (t1);
  mpz_init (q);
  mpz_mul (P0, R11, Pin);
  mpz_mul (Q0, R12, Qin);
  mpz_mul (t1, R21, Pin);
  mpz_mul (q,  R22, Qin);
  mpz_add (P0, P0, Q0);
  mpz_add (Q0, t1, q);
  mpz_div_2exp (P0, P0, j);
  mpz_div_2exp (Q0, Q0, j);
  res = mpz_cmp (P0, P) != 0 || mpz_cmp (Q0, Q) != 0;
  mpz_clear (P0);
  mpz_clear (Q0);
  mpz_clear (t1);
  mpz_clear (q);
  return res;
}

















long
mpz_int_rfeea0 (mpz_t R11, mpz_t R12, mpz_t R21, mpz_t R22,
                mpz_t P, mpz_t Q, unsigned long k,
                mpz_t *tmp)
{
  mpz_t P0, Q0;
  unsigned long cut;
  int j1;

  mpz_init (P0);
  mpz_init (Q0);
#define q   tmp[1]
#define t1  tmp[0]

  cut = 2 * k + 1;
  if (cut % GMP_NUMB_BITS)
    cut += GMP_NUMB_BITS - (cut % GMP_NUMB_BITS);
  mpz_tdiv_r_2exp (P0, P, cut);
  mpz_tdiv_q_2exp (P,  P, cut);
  mpz_tdiv_r_2exp (Q0, Q, cut);
  mpz_tdiv_q_2exp (Q,  Q, cut);

  j1 = mpz_int_rfeea (R11, R12, R21, R22, P0, Q0, k, tmp, 1);

  MPZ_MUL (q, R12, Q);
  MPZ_MUL (Q, R22, Q);
  MPZ_ADDMUL (Q, R21, P);
  mpz_mul_2exp (Q, Q, cut - 2 * j1);
  mpz_add (Q, Q, Q0);
  MPZ_MUL (t1, R11, P);
  mpz_add (P, t1, q);
  mpz_mul_2exp (P, P, cut - 2 * j1);
  mpz_add (P, P, P0);

    


  mpz_clear (P0);
  mpz_clear (Q0);
  return j1;

}






/* returns a 2x2 integer matrix R = [[R11,R12],[R21,R22]] such that if
   (P')        (P)
   (  )  = R * ( ) / 2^j
   (Q')        (Q)
   P', Q' are two consecutive remainders of
   the modified right binary Euclidean algorithm with
   val(P') <= val(P) + k < val(Q').

   Assumes 0=val(P) < val(Q), i.e. P is odd and Q is even.
   Returns j.
   Modifies in-place P and Q to P'/2^j and Q'/2^j, with val(P'/2^j)=0.
   If extended=0, the matrix R is not computed.
   We have det(R)=2^(2*j).

   For example, for P=646173941, Q=3473871690, k=32, it returns
   R11=3094378856, R12=4734555148, R21=-3473871690, R22=646173941, j=32.
*/
long
mpz_int_rfeea (mpz_t R11, mpz_t R12, mpz_t R21, mpz_t R22,
               mpz_t P, mpz_t Q, unsigned long k,
               mpz_t *tmp, int extended)
{
  unsigned long m, d, mm, dd, j, j1, j2, cut;
  mpz_t P0, Q0, S11, S12, S21, S22; /* t1, q */
  long a[4];
  int computeS;
#ifdef DEBUG
  mpz_t g, g2;
  mpz_init (g); mpz_init (g2); mpz_gcd (g, P, Q);
  mpz_div_2exp (g, g, VAL(g));
#endif

#define t1  tmp[0]
#define q   tmp[1]

  m = VAL(Q);

#ifdef DEBUG
  if (VAL(P) != 0 || m == 0)
    {
      fprintf (stderr, "Error in mpz_int_rfeea: VAL(P)<>0 or val(Q)=0\n");
      printf ("P:="); OUT(P);
      printf ("Q:="); OUT(Q);
      exit (1);
    }
#endif

  if (k < m) /* val(P) + k = k < val(Q) */
    {
      /* the output condition is satisfied with P'=P, Q'=Q, j=0, and R=Id */
      if (extended)
        {
          mpz_set_ui (R11, 1);
          mpz_set_ui (R12, 0);
          mpz_set_ui (R21, 0);
          mpz_set_ui (R22, 1);
        }
      j1 = 0;
      goto end2;
    }

  /* we start to gain when k/4 > MUL_KARATSUBA_THRESHOLD * GMP_NUMB_BITS */
#define THRESHOLD ( MUL_KARATSUBA_THRESHOLD * GMP_NUMB_BITS )
  if (k < THRESHOLD) /* naive algorithm */
    {
#define DIVI_SI
#ifdef DIVI_SI
#ifdef DD

      /*      printf("blabla, %u\n",k); */
      j1 = base_case_bin (R11, R12, R21, R22, P, Q, k, k);
      /* printf ("blabla, %u\n", j1); */
      goto end2;
#else
      j1 = mpz_divi_si (a, a+1, a+2, a+3, P, Q, t1, k, 1);
      if (extended)
        {
          mpz_set_si (R11, a[2]);
          mpz_set_si (R12, a[3]);
          mpz_set_si (R21, a[0]);
          mpz_set_si (R22, a[1]);
        }

      while (j1 + VAL(Q) <= k)
        {
          j = mpz_divi_si (a, a+1, a+2, a+3, P, Q, t1, k - j1, 1);
          j1 += j;


          if (extended) /* R = [a2,a3;a0,a1] * R, i.e.
                           R11 <- a2*R11+a3*R21, R21 <- a0*R11+a1*R21,
                           R12 <- a2*R12+a3*R22, R22 <- a0*R12+a1*R22 */
            {
              mpz_mul_si (t1, R11, a[0]);
              mpz_mul_si (R11, R11, a[2]);
              mpz_addmul_si (R11, R21, a[3]);
              mpz_mul_si (R21, R21, a[1]);
              mpz_add (R21, R21, t1);
              mpz_mul_si (t1, R12, a[0]);
              mpz_mul_si (R12, R12, a[2]);
              mpz_addmul_si (R12, R22, a[3]);
              mpz_mul_si (R22, R22, a[1]);
              mpz_add (R22, R22, t1);
            }
        }
      goto end2;
#endif

#else
      j1 = mpz_divi (R22, P, Q, t1, extended);
      mpz_swap (P, Q);
      /* P = oldQ, Q = (2^j1*oldP + q*oldQ)/2^j1, val(Q) < val(P) */
      if (extended)
        {
          mpz_set_ui (R11, 0);
          mpz_set_ui (R12, 1);
          mpz_mul_2exp (R12, R12, j1);
          mpz_set (R21, R12);
        }
      while (VAL(Q) <= k)
        {
          j = mpz_divi (q, P, Q, t1, extended);
          mpz_swap (P, Q);
          j1 += j;
          if (extended) /* R = [0,2^j;2^j,q] * R */
            {
              MPZ_MUL (t1, q, R21);
              mpz_mul_2exp (R11, R11, j);
              mpz_add (R11, R11, t1); /* 2^j*R11+q*R21 */
              mpz_mul_2exp (R21, R21, j);
              mpz_swap (R11, R21);
              MPZ_MUL (t1, q, R22);
              mpz_mul_2exp (R12, R12, j);
              mpz_add (R12, R12, t1); /* 2^j*R12+q*R22 */
              mpz_mul_2exp (R22, R22, j);
              mpz_swap (R12, R22);
            }
        }
      mpz_div_2exp (P, P, j1);
      mpz_div_2exp (Q, Q, j1);
      goto end2;
#endif
    }

  d = k / 2;

  mpz_init (P0);
  mpz_init (Q0);

  cut = 2 * d + 1;
  if (cut % GMP_NUMB_BITS)
    cut += GMP_NUMB_BITS - (cut % GMP_NUMB_BITS);
  mpz_tdiv_r_2exp (P0, P, cut);
  mpz_tdiv_q_2exp (P,  P, cut);
  mpz_tdiv_r_2exp (Q0, Q, cut);
  mpz_tdiv_q_2exp (Q,  Q, cut);

#ifdef DEBUG
  if (VAL(P0) != 0) {printf ("1st call\n"); abort();}
#endif
  /* we don't need to compute the matrix R when extended=0 and P=Q=0 */
  j1 = mpz_int_rfeea (R11, R12, R21, R22, P0, Q0, d, tmp, extended
                      || mpz_cmp_ui (P, 0) || mpz_cmp_ui (Q, 0));
  /* we should have 2^j1*P0 = R11*oldP0+R12*oldQ0,
                    2^j1*Q0 = R21*oldP0+R22*oldQ0 */
  /* we now have j1 = j1+val(P0') <= val(P0) + d < j1+val(Q0')
     [val(P0)=val(P)=0] with P0' = R11*P0 + R12*Q0 */

  /* now compute P' = (R11*P1 + R12*Q1)*2^(2d+1-2*j1) + P0'.
     We have val(P') = val(P0') <= val(P) + d < val(Q0') = val(Q') */
#ifdef DEBUG
  if (!(j1 + VAL(P0) <= d && ((VAL(Q0) == INFINITY) || (d < j1 + VAL(Q0)))))
    {
      fprintf (stderr, "j1+VAL(P0) <= d && d < j1+VAL(Q0) not fulfilled: %u %u %u\n",
               VAL(P0), d, VAL(Q0));
      exit (1);
    }
  if (cut < j1)
    {
      fprintf (stderr, "Error: cut < j1\n");
      exit (1);
    }
#endif
#if 0
  if (extended && k > 40000)
    printf ("k:%u d:%u R11:%u R12:%u R21:%u R22:%u P:%u Q:%u\n", k, d,
            mpz_size (R11), mpz_size (R12), mpz_size (R21), mpz_size (R22),
            mpz_size (P), mpz_size (Q));
#endif
#ifdef DEBUG
  if (2 * j1 >= cut)
    {
      fprintf (stderr, "Error: 2 * j1 >= cut\n");
      exit (1);
    }
#endif
  MPZ_MUL (q, R12, Q);
  MPZ_MUL (Q, R22, Q);
  MPZ_ADDMUL (Q, R21, P);
  mpz_mul_2exp (Q, Q, cut - 2 * j1);
  mpz_add (Q, Q, Q0);
  MPZ_MUL (t1, R11, P);
  mpz_add (P, t1, q);
  mpz_mul_2exp (P, P, cut - 2 * j1);
  mpz_add (P, P, P0);
#ifdef DEBUG
  if (!(j1 + VAL(P) <= d && ((VAL(Q) == INFINITY) || (d < j1 + VAL(Q)))))
    {
      fprintf (stderr, "j1+VAL(P) <= d && d < j1+VAL(Q) not fulfilled: %u %u %u\n", j1+VAL(P), d, j1+VAL(Q));
      exit (1);
    }
#endif  

#ifdef DEBUG
  mpz_gcd (g2, P, Q);
  mpz_div_2exp (g2, g2, VAL(g2));
  if (mpz_cmp (g, g2))
    {
      fprintf (stderr, "1: g <> g2, extended=%u\n", extended);
      OUT(g);
      OUT(g2);
      exit (1);
    }
  if (j1 > d + 1)
    {
      fprintf (stderr, "Error: j1 > d + 1\n");
      exit (1);
    }
#endif

  mm = VAL2(Q, d + 1 - j1);
#ifdef DEBUG
  if (mm != VAL(Q))
    {
      fprintf (stderr, "Error: mm<>VAL(Q)\n");
      printf ("mm=%u d=%u\n", mm, d);
      printf ("P:="); OUT(P);
      printf ("Q:="); OUT(Q);
      printf ("P0:="); OUT(P0);
      printf ("Q0:="); OUT(Q0);
      exit (1);
    }
#endif

  if (mm == INFINITY || k < j1 + mm) /* val(P_init) + k < val(Q') */
    goto end;

#ifdef DIVI_SI
  j = mpz_divi_si (a, a+1, a+2, a+3, P, Q, t1, ULONG_MAX, 0);
  /* P = a[0]*oldP + a[1]*oldQ, Q = a[2]*oldP + a[3]*oldQ */
  mm = j;
#else
  j = mpz_divi (q, P, Q, t1, 1);
  if (mm)
    {
      mpz_div_2exp (P, P, mm);
      mpz_div_2exp (Q, Q, mm);
    }
#endif
  /* we should now have VAL(Q) < VAL(P) and VAL(Q) <= k */
#ifdef DEBUG
  if (VAL(P) <= VAL(Q) || VAL(Q) > k)
    {
      fprintf (stderr, "Error: VAL(P)<=VAL(Q) or VAL(Q)>k after divi\n");
      exit (1);
    }
#endif

  if (j1 + mm > k)
    {
      fprintf (stderr, "Error: j1 + mm > k\n");
      exit (1);
    }
  dd = k - (j1 + mm);

  cut = 2 * dd + 1;
  if (cut % GMP_NUMB_BITS)
    cut += GMP_NUMB_BITS - (cut % GMP_NUMB_BITS);
  mpz_tdiv_r_2exp (Q0, Q, cut);
  mpz_tdiv_q_2exp (Q,  Q, cut);
  mpz_tdiv_r_2exp (P0, P, cut);
  mpz_tdiv_q_2exp (P,  P, cut);

  if (extended)
    {
#ifdef DIVI_SI
      /* R <- [a[2],a[3]; a[0],a[1]] * R:
         R11 <- a[0]*R11+a[1]*R21
         R21 <- a[2]*R11+a[3]*R21
         R12 <- a[0]*R12+a[1]*R22
         R22 <- a[2]*R12+a[3]*R22 */
      mpz_mul_si (t1, R11, a[2]);
      mpz_addmul_si (t1, R21, a[3]);
      mpz_mul_si (R21, R21, a[1]);
      mpz_addmul_si (R21, R11, a[0]);
      mpz_swap (R11, t1);
      mpz_mul_si (t1, R12, a[2]);
      mpz_addmul_si (t1, R22, a[3]);
      mpz_mul_si (R22, R22, a[1]);
      mpz_addmul_si (R22, R12, a[0]);
      mpz_swap (R12, t1);
#else
      /* R <- [0,2^j; 2^j, q] * R:
         R11 <- 2^j*R21
         R21 <- 2^j*R11+q*R21
         R12 <- 2^j*R22
         R22 <- 2^j*R12+q*R22 */
      MPZ_MUL (t1, q, R21);
      mpz_mul_2exp (R11, R11, j);
      mpz_add (R11, R11, t1); /* 2^j*R11+q*R21 */
      mpz_mul_2exp (R21, R21, j);
      mpz_swap (R11, R21);
      MPZ_MUL (t1, q, R22);
      mpz_mul_2exp (R12, R12, j);
      mpz_add (R12, R12, t1); /* 2^j*R12+q*R22 */
      mpz_mul_2exp (R22, R22, j);
      mpz_swap (R12, R22);
#endif
    }
  
  computeS = extended || mpz_cmp_ui (P, 0) || mpz_cmp_ui (Q, 0);

  mpz_init (S11);
  mpz_init (S12);
  mpz_init (S21);
  mpz_init (S22);

#ifdef DEBUG
  if (VAL(Q0) != 0)
    {
      fprintf (stderr, "Error: VAL(Q0) <> 0\n");
      exit (1);
    }
#endif
  j2 = mpz_int_rfeea (S11, S12, S21, S22, Q0, P0, dd, tmp, computeS);
#ifdef DEBUG
  if (!(j2 + VAL(Q0) <= dd && ((VAL(P0) == INFINITY) || (dd < j2 + VAL(P0)))))
    {
      fprintf (stderr, "j2+VAL(Q0) <= dd && dd < j2+VAL(P0) not fulfilled: %u %u %u\n", j2+VAL(Q0), dd, j2+VAL(P0));
      exit (1);
    }
#endif

  if (extended) /* R <- S * R:
     R11 <- S11 * R11 + S12 * R21
     R12 <- S11 * R12 + S12 * R22
     R21 <- S21 * R11 + S22 * R21
     R22 <- S21 * R12 + S22 * R22 */
    {
      MPZ_MUL (t1, S21, R11);
      MPZ_MUL (R11, S11, R11);
      MPZ_MUL (q, S12, R21);
      MPZ_MUL (R21, S22, R21);
      mpz_add (R11, R11, q);
      mpz_add (R21, t1, R21);
      MPZ_MUL (t1, S21, R12);
      MPZ_MUL (R12, S11, R12);
      MPZ_MUL (q, S12, R22);
      MPZ_MUL (R22, S22, R22);
      mpz_add (R12, R12, q);
      mpz_add (R22, t1, R22);
    }

  /* update Q, P to S*2^(cut+mm-j2) * [Q, P] + [P0, Q0] */
#ifdef DEBUG
  if (cut + mm <= j2)
    {
      fprintf (stderr, "Error: cut + mm <= j2: %d %d\n", cut + mm, j2);
      exit (1);
    }
#endif
#if 0
  if (extended && k > 40000)
    printf ("k:%u d:%u S11:%u S12:%u S21:%u S22:%u P:%u Q:%u\n", k, dd,
            mpz_size (S11), mpz_size (S12), mpz_size (S21), mpz_size (S22),
            mpz_size (P), mpz_size (Q));
#endif
#ifdef DEBUG
  if (2 * j2 >= cut)
    {
      fprintf (stderr, "Error: 2 * j2 >= cut: j2=%u cut=%u\n", j2, cut);
      exit (1);
    }
#endif
  MPZ_MUL (t1, S11, Q);
  MPZ_MUL (q, S12, P);
  MPZ_MUL (Q, S21, Q);
  MPZ_MUL (P, S22, P);
  mpz_add (Q, Q, P);
  mpz_add (P, t1, q);
  mpz_mul_2exp (P, P, cut - 2 * j2);
  mpz_mul_2exp (Q, Q, cut - 2 * j2);
  mpz_add (P, P, Q0);
  mpz_add (Q, Q, P0);

  mpz_clear (S11);
  mpz_clear (S12);
  mpz_clear (S21);
  mpz_clear (S22);

  j1 += j + j2;

 end:
  mpz_clear (P0);
  mpz_clear (Q0);

 end2:
#ifdef DEBUG
  mpz_gcd (g2, P, Q);
  mpz_div_2exp (g2, g2, VAL(g2));
  if (mpz_cmp (g, g2))
    {
      fprintf (stderr, "g <> g2, extended=%u\n", extended);
      OUT(g);
      OUT(g2);
      exit (1);
    }
  mpz_clear (g);
  mpz_clear (g2);
#endif

  return j1;
}

#define TMP_SIZE 8

void
mpz_fastgcd (mpz_t g, const mpz_t a, const mpz_t b)
{
  mpz_t tmp[TMP_SIZE];
  unsigned long i, va, vb, v, v1;

  for (i = 0; i < TMP_SIZE; i++)
    mpz_init (tmp[i]);

  va = VAL(a);
  vb = VAL(b);

  if (va < vb)
    {
      mpz_set (tmp[0], a);
      mpz_set (tmp[1], b);
    }
  else
    {
      mpz_set (tmp[0], b);
      if (va == vb)
        mpz_add (tmp[1], a, b);
      else
        mpz_set (tmp[1], a);
      va = vb;
    }

  /* now VAL(tmp[0]) < VAL(tmp[1]) */

#ifdef DEBUG2
  printf ("t0:="); OUT(tmp[0]);
  printf ("t1:="); OUT(tmp[1]);
#endif

  mpz_div_2exp (tmp[0], tmp[0], va);
  mpz_div_2exp (tmp[1], tmp[1], va);

  while (VAL(tmp[1]) +1 < mpz_sizeinbase (tmp[1], 2))
    {
#if 0
      mpz_div_2exp (tmp[1], tmp[1], VAL(tmp[1]));
      mpz_add (tmp[1], tmp[0], tmp[1]);
#endif

#ifdef DEBUG
      if ((int) SIZ(tmp[0]) >= 498)
	printf ("    size = %i \n", SIZ(tmp[0]));  
#endif
      if (mpz_sizeinbase (tmp[1], 2) <= VAL(tmp[1]) + THRESHOLD)
	break; 
      if (mpz_size (tmp[1]) > MUL_FFT_THRESHOLD)
        v = mpz_sizeinbase (tmp[1], 2) / 4;
      else
        v = mpz_sizeinbase (tmp[1], 2) / 4;
      v1 = mpz_int_rfeea0 (tmp[2], tmp[3], tmp[4], tmp[5], tmp[0], tmp[1], v,
                     tmp + 6);
      if (v1 == 0)
        break;
    }

#ifdef DEBUG2
  printf ("t0:="); OUT(tmp[0]);
  printf ("t1:="); OUT(tmp[1]);
#endif

  mpz_div_2exp (tmp[1], tmp[1], VAL(tmp[1]));
  mpz_gcd (g, tmp[0], tmp[1]);
  mpz_mul_2exp (g, g, va);

  for (i = 0; i < TMP_SIZE; i++)
    mpz_clear (tmp[i]);
}

void
Hstar (mp_ptr r, mp_srcptr a, mp_srcptr b, mp_size_t n)
{
  int i;
  if (n / 4 == 0)
    return;
  for (i=0; i<4; i++)
    mpn_mul (r, a, n / 2, b, n / 4);
  for (i=0; i<12; i++)
    mpn_mul (r, a, n / 4, b, n / 4);
  Hstar (r, a, b, n / 2);
  Hstar (r, a, b, n / 2);
}

void
H (mp_ptr r, mp_srcptr a, mp_srcptr b, mp_size_t n)
{
  int i;
  for (i=0; i<4; i++)
    mpn_mul (r, a, n / 2, b, n / 4);
  for (i=0; i<4; i++)
    mpn_mul (r, a, n / 4, b, n / 4);
  Hstar (r, a, b, n / 2);
  Hstar (r, a, b, n / 2);
}

void
G (mp_ptr r, mp_srcptr a, mp_srcptr b, mp_size_t n)
{
  if (n <= 1)
    return;
  H (r, a, b, n);
  G (r, a, b, (n + 1) / 2);
}

/* P=a[n], Q=2*a[n-1] such that a[0]=0, a[1]=1, a[n] = -a[n-1] + 4 a[n-2] */
void
mpz_bfib2_ui (mpz_t P, mpz_t Q, unsigned long n)
{
  if (n == 0) abort();
  mpz_set_ui (P, 1);
  mpz_set_ui (Q, 0); /* n=0 */
  while (n-- > 1)
    {
      mpz_mul_2exp (Q, Q, 2);
      mpz_sub (Q, Q, P);
      mpz_swap (P, Q);
    }
  mpz_mul_2exp (Q, Q, 1);
}

int
main (int argc, char *argv[])
{
  unsigned long n, j, k;
  mpz_t P, Q, R11, R12, R21, R22, g, g2;
  int st, fib = 0;

  if (argc < 2)
    {
      fprintf (stderr, "Usage: hgcd_bin [-fib] [-bfib] n [k]\n");
      fprintf (stderr, "       n     - input size (in limbs)\n");
      fprintf (stderr, "       k     - repeat k times\n");
      fprintf (stderr, "       -fib  - use F(n) and F(n-1) as input\n");
      fprintf (stderr, "       -bfib - use bF(n) and bF(n-1) as input\n");
      exit (1);
    }

  if (strcmp (argv[1], "-fib") == 0)
    {
      fib = 1;
      argv ++;
      argc --;
    }

  if (strcmp (argv[1], "-bfib") == 0)
    {
      fib = -1;
      argv ++;
      argc --;
    }

  n = atoi (argv[1]); /* number of limbs of input numbers */
  k = (argc > 2) ? atoi (argv[2]) : 1;

  mpz_init (P);
  mpz_init (Q);
  mpz_init (R11);
  mpz_init (R12);
  mpz_init (g);
  mpz_init (g2);

  if (fib == 1)
    {
      mpz_fib2_ui (P, Q, n);
      n = mpz_size (P);
      if (mpz_size (Q) > n)
        n = mpz_size (Q);
      printf ("Using Fibonacci numbers of %lu limbs\n", n);
    }
  else if (fib == -1)
    {
      mpz_bfib2_ui (P, Q, n);
      n = mpz_size (P);
      if (mpz_size (Q) > n)
        n = mpz_size (Q);
      printf ("Using bin-Fibonacci numbers of %lu limbs\n", n);
    }
  else
    {
      do
        {
          mpz_random (P, n);
          mpz_random (Q, n);
        }
      while (VAL(P) >= VAL(Q));
    }

  mpz_realloc (g, 2 * n);

  mul_cost_hgcd = 0.0;
  st = cputime ();
  for (j = 0; j < k; j++)
    mpz_gcd (g, P, Q);
  printf ("mpz_gcd took     %dms\n", cputime () - st);
#ifdef DEBUG
  printf ("mul_cost_hgcd = %f\n", mul_cost_hgcd);
  printf ("\n");
#endif

  mul_cost = 0.0;
  st = cputime ();
  for (j = 0; j < k; j++)
    mpz_fastgcd (g2, P, Q);
  printf ("mpz_fastgcd took %dms\n", cputime () - st);
#ifdef DEBUG
  printf ("mul_cost = %f\n", mul_cost);
#endif

  if (mpz_cmp (g, g2) != 0)
    {
      fprintf (stderr, "mpz_gcd and mpz_fastgcd differ\n");
      /*      printf ("g="); OUT(g);
              printf ("g2="); OUT(g2); */
      exit (1);
    }

  st = cputime ();
  for (j = 0; j < k; j++)
    mpz_gcdext (g, R11, R12, P, Q);
  printf ("mpz_gcdext took  %dms\n", cputime () - st);

  if (mpz_cmp (g, g2) != 0)
    {
      fprintf (stderr, "mpz_gcdext and mpz_fastgcd differ\n");
      /*      printf ("g="); OUT(g);
              printf ("g2="); OUT(g2); */
      exit (1);
    }

  mpz_clear (P);
  mpz_clear (Q);
  mpz_clear (R11);
  mpz_clear (R12);
  mpz_clear (g);
  mpz_clear (g2);

  return 0;
}
