/* Implementation of the Binary Subquadratic Gcd

   Reference: A Binary Recursive Gcd Algorithm, Damien Stehle' and Paul
   Zimmermann, Proceedings of the 6th International Symposium on Algorithmic
   Number Theory (ANTS VI), Burlington, USA, LNCS 3076, pages 411-425, 2004.

   This program must be compiled with a GMP build, for example:
   gcc -g -O2 -I/tmp/gmp-6.1.2 hgcd.c -o hgcd /tmp/gmp-6.1.2/.libs/libgmp.a

Copyright 2009-2019 Paul Zimmermann

This program is free software; you can redistribute it and/or modify
it under the terms of the GNU Lesser General Public License as published by
the Free Software Foundation; either version 2.1 of the License, or (at your
option) any later version.

This program is distributed in the hope that it will be useful, but
WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public
License for more details.

You should have received a copy of the GNU Lesser General Public License
along with that file; see the file COPYING.LIB.  If not, write to the Free
Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston,
MA 02110-1301, USA. */

#include <stdio.h>
#include <stdlib.h>
#include <stdint.h> /* for SIZE_MAX */
#include <assert.h>
#include "gmp.h"
#include "gmp-impl.h"
#include "longlong.h"
#include <sys/types.h>
#include <sys/resource.h>

#define HGCD_DC_THRESHOLD 256        /* bits */
#define WRAP_THRESHOLD 32768         /* bits */
#define STRASSEN_THRESHOLD (1 << 12) /* bits */

/* comment the #undef and #define lines to enable assertions */
#undef ASSERT_ALWAYS
#define ASSERT_ALWAYS(x) 

int
cputime ()
{
  struct rusage rus;

  getrusage (0, &rus);
  return rus.ru_utime.tv_sec * 1000 + rus.ru_utime.tv_usec / 1000;
}

#define SGN(x) ((x)->_mp_size > 0 ? 1 : (x)->_mp_size == 0 ? 0 : -1)
#define ONE ((mp_limb_t) 1)
#define ZERO ((mp_limb_t) 0)

#define mpz_addmul_si(a,b,c) do {               \
  if ((c) > 0)                                  \
    mpz_addmul_ui (a, b, c);                    \
  else                                          \
    mpz_submul_ui (a, b, -(c));                 \
  } while (0)

long
BinaryDivide1 (long *a0, long b0, size_t j)
{
  mp_limb_t invb0, u;
  long q;
  int T[] = {1, 171, 205, 183, 57, 163, 197, 239, 241, 27, 61, 167, 41, 19,
             53, 223, 225, 139, 173, 151, 25, 131, 165, 207, 209, 251, 29,
             135, 9, 243, 21, 191, 193, 107, 141, 119, 249, 99, 133, 175, 177,
             219, 253, 103, 233, 211, 245, 159, 161, 75, 109, 87, 217, 67,
             101, 143, 145, 187, 221, 71, 201, 179, 213, 127, 129, 43, 77, 55,
             185, 35, 69, 111, 113, 155, 189, 39, 169, 147, 181, 95, 97, 11,
             45, 23, 153, 3, 37, 79, 81, 123, 157, 7, 137, 115, 149, 63, 65,
             235, 13, 247, 121, 227, 5, 47, 49, 91, 125, 231, 105, 83, 117,
             31, 33, 203, 237, 215, 89, 195, 229, 15, 17, 59, 93, 199, 73,
             51, 85, 255};
  /* T[i] = 1/(2i+1) mod 2^8 */

  ASSERT_ALWAYS (j + 1 < GMP_NUMB_BITS);
  ASSERT_ALWAYS (GMP_NUMB_BITS <= 64);

  ASSERT_ALWAYS (b0 & 1);
  /* we want 1/b0 mod 2^(j+1) */
  invb0 = T[(b0 & 255) >> 1];
  if (j + 1 > 8)
    invb0 += invb0 * (1 - b0 * invb0); /* 2^16 */
  {
    if (j + 1 > 16)
      {
        invb0 += invb0 * (1 - b0 * invb0); /* 2^32 */
        if (j + 1 > 32)
          invb0 += invb0 * (1 - b0 * invb0); /* 2^64 */
      }
  }
  u = (-*a0) * invb0;
  u = u & ((ONE << (j + 1)) - ONE);
  if ((u >> j) != 0)
    {
      q = (long) u - ((long) 1 << (j + 1));
      u = -q;
    }
  else
    q = (long) u;
  ASSERT_ALWAYS (q & 1);

  /* now replace a by a + q*b */
  *a0 += b0 * q;

  return q;
}

long
BinaryDivide (mpz_t A, mpz_t B, size_t j)
{
  mp_limb_t b0, invb0, u;
  long q;
  mp_ptr a = PTR(A), b = PTR(B);
  int T[] = {1, 171, 205, 183, 57, 163, 197, 239, 241, 27, 61, 167, 41, 19,
             53, 223, 225, 139, 173, 151, 25, 131, 165, 207, 209, 251, 29,
             135, 9, 243, 21, 191, 193, 107, 141, 119, 249, 99, 133, 175, 177,
             219, 253, 103, 233, 211, 245, 159, 161, 75, 109, 87, 217, 67,
             101, 143, 145, 187, 221, 71, 201, 179, 213, 127, 129, 43, 77, 55,
             185, 35, 69, 111, 113, 155, 189, 39, 169, 147, 181, 95, 97, 11,
             45, 23, 153, 3, 37, 79, 81, 123, 157, 7, 137, 115, 149, 63, 65,
             235, 13, 247, 121, 227, 5, 47, 49, 91, 125, 231, 105, 83, 117,
             31, 33, 203, 237, 215, 89, 195, 229, 15, 17, 59, 93, 199, 73,
             51, 85, 255};
  /* T[i] = 1/(2i+1) mod 2^8 */

  ASSERT_ALWAYS (j + 1 < GMP_NUMB_BITS);
  ASSERT_ALWAYS (GMP_NUMB_BITS <= 64);

  b0 = b[0];
  ASSERT_ALWAYS (b0 & 1);
  /* we want 1/b0 mod 2^(j+1) */
  invb0 = T[(b0 & 255) >> 1];
  if (j + 1 > 8)
    invb0 += invb0 * (1 - b0 * invb0); /* 2^16 */
  {
    if (j + 1 > 16)
      {
        invb0 += invb0 * (1 - b0 * invb0); /* 2^32 */
        if (j + 1 > 32)
          invb0 += invb0 * (1 - b0 * invb0); /* 2^64 */
      }
  }
  if (SGN(A) == SGN(B))
    u = (-a[0]) * invb0;
  else
    u = a[0] * invb0;
  u = u & ((ONE << (j + 1)) - ONE);
  if ((u >> j) != 0)
    {
      q = (long) u - ((long) 1 << (j + 1));
      u = -q;
    }
  else
    q = (long) u;
  ASSERT_ALWAYS (q & 1);

  /* now replace a by a + q*b */
  mpz_addmul_si (A, B, q);

  return q;
}

void
BinaryDivideSlow (mpz_t A, mpz_t B, size_t j, mpz_t q)
{
  /* we can use slow code here, since it will be called rarely */
  size_t n = 1 + j / GMP_NUMB_BITS; /* ceil ((j+1) / GMP_NUMB_BITS) */
  mpz_t t;

  mpz_init (t);

  /* q <- B / 2^j */
  mpz_set_ui (t, 1);
  mpz_mul_2exp (t, t, n * GMP_NUMB_BITS);
  mpz_invert (q, B, t);
  mpz_mul (q, q, A);
  mpz_neg (q, q);
  mpz_tdiv_r_2exp (q, q, j + 1);
  /* check if |q| >= 2^j */
  mpz_set_ui (t, 1);
  mpz_mul_2exp (t, t, j);
  if (mpz_cmpabs (q, t) >= 0)
    {
      mpz_mul_2exp (t, t, 1);
      if (mpz_sgn (q) > 0)
        mpz_sub (q, q, t);
      else
        mpz_add (q, q, t);
    }
  ASSERT_ALWAYS (mpz_scan1 (q, 0) == 0);
  ASSERT_ALWAYS (mpz_sizeinbase (q, 2) <= j);

  /* a <- a + q*B */
  mpz_addmul (A, B, q);

  mpz_clear (t);
}

/* A <- B*A.
   This is Strassen-Winograd (with 15 additions), whereas Strassen's
   original algorithm has 18 additions. See http://arxiv.org/abs/0707.2347
   to decrease the memory usage of temporary variables.
*/
void
strassen_mul (mpz_t A00, mpz_t A01, mpz_t A10, mpz_t A11,
              mpz_t B00, mpz_t B01, mpz_t B10, mpz_t B11)
{
  mpz_t S0, T0, S1, T1, S2, T2, S3, T3, P0, P1, P2, P3, P4, P5, P6;

  mpz_init (S0);
  mpz_init (T0);
  mpz_init (S1);
  mpz_init (T1);
  mpz_init (S2);
  mpz_init (T2);
  mpz_init (S3);
  mpz_init (T3);
  mpz_init (P0);
  mpz_init (P1);
  mpz_init (P2);
  mpz_init (P3);
  mpz_init (P4);
  mpz_init (P5);
  mpz_init (P6);
  mpz_add (S0, B10, B11);
  mpz_sub (T0, A01, A00);
  mpz_sub (S1, S0, B00);
  mpz_sub (T1, A11, T0);
  mpz_sub (S2, B00, B10);
  mpz_sub (T2, A11, A01);
  mpz_sub (S3, B01, S1);
  mpz_sub (T3, A10, T1);
  mpz_mul (P0, B00, A00);
  mpz_mul (P1, B01, A10);
  mpz_mul (P2, S0, T0);
  mpz_mul (P3, S1, T1);
  mpz_mul (P4, S2, T2);
  mpz_mul (P5, S3, A11);
  mpz_mul (P6, B11, T3);
  mpz_add (A00, P0, P1);
  mpz_add (A01, P0, P3);
  mpz_add (A11, A01, P4);
  mpz_add (A10, A11, P6);
  mpz_add (A11, A11, P2);
  mpz_add (A01, A01, P2);
  mpz_add (A01, A01, P5);
  mpz_clear (S0);
  mpz_clear (T0);
  mpz_clear (S1);
  mpz_clear (T1);
  mpz_clear (S2);
  mpz_clear (T2);
  mpz_clear (S3);
  mpz_clear (T3);
  mpz_clear (P0);
  mpz_clear (P1);
  mpz_clear (P2);
  mpz_clear (P3);
  mpz_clear (P4);
  mpz_clear (P5);
  mpz_clear (P6);
}

/* always use flag = 1 */
int
hgcd_ref1 (long A, long B, size_t k,
           long *R11, long *R12, long *R21, long *R22)
{
  size_t j, j_tot = 0, jj;
  long q, tmp;

  count_trailing_zeros (j, B);
  ASSERT_ALWAYS (j > 0);

  *R11 = 1;
  *R12 = 0;
  *R21 = 0;
  *R22 = 1;

  if (j > k)
    return 0;

  B >>= j;
  while (1)
    {
      /* perform one binary division step */
      q = BinaryDivide1 (&A, B, j);
      tmp = (*R11 << j) + *R21 * q;
      *R11 = *R21 << j;
      *R21 = tmp;
      tmp = (*R12 << j) + *R22 * q;
      *R12 = *R22 << j;
      *R22 = tmp;
      j_tot += j;

      if (A == 0)
        break;

      /* divide {b, n} and {a, n} by 2^j */
      count_trailing_zeros (jj, A);
      A >>= jj;
      j = jj - j;

      if (j_tot + j > k)
        break;

      /* perform one binary division step */
      q = BinaryDivide1 (&B, A, j);
      tmp = (*R11 << j) + *R21 * q;
      *R11 = *R21 << j;
      *R21 = tmp;
      tmp = (*R12 << j) + *R22 * q;
      *R12 = *R22 << j;
      *R22 = tmp;
      j_tot += j;

      if (B == 0)
        break;

      /* divide {b, n} and {a, n} by 2^j */
      count_trailing_zeros (jj, B);
      B >>= jj;
      j = jj - j;

      if (j_tot + j > k)
        break;
    }

  return j_tot;
}

int
hgcd_ref (mpz_ptr A0, mpz_ptr B0, size_t k,
          mpz_t R11, mpz_t R12, mpz_t R21, mpz_t R22, int flag)
{
  size_t j, j_tot = 0, jj;
  long q;
  mpz_t A, B;

  if (2 * k + 1 < GMP_NUMB_BITS)
    {
      long r11, r12, r21, r22, a0, b0;
      a0 = SIZ(A0) >= 0 ? PTR(A0)[0] : -PTR(A0)[0];
      b0 = SIZ(B0) >= 0 ? PTR(B0)[0] : -PTR(B0)[0];
      j = hgcd_ref1 (a0, b0, k, &r11, &r12, &r21, &r22);
      if (flag & 1)
        {
          mpz_set_si (R11, r11);
          mpz_set_si (R12, r12);
          mpz_set_si (R21, r21);
          mpz_set_si (R22, r22);
        }
      if (flag & 2)
        {
          mpz_t tmp;
          mpz_init (tmp);
          mpz_mul_si (tmp, A0, r11);
          mpz_addmul_si (tmp, B0, r12);
          mpz_mul_si (B0, B0, r22);
          mpz_addmul_si (B0, A0, r21);
          mpz_tdiv_q_2exp (A0, tmp, 2 * j);
          mpz_tdiv_q_2exp (B0, B0, 2 * j);
          mpz_clear (tmp);
        }
      return j;
    }

  j = mpz_scan1 (B0, 0);
  ASSERT_ALWAYS (j > 0);

  if (flag & 1)
    {
      mpz_set_ui (R11, 1);
      mpz_set_ui (R12, 0);
      mpz_set_ui (R21, 0);
      mpz_set_ui (R22, 1);
    }

  if (j > k)
    return 0;

  mpz_init_set (A, A0);
  mpz_init_set (B, B0);

  mpz_tdiv_q_2exp (B, B, j);
  while (1)
    {
      /* perform one binary division step */
      if (j + 1 < GMP_NUMB_BITS)
        {
          q = BinaryDivide (A, B, j);
          if (flag & 1)
            {
              mpz_mul_2exp (R11, R11, j);
              mpz_addmul_si (R11, R21, q);
              mpz_mul_2exp (R21, R21, j);
              mpz_swap (R11, R21);
              mpz_mul_2exp (R12, R12, j);
              mpz_addmul_si (R12, R22, q);
              mpz_mul_2exp (R22, R22, j);
              mpz_swap (R12, R22);
            }
        }
      else
        {
          mpz_t Q;
          mpz_init (Q);
          BinaryDivideSlow (A, B, j, Q);
          /* update R */
          if (flag & 1)
            {
              mpz_mul_2exp (R11, R11, j);
              mpz_addmul (R11, R21, Q);
              mpz_mul_2exp (R21, R21, j);
              mpz_swap (R11, R21);
              mpz_mul_2exp (R12, R12, j);
              mpz_addmul (R12, R22, Q);
              mpz_mul_2exp (R22, R22, j);
              mpz_swap (R12, R22);
            }
          mpz_clear (Q);
        }
      j_tot += j;

      if (mpz_cmp_ui (A, 0) == 0)
        {
          if (flag & 2)
            {
              mpz_swap (A0, B); /* B was already divided by 2^j, and is odd */
              mpz_swap (B0, A); /* 0 */
            }
          break;
        }

      /* divide {b, n} and {a, n} by 2^j */
      jj = mpz_scan1 (A, j) - j;
      mpz_tdiv_q_2exp (A, A, j + jj);
      j = jj;

      if (j_tot + j > k)
        {
          if (flag & 2)
            {
              mpz_swap (A0, B);
              mpz_mul_2exp (B0, A, jj); /* restore the even A */
            }
          break;
        }

      /* perform one binary division step */
      if (j + 1 < GMP_NUMB_BITS)
        {
          q = BinaryDivide (B, A, j);
          if (flag & 1)
            {
              mpz_mul_2exp (R11, R11, j);
              mpz_addmul_si (R11, R21, q);
              mpz_mul_2exp (R21, R21, j);
              mpz_swap (R11, R21);
              mpz_mul_2exp (R12, R12, j);
              mpz_addmul_si (R12, R22, q);
              mpz_mul_2exp (R22, R22, j);
              mpz_swap (R12, R22);
            }
        }
      else
        {
          mpz_t Q;
          mpz_init (Q);
          BinaryDivideSlow (B, A, j, Q);
          /* update R */
          if (flag & 1)
            {
              mpz_mul_2exp (R11, R11, j);
              mpz_addmul (R11, R21, Q);
              mpz_mul_2exp (R21, R21, j);
              mpz_swap (R11, R21);
              mpz_mul_2exp (R12, R12, j);
              mpz_addmul (R12, R22, Q);
              mpz_mul_2exp (R22, R22, j);
              mpz_swap (R12, R22);
            }
          mpz_clear (Q);
        }
      j_tot += j;

      if (mpz_cmp_ui (B, 0) == 0)
        {
          if (flag & 2)
            {
              mpz_swap (A0, A); /* A was already divided by 2^j, and is odd */
              mpz_swap (B0, B); /* 0 */
            }
          break;
        }

      /* divide {b, n} and {a, n} by 2^j */
      jj = mpz_scan1 (B, j) - j;
      mpz_tdiv_q_2exp (B, B, j + jj);
      j = jj;

      if (j_tot + j > k)
        {
          if (flag & 2)
            {
              mpz_swap (A0, A);
              mpz_mul_2exp (B0, B, jj); /* restore the even B */
            }
          break;
        }
    }

  mpz_clear (A);
  mpz_clear (B);

  return j_tot;
}

#if 0
/* extract the n bits from a, starting at bit j */
void
extract_bits (mpz_t a, size_t j, size_t n)
{
  size_t q = j / GMP_NUMB_BITS;
  size_t r = j % GMP_NUMB_BITS;
  size_t q2 = 1 + (j + n - 1) / GMP_NUMB_BITS;
  mp_ptr ap = PTR(a);

  /* the first bit is bit r in limb q, the last bit is in limb q2-1 */
  if (q2 > ABSIZ(a))
    q2 = ABSIZ(a);
  q2 -= q;
  if (r > 0)
    mpn_rshift (ap, ap + q, q2, r);
  else
    MPN_COPY (ap, ap + q, q2);
  q = n / GMP_NUMB_BITS;
  r = n % GMP_NUMB_BITS;
  if (r + 1 < GMP_NUMB_BITS && q + 1 == q2)
    ap[q] &= (ONE << (r + 1)) - ONE;
  MPN_NORMALIZE(ap, q2);
  SIZ(a) = SIZ(a) > 0 ? q2 : -q2;
}
#endif

/* a <- b * c + d * e, where we know that the product is divisible by 2^N */
void
wrap_mul (mpz_t a, mpz_t b, mpz_t c, mpz_t d, mpz_t e, size_t N)
{
  mp_size_t s1 = mpz_sizeinbase (b, 2) + mpz_sizeinbase (c, 2);
  mp_size_t s2 = mpz_sizeinbase (d, 2) + mpz_sizeinbase (e, 2);
  mp_size_t s, i;
  int k, sbc, sde;
  mp_ptr bc, de;
  mp_limb_t cy;
  mp_size_t n = N / GMP_NUMB_BITS;

  ASSERT_ALWAYS(n > 0);

  /* we have |b*c| <= 2^s1 and |d*e| <= 2^s2,
     thus |b*c + d*e| <= 2^(max(s1,s2)+1), so that the possible range
     for b*c + d*e has width < 2^s */
  s = s1 >= s2 ? s1 + 2: s2 + 2;

  s = s - n * GMP_NUMB_BITS;
  s = 1 + (s - 1) / GMP_NUMB_BITS; /* convert to limbs */
  k = mpn_fft_best_k (s, 0);
  s = mpn_fft_next_size (s, k);
  mpz_realloc2 (a, (n + s + 1) * GMP_NUMB_BITS);
  bc = PTR(a);
  if (SIZ(b) == 0 || SIZ(c) == 0)
    MPN_ZERO(bc, s);
  else
    mpn_mul_fft (bc, s, PTR(b), ABSIZ(b), PTR(c), ABSIZ(c), k);
  de = (mp_ptr) malloc (s * sizeof (mp_limb_t));
  mpn_mul_fft (de, s, PTR(d), ABSIZ(d), PTR(e), ABSIZ(e), k);
  sbc = SGN(b) * SGN(c);
  sde = SGN(d) * SGN(e);
  if (sbc == sde)
    {
      cy = mpn_add_n (bc, bc, de, s);
      cy = mpn_sub_1 (bc, bc, s, cy);
      ASSERT_ALWAYS (cy == 0);
    }
  else
    {
      cy = mpn_sub_n (bc, bc, de, s);
      /* we have to subtract cy at B^s, i.e., add cy at 0 */
      cy = mpn_add_1 (bc, bc, s, cy);
      ASSERT_ALWAYS (cy == 0);
    }

  if (n <= s)
    {
  /* we have now computed A = b * c + d * e mod B^s+1, and we want
     A + x * (B^s+1) = 0 (mod B^n), i.e., A + x = 0 (mod B^n),
     assuming s >= n. Thus if A = H*B^n + L, we want
     H*B^n + L + (B^n - L) * (B^s+1) or
     H*B^n + L - L * (B^s+1) */
  for (i = 0; i < n && bc[i] == ZERO; i++);
  if (i < n) /* if i = n, the low n limbs are already 0 */
    {
      if ((bc[n - 1] & GMP_LIMB_HIGHBIT) != 0)
        { /* compute H*B^n + L + (B^n - L) * (B^s+1) */
          MPN_ZERO(bc + s, n);
          cy = ONE - mpn_sub_n (bc + s, bc + s, bc, n);  /* add B^(n+s) and
                                                            subtract B^n * L */
          MPN_ZERO(bc, n);                          /* subtract L */
          cy += mpn_add_1 (bc + n, bc + n, s, ONE); /* add B^n */
          ASSERT_ALWAYS (cy == 0);
        }
      else
        { /* compute H*B^n + L - L * (B^s+1)
             = - (L * B^s - H*B^n) */
          MPN_COPY (bc + s, bc, n); /* L * B^s */
          if (s > n)
            {
              mpn_com (bc + n, bc + n, s - n); /* B^s - H*B^n - B^n */
              cy = ONE - mpn_add_1 (bc + n, bc + n, s - n, ONE);
              cy = mpn_sub_1 (bc + s, bc + s, n, cy);
            }
          else
            cy = ZERO;
          ASSERT_ALWAYS (cy == 0);
          MPN_ZERO (bc, n);
          sbc = -sbc;
        }
      s = n + s;
    }
    }
  else /* case s < n: let b*c+d*e mod (B^s+1) = L1 * B^(n-s) + L0,
          where L1 has 2s-n limbs and L0 has n-s limbs. We want
          L0 * B^(2*s) - L1 * B^n if L0 < 1/2*B^(n-s), and
          B^n*(B^s+1) - L0 * B^(2*s) + L1 * B^n otherwise */
    {
      ASSERT_ALWAYS (n <= 2 * s);
      for (i = 0; (i < n - s) & (bc[i] == ZERO); i ++);
      if (i < n - s) /* L0 is not zero */
        {
          if ((bc[n - s - 1] & GMP_LIMB_HIGHBIT) == 0) /* L0 < 1/2*B^(n-s) */
            {
              if (2 * s > n)
                {
                  mpn_com (bc + n, bc + n - s, 2 * s - n); /* -L1 * B^n */
                  cy = ONE - mpn_add_1 (bc + n, bc + n, 2 * s - n, ONE);
                  cy = mpn_sub_1 (bc + 2 * s, bc, n - s, cy);
                }
              else
                cy = ZERO;
            }
          else /* B^n*(B^s+1) - L0 * B^(2*s) + L1 * B^n */
            {
              cy = mpn_add_1 (bc + n, bc + n - s, 2 * s - n, ONE);
              /* (L1+1) * B^n */
              mpn_com (bc + 2 * s, bc, n - s); /* -L0 * B^(2*s) */
              cy = mpn_add_1 (bc + 2 * s, bc + 2 * s, n - s, ONE);
              ASSERT_ALWAYS (cy == 0);
              sbc = -sbc;
            }
        }
      else /* L0 is zero, result is - L1 * B^n */
        {
          MPN_COPY (bc + n, bc + n - s, 2 * s - n);
          MPN_ZERO (bc + 2 * s, n - s);
          sbc = -sbc;
        }
      MPN_ZERO(bc, n);
      s = n + s;
    }
  MPN_NORMALIZE(bc, s);
  SIZ(a) = sbc == 1 ? s : -s;

  free (de);
}

/* Input: A = sa * {a, n} and B = sb * {b, n} with 0 = nu2(A) < nu2(B).
   Return value: j

   If flag & 1: computes the 2x2 matrix (does not modify A, B).
   If flag & 2: computes the final terms of the remainder sequence
                (in place of A and B).
 */
int
hgcd (mpz_ptr A, mpz_ptr B, mp_size_t k,
      mpz_t R11, mpz_t R12, mpz_t R21, mpz_t R22, int flag)
{
  mp_size_t j, k1, j1, j0, k2, j2;
  mpz_t a1, b1, S11, S12, S21, S22, tmp;
  long q;

  if (k < HGCD_DC_THRESHOLD)
    return hgcd_ref (A, B, k, R11, R12, R21, R22, flag);

  ASSERT_ALWAYS (mpz_scan1 (A, 0) == 0);
  ASSERT_ALWAYS (mpz_scan1 (B, 0) != 0);

  j = mpz_scan1 (B, 0);
  if (j > k)
    {
      if (flag & 1)
        {
          mpz_set_ui (R11, 1);
          mpz_set_ui (R12, 0);
          mpz_set_ui (R21, 0);
          mpz_set_ui (R22, 1);
        }
      return 0;
    }

  mpz_init (a1);
  mpz_init (b1);

  k1 = k / 2;

  mpz_tdiv_r_2exp (a1, A, 2 * k1 + 1);
  mpz_tdiv_r_2exp (b1, B, 2 * k1 + 1);

  j1 = hgcd (a1, b1, k1, R11, R12, R21, R22, 1);

  /* a1 <- {a, n} * R11 + {b, n} * R12 */
  /* b1 <- {a, n} * R21 + {b, n} * R22 */
  if (2 * j1 >= WRAP_THRESHOLD)
    {
      wrap_mul (a1, A, R11, B, R12, 2 * j1);
      wrap_mul (b1, A, R21, B, R22, 2 * j1);
    }
  else
    {
      mpz_mul (a1, A, R11);
      mpz_addmul (a1, B, R12);
      mpz_mul (b1, A, R21);
      mpz_addmul (b1, B, R22);
    }

  /* divide a1, b1 by 2^(2j1) */
  mpz_tdiv_q_2exp (a1, a1, 2 * j1);

  if (SIZ(b1) == 0)
    {
      if (flag & 2)
        {
          mpz_set (A, a1);
          mpz_set_ui (B, 0);
        }
      mpz_clear (a1);
      mpz_clear (b1);
      return j1;
    }

  j0 = mpz_scan1 (b1, 2 * j1) - 2 * j1;

  ASSERT_ALWAYS (SIZ(b1) != 0);
  if (j0 + j1 > k)
    {
      if (flag & 2)
        {
          mpz_set (A, a1);
          mpz_tdiv_q_2exp (B, b1, 2 * j1);
        }
      mpz_clear (a1);
      mpz_clear (b1);
      return j1;
    }

  mpz_tdiv_q_2exp (b1, b1, 2 * j1 + j0);

  if (j0 + 1 < GMP_NUMB_BITS)
    {
      q = BinaryDivide (a1, b1, j0);
      /* R11 <- 2^j0 * R21
         R12 <- 2^j0 * R22
         R21 <- 2^j0 * R11 + q * R21
         R22 <- 2^j0 * R12 + q * R22 */
      mpz_mul_2exp (R11, R11, j0);
      mpz_addmul_si (R11, R21, q);
      mpz_mul_2exp (R21, R21, j0);
      mpz_swap (R11, R21);
      mpz_mul_2exp (R12, R12, j0);
      mpz_addmul_si (R12, R22, q);
      mpz_mul_2exp (R22, R22, j0);
      mpz_swap (R12, R22);
    }
  else
    {
      mpz_t Q;
      mpz_init (Q);
      BinaryDivideSlow (a1, b1, j0, Q);
      mpz_mul_2exp (R11, R11, j0);
      mpz_addmul (R11, R21, Q);
      mpz_mul_2exp (R21, R21, j0);
      mpz_swap (R11, R21);
      mpz_mul_2exp (R12, R12, j0);
      mpz_addmul (R12, R22, Q);
      mpz_mul_2exp (R22, R22, j0);
      mpz_swap (R12, R22);
      mpz_clear (Q);
    }    

  if (SIZ (a1) == 0) /* a1 is zero */
    {
      if (flag & 2)
        {
          mpz_set (A, b1);
          mpz_set_ui (B, 0);
        }
      mpz_clear (a1);
      mpz_clear (b1);
      return j0 + j1;
    }

  k2 = k - (j0 + j1);

  mpz_init (tmp);
  mpz_init (S11);
  mpz_init (S12);
  mpz_init (S21);
  mpz_init (S22);

  mpz_tdiv_q_2exp (a1, a1, j0);

  if (flag & 2) /* save a1, b1 */
    {
      mpz_set (A, b1);
      mpz_set (B, a1);
    }

  mpz_tdiv_r_2exp (a1, a1, 2 * k2 + 1);
  mpz_tdiv_r_2exp (b1, b1, 2 * k2 + 1);

  j2 = hgcd (b1, a1, k2, S11, S12, S21, S22, 1);

  if (flag & 1)
    {
      /* R11 <- S11 * R11 + S12 * R21
         R12 <- S11 * R12 + S12 * R22
         R21 <- S21 * R11 + S22 * R21
         R22 <- S21 * R12 + S22 * R22 */
      if (k > STRASSEN_THRESHOLD)
        strassen_mul (R11, R12, R21, R22, S11, S12, S21, S22);
      else
        {
          mpz_mul (tmp, S11, R11);
          mpz_addmul (tmp, S12, R21); /* new R11 */
          mpz_mul (R11, S21, R11);
          mpz_addmul (R11, S22, R21); /* new R21 */
          mpz_swap (R21, R11);
          mpz_swap (R11, tmp);
          mpz_mul (tmp, S11, R12);
          mpz_addmul (tmp, S12, R22); /* new R12 */
          mpz_mul (R12, S21, R12);
          mpz_addmul (R12, S22, R22); /* new R22 */
          mpz_swap (R22, R12);
          mpz_swap (R12, tmp);
        }
    }
  if (flag & 2)
    {
      /* A <- 2^(-2j2) * (S11 * A + S12 * B)
         B <- 2^(-2j2) * (S21 * A + S22 * B) */
      if (2 * j2 >= WRAP_THRESHOLD)
        {
          wrap_mul (tmp, S11, A, S12, B, 2 * j2);
          wrap_mul (a1, S21, A, S22, B, 2 * j2);
          mpz_swap (B, a1);
        }
      else
        {
          mpz_mul (tmp, S11, A);
          mpz_addmul (tmp, S12, B);
          mpz_mul (B, S22, B);
          mpz_addmul (B, S21, A);
        }
      mpz_tdiv_q_2exp (B, B, 2 * j2);
      mpz_tdiv_q_2exp (A, tmp, 2 * j2);
    }

  mpz_clear (tmp);
  mpz_clear (S11);
  mpz_clear (S12);
  mpz_clear (S21);
  mpz_clear (S22);

  mpz_clear (a1);
  mpz_clear (b1);

  return j1 + j0 + j2;
}

void
mpz_bgcd (mpz_t g, mpz_t a0, mpz_t b0)
{
  mpz_t a, b, R11, R12, R21, R22;
  size_t j, va, vb, vg;

  if (SIZ(a0) == 0)
    {
      mpz_set (g, b0);
      return;
    }
  if (SIZ(b0) == 0)
    {
      mpz_set (g, a0);
      return;
    }

  mpz_init (a);
  mpz_init (b);
  mpz_init (R11);
  mpz_init (R12);
  mpz_init (R21);
  mpz_init (R22);

  va = mpz_scan1 (a0, 0);
  vb = mpz_scan1 (b0, 0);

  if (va < vb)
    {
      vg = va;
      mpz_tdiv_q_2exp (a, a0, vg);
      mpz_tdiv_q_2exp (b, b0, vg);
    }
  else if (vb < va)
    {
      vg = vb;
      mpz_tdiv_q_2exp (a, b0, vg);
      mpz_tdiv_q_2exp (b, a0, vg);
    }
  else
    {
      vg = va;
      mpz_tdiv_q_2exp (a, a0, vg);
      mpz_tdiv_q_2exp (b, b0, vg);
      if (mpz_sgn (a) * mpz_sgn (b) > 0)
        mpz_sub (b, b, a);
      else
        mpz_add (b, b, a);
    }

  /* now a is odd, and b is even */

  while (SIZ(b) != 0)
    {
      j = mpz_sizeinbase (b, 2) / 3;
      /* we need j >= nu2(b) otherwise hgcd() will return the identity
         matrix */
      if (j < mpz_scan1 (b, 0))
        j = mpz_scan1 (b, 0);
      hgcd (a, b, j, R11, R12, R21, R22, 2);
    }

  mpz_tdiv_q_2exp (g, a, mpz_scan1 (a, 0));
  mpz_mul_2exp (g, g, vg);
  mpz_abs (g, g);

  mpz_clear (a);
  mpz_clear (b);
  mpz_clear (R11);
  mpz_clear (R12);
  mpz_clear (R21);
  mpz_clear (R22);
}

int
main (int argc, char *argv[])
{
  mpz_t a, b, a_ref, b_ref, g, g_ref;
  size_t k, n = atoi (argv[1]);
  int st, i, I;

  I = (argc > 2) ? atoi (argv[2]) : 1;

  mpz_init (g);
  mpz_init (g_ref);
  mpz_init2 (a, n * GMP_NUMB_BITS);
  mpz_init2 (b, n * GMP_NUMB_BITS);
  mpz_init2 (a_ref, (n + 1) * GMP_NUMB_BITS);
  mpz_init2 (b_ref, (n + 1) * GMP_NUMB_BITS);

  while (1)
    {
      mpz_random (a, n);
      mpz_random (b, n);
      k = n * GMP_NUMB_BITS;

      st = cputime ();
      for (i = 0; i < I; i++)
        mpz_gcd (g_ref, a, b);
      printf ("mpz_gcd took %dms\n", cputime () - st);

      st = cputime ();
      for (i = 0; i < I; i++)
        mpz_bgcd (g, a, b);
      printf ("mpz_bgcd took %dms\n", cputime () - st);

      if (mpz_cmp (g, g_ref) != 0)
        {
          gmp_printf ("mpz_gcd and mpz_bgcd differ\n");
          gmp_printf ("g=%Zd\n", g);
          gmp_printf ("g_ref=%Zd\n", g_ref);
          exit (1);
        }
    }

  mpz_clear (a);
  mpz_clear (b);
  mpz_clear (a_ref);
  mpz_clear (b_ref);
  mpz_clear (g);
  mpz_clear (g_ref);
}
