/* splitting x into xh and xl, both with 26 bits */

// #define CHECK /* to check both algorithms return the same results */

#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <assert.h>
#include <math.h>
#include <fpu_control.h>

static void
set_fpu_prec (void)
{
  fpu_control_t cw;

  _FPU_GETCW(cw);
  cw &= ~(_FPU_EXTENDED|_FPU_DOUBLE|_FPU_SINGLE);
  cw |= _FPU_DOUBLE;
  _FPU_SETCW(cw);
}

#define N 100000000

/* Algorithm 3 from the paper */
double
algo3 (double x, double *xl)
{
  double C = 134217729.0; /* 2^27+1 */
  double gamma = C * x;
  // printf ("gamma=%.16e\n", gamma);
  double delta = x - gamma;
  double xh = gamma + delta;
  *xl = x - xh;
  return xh;
}

/* Algorithm 4 from the paper
   x = 9.9999999777892101e-01 requires 27 bits for xl */
double
algo4 (double x, double *xl)
{
  static const double C = 134217729.0;
  double gamma = C * x;
  double xh = gamma - 134217728.0 * x;
  *xl = x - xh;
  return xh;
}

/* Algorithm 4 from the paper, with xl = RN(C*x-gamma) */
double
algo4a (double x, double *xl)
{
  static const double C = 134217729.0;
  double gamma = C * x;
  double xh = gamma - 134217728.0 * x;
  *xl = fma (C, x , -gamma);
  return xh;
}

/* same using bit manipulations */
double
bitman (double x, double *xl)
{
  union { double d; unsigned long n; } z;
  double xh;
  z.d = x;
  z.n &= ~0x7ffffffUL; /* zero the low 27 bits */
  *xl = x - z.d;
  return z.d;
}

/* return the least number of bits needed to store the significand of x */
int
nbits (double x)
{
  int e, n;
  double t, x0 = x;

  if (x == 0)
    return 0;
  x = (x > 0) ? x : -x;
  /* now x > 0 */
  x = frexp (x, &e);
  assert (0.5 <= x && x < 1);
  t = 0.5;
  n = 0;
  while (x != 0)
    {
      n ++;
      if (x >= t)
        x -= t;
      t = t * 0.5;
    }
  return n;
}

int
main()
{
  double *x;
  clock_t c;

  set_fpu_prec ();

  x = malloc (N * sizeof (double));

  /* fill the table with random floats */
  for (int i = 0; i < N; i++)
    {
      /* warning: drand48() only generates the high 48 bits! */
      x[i] = drand48 () + (lrand48 () % 32) / 9007199254740992.0;
      x[i] = ldexp (x[i], 32);
    }

  double xh, xl, xh2, xl2;
#ifdef CHECK
  /* check all algorithms give the same value */
  for (int i = 0; i < N; i++)
    {
      xh = algo3 (x[i], &xl);
      xh2 = algo4 (x[i], &xl2);
      if (xh != xh2 || xl != xl2)
        {
          printf ("Error for x=%.16e\n", x[i]);
          printf ("algo3 gives xh=%.16e, xl=%.16e\n", xh, xl);
          printf ("algo4 gives xh=%.16e, xl=%.16e\n", xh2, xl2);
          exit (1);
        }
    }
#endif
  for (int i = 0; i < N; i++)
    {
      xh = algo3 (x[i], &xl);
      // printf ("x=%.16e xh=%.16e xl=%.16e\n", x[i], xh, xl);
      assert (x[i] == xh + xl);
      assert (nbits (xh) <= 26);
      assert (nbits (xl) <= 26);
      xh = algo4 (x[i], &xl);
      assert (x[i] == xh + xl);
      assert (nbits (xh) <= 26);
      assert (nbits (xl) <= 27);
      xh = algo4a (x[i], &xl);
      assert (x[i] == xh + xl);
      assert (nbits (xh) <= 26);
      assert (nbits (xl) <= 27);
      xh = bitman (x[i], &xl);
      assert (x[i] == xh + xl);
      assert (nbits (xh) <= 26);
      assert (nbits (xl) <= 27);
    }
  for (int i = 0; i < N; i++)
    {
      xh = algo4 (x[i], &xl);
      xh2 = algo4a (x[i], &xl2);
      if (xh != xh2 || xl != xl2)
        {
          printf ("Error for x=%.16e\n", x[i]);
          printf ("algo4 gives xh=%.16e, xl=%.16e\n", xh, xl);
          printf ("algo4a gives xh=%.16e, xl=%.16e\n", xh2, xl2);
          exit (1);
        }
    }

  double s = 0;
  c = clock ();
  for (int i = 0; i < N; i++)
    s += algo3 (x[i], &xl);
  printf ("Algo3:  s=%e time=%e\n",
          s, (double) (clock () - c) / (double) CLOCKS_PER_SEC);

  double t = 0;
  c = clock ();
  for (int i = 0; i < N; i++)
    t += algo4 (x[i], &xl);
  printf ("Algo4: t=%e time=%e\n",
          t, (double) (clock () - c) / (double) CLOCKS_PER_SEC);

  double u = 0;
  c = clock ();
  for (int i = 0; i < N; i++)
    u += algo4a (x[i], &xl);
  printf ("Algo4a:  u=%e time=%e\n",
          u, (double) (clock () - c) / (double) CLOCKS_PER_SEC);

  double v = 0;
  c = clock ();
  for (int i = 0; i < N; i++)
    v += bitman (x[i], &xl);
  printf ("bitman:  v=%e time=%e\n",
          v, (double) (clock () - c) / (double) CLOCKS_PER_SEC);

  free (x);
  
  return 0;
}
