/* floating-point Newton, with inversion in 3M(n) */

#include <stdio.h>
#include <stdlib.h>
#include <assert.h>
#include "gmp.h"
#include "gmp-impl.h"
#include "longlong.h"

#define ZERO (mp_limb_t) 0
#define ONE  (mp_limb_t) 1

/* mpn_com_n was renamed to mpn_com in recent GMP */
#ifndef mpn_com_n
#define mpn_com_n mpn_com
#endif

/* Input: A = {ap, n} with most significant bit set.
   Output: X = B^n + {xp, n} where B = 2^GMP_NUMB_BITS.

   X is a lower approximation of B^(2n)/A with implicit msb.
   More precisely, one has:

              A*X < B^(2n) <= A*(X+1).
*/
void
mpn_invert2 (mp_ptr xp, mp_srcptr ap, mp_size_t n)
{
  if (n == 1)
    {
      /* invert_limb returns min(B-1, floor(B^2/ap[0])-B),
	 which is B-1 when ap[0]=B/2, and 1 when ap[0]=B-1.
	 For X=B+xp[0], we have A*X < B^2 <= A*(X+1) where
	 the equality holds only when A=B/2.

	 We thus have A*X < B^2 <= A*(X+1).
      */
      invert_limb (xp[0], ap[0]);
    }
  else if (n == 2)
    {
      mp_limb_t tp[4], up[2], sp[2], cy;

      tp[0] = ZERO;
      invert_limb (xp[1], ap[1]);
      tp[3] = mpn_mul_1 (tp + 1, ap, 2, xp[1]);
      cy = mpn_add_n (tp + 2, tp + 2, ap, 2);
      while (cy) /* Xh is too large */
	{
	  xp[1] --;
	  cy -= mpn_sub (tp + 1, tp + 1, 3, ap, 2);
	}
      /* tp[3] should be 111...111 */

      mpn_com_n (sp, tp + 1, 2);
      cy = mpn_add_1 (sp, sp, 2, ONE);
      /* cy should be 0 */

      up[1] = mpn_mul_1 (up, sp + 1, 1, xp[1]);
      cy = mpn_add_1 (up + 1, up + 1, 1, sp[1]);
      /* cy should be 0 */
      xp[0] = up[1];

      /* update tp */
      cy = mpn_addmul_1 (tp, ap, 2, xp[0]);
      cy = mpn_add_1 (tp + 2, tp + 2, 2, cy);
      do
	{
	  cy = mpn_add (tp, tp, 4, ap, 2);
	  if (cy == ZERO)
	    mpn_add_1 (xp, xp, 2, ONE);
	}
      while (cy == ZERO);

      /* now A*X < B^4 <= A*(X+1) */
    }
  else
    {
      mp_size_t l, h;
      mp_ptr tp, up;
      mp_limb_t cy, th;
      TMP_DECL;

      l = (n - 1) / 2;
      h = n - l;

      mpn_invert2 (xp + l, ap + l, h);

      TMP_MARK;
      tp = TMP_ALLOC_LIMBS (n + h);
      up = TMP_ALLOC_LIMBS (2 * h);
      mpn_mul (tp, ap, n, xp + l, h);
      cy = mpn_add_n (tp + h, tp + h, ap, n);
      while (cy)
	{
	  mpn_sub_1 (xp + l, xp + l, h, ONE);
	  cy -= mpn_sub (tp, tp, n + h, ap, n);
	}

#if 0 /* original code */
      mpn_com_n (tp, tp, n);
      mpn_add_1 (tp, tp, n, ONE);
#else /* suggestion from Marco Bodrato:
         https://gmplib.org/list-archives/gmp-devel/2015-April/003946.html */
      mpn_neg (tp, tp, n);
#endif
      mpn_mul_n (up, tp + l, xp + l, h);
      cy = mpn_add_n (up + h, up + h, tp + l, h - l);
      mpn_add_nc (xp, up + 2*h - l, tp + h, l, cy);
      if (up[2*h-l-1] + 3 <= CNST_LIMB(2)) /* X might be off by 1 */
        {
          mp_ptr vp = TMP_ALLOC_LIMBS (n + n);
          mpn_mul_n (vp, ap, xp, n);
          cy = mpn_add_n (vp + n, vp + n, ap, n);
          assert (cy == 0); /* A*X should be less than B^(2n) */
          cy = mpn_add (vp, vp, n + n, ap, n);
          if (cy == 0) /* A*(X+1) < B^(2n) */
            {
              cy = mpn_add_1 (xp, xp, n, ONE);
              assert (cy == 0);
            }
        }
      TMP_FREE;
    }
}

int
test_invert2 (mp_ptr xp, mp_srcptr ap, mp_size_t n)
{
  int res = 1;
  mp_size_t i;
  mp_ptr tp, up;
  mp_limb_t cy;
  TMP_DECL;

  TMP_MARK;
  tp = TMP_ALLOC_LIMBS (2 * n);
  up = TMP_ALLOC_LIMBS (2 * n);

  /* first check X*A < B^(2*n) */
  mpn_mul_n (tp, xp, ap, n);
  cy = mpn_add_n (tp + n, tp + n, ap, n); /* A * msb(X) */

  /* now check B^(2n) - X*A <= A */
  mpn_com_n (tp, tp, 2 * n);
  mpn_add_1 (tp, tp, 2 * n, 1); /* B^(2n) - X*A */
  MPN_ZERO (up, 2 * n);
  MPN_COPY (up, ap, n);
  res = (cy == 0) && (mpn_cmp (tp, up, 2 * n) <= 0);
  TMP_FREE;
  return res;
}

#ifdef MAIN
#include <sys/types.h>
#include <sys/resource.h>

int
cputime ()
{
  struct rusage rus;

  getrusage (0, &rus);
  return rus.ru_utime.tv_sec * 1000 + rus.ru_utime.tv_usec / 1000;
}

int
main (int argc, char *argv[])
{
  mp_size_t n = atoi (argv[1]), i, j, k;
  mp_ptr qp, rp, dp, tp, qp2, rp2;
  mp_limb_t cy;
  pid_t pid;
  int st, st0;

  k = (argc <= 2) ? 1 : atoi(argv[2]);

  qp = malloc (n * sizeof (mp_limb_t));
  qp2 = malloc (n * sizeof (mp_limb_t));
  rp = malloc (n * sizeof (mp_limb_t));
  rp2 = malloc (2 * n * sizeof (mp_limb_t));
  dp = malloc (n * sizeof (mp_limb_t));
  tp = malloc (2 * n * sizeof (mp_limb_t));

  pid = getpid ();
  printf ("Seed=%lu\n", pid);
  srand48 (pid);
  for (i = 0; i < n; i++)
    dp[i] = lrand48 ();
  dp[n - 1] |= GMP_NUMB_HIGHBIT;
  for (i = 0; i < n; i++)
    rp[i] = lrand48 ();

  st = cputime ();
  for (i = 0; i < k; i++)
    mpn_mul_n (tp, dp, rp, n);
  st0 = cputime () - st;
  printf ("mpn_mul_n took %dms (%.3f)\n", st0, (double) st0 / (double) k);
  
  st = cputime ();
  for (i = 0; i < k; i++)
    {
#ifdef CHECK2
      if (i < k / 3) /* test small numbers */
        {
          for (j = 0; j < n; j++)
            dp[j] = 0;
          dp[0] = i;
        }
      else if (i < 2 * k / 3)
        {
          for (j = 0; j < n; j++)
            dp[j] = lrand48 ();
        }
      else /* test large numbers */
        {
          for (j = 0; j < n; j++)
            dp[j] = ~0;
          dp[0] = (mp_limb_t) i - (mp_limb_t) k;
        }
      dp[n - 1] |= GMP_NUMB_HIGHBIT;
#endif      
      mpn_invert2 (qp, dp, n);
#ifdef CHECK
  if (test_invert2 (qp, dp, n) == 0)
    {
      fprintf (stderr, "test_invert2 failed at i=%lu\n", i);
      gmp_printf ("A=%Nd\n", dp, n);
      gmp_printf ("X=B^%lu+%Nd\n", n, qp, n);
      exit (1);
    }
#endif
    }
  st = cputime () - st;
  printf ("mpn_invert2 took %dms (%.2f)\n", st, (double) st / (double) st0);

  MPN_ZERO (rp2, 2 * n);
  rp2[2 * n - 1] = GMP_LIMB_HIGHBIT;
  st = cputime ();
  for (i = 0; i < k; i++)
    {
      MPN_ZERO (rp2, 2 * n);
      rp2[2 * n - 1] = GMP_LIMB_HIGHBIT;
      mpn_divrem (qp2, 0, rp2, 2 * n, dp, n);
    }
  st = cputime () - st;
  printf ("mpn_divrem took %dms (%.2f)\n", st, (double) st / (double) st0);

  free (qp);
  free (rp);
  free (dp);
  free (tp);

  return 0;
}
#endif
