/* rounding to the nearest integer */

#define NEAREST 0
#define ZERO 1

#define ROUND NEAREST
// #define ROUND ZERO

#if ROUND == NEAREST
#define RINT rint
#define RINTs "rint"
#else
#define RINT floor
#define RINTs "floor"
#endif

#define CHECK /* to check both algorithms return the same results */

#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <assert.h>
#include <math.h>
#include <fpu_control.h>

static void
set_fpu_prec (void)
{
  fpu_control_t cw;

  _FPU_GETCW(cw);
  cw &= ~(_FPU_EXTENDED|_FPU_DOUBLE|_FPU_SINGLE);
  cw |= _FPU_DOUBLE;
  _FPU_SETCW(cw);
}

#define N 100000000

#if ROUND == NEAREST
/* Algorithm 1 from the paper */
double
algo1 (double x)
{
  static const double C = 6755399441055744.0; /* 2^52 + 2^51 */
  double s = C + x;
  return s - C;
}
#else
/* Algorithm 2 from the paper */
double
algo1 (double x)
{
  double y = x - 0.5;
  double C = 9007199254740992.0 - x;
  double s = C + y;
  return s - C;
}
#endif

#if ROUND == NEAREST
/* with bit manipulations, rounding to nearest */
double
algo1a (double x)
{
  union { double d; unsigned long n; } z;
  int e; /* biased exponent */
  z.d = x;
  e = (z.n >> 52) & 0x7ff; /* 11 bits of exponent */
  e -= 1023;
  /* for x=0.5, e=-1 */
  if (e < -1)
    return 0.0;
  else if (e == -1)
    return (x == 0.5) ? 0.0 : 1.0; /* 0.5 is rounded to 0, 0.501..0.999 to 1 */
  /* x = 1.xxx * 2^e thus we should zero the low 52-e bits */
  else if (e >= 52)
    return x;
  else /* 0 <= e <= 51 */
    {
      unsigned long rb = 1UL << (51 - e);
      unsigned long mask = -(rb << 1);
      /* even rule */
      if (z.n << (e + 11) != 0x4000000000000000)
        z.n += rb;
      z.n &= mask;
    }
  return z.d;
}
#else
/* with bit manipulations, rounding down, i.e., floor(x) */
double
algo1a (double x)
{
  union { double d; unsigned long n; } z;
  int e; /* biased exponent */
  z.d = x;
  e = (z.n >> 52) & 0x7ff; /* 11 bits of exponent */
  e -= 1023;
  if (e <= -1)
    return signbit (x) == 0 ? 0.0 : -1.0;
  /* x = 1.xxx * 2^e thus we should zero the low 52-e bits */
  else if (e >= 52)
    return x;
  else /* 0 <= e <= 52 */
    {
      z.n &= -(1UL << (52 - e));
      return (x > 0) ? z.d : z.d - 1.0;
    }
}
#endif

int
main()
{
  double *x;
  clock_t c;

  set_fpu_prec ();

  x = malloc (N * sizeof (double));

#if 0
  for (double d = 0.0; d <= 2.0; d += 0.5)
    printf ("d:%f algo1:%f algo1a:%f %s:%f\n", d, algo1 (d), algo1a (d),
            RINTs, RINT (d));
#endif

  /* fill the table with random floats in [0, 2^32) */
  for (int i = 0; i < N; i++)
    {
      /* warning: drand48() only generates the high 48 bits! */
      x[i] = drand48 () + (lrand48 () % 32) / 9007199254740992.0;
      x[i] = ldexp (x[i], 32);
    }

  /* special numbers */
  x[0] = 0.0;
  x[1] = 4.9406564584124654417656879287e-324; /* smallest subnormal */
  x[2] = 2.2250738585072013830902327173e-308; /* smallest normal */
  x[3] = 1.7976931348623157081452742373e308; /* largest number */

#ifdef CHECK
  /* check both algorithms give the same value */
  for (int i = 0; i < N; i++)
    if (algo1 (x[i]) != algo1a (x[i]))
      {
        printf ("Error for x=%.16e\n", x[i]);
        printf ("algo1 gives %.16e\n", algo1 (x[i]));
        printf ("algo1a gives %.16e\n", algo1a (x[i]));
        exit (1);
      }
  /* also check with round or floor */
  for (int i = 0; i < N; i++)
    if (algo1 (x[i]) != RINT (x[i]))
      {
        printf ("Error for x=%.16e\n", x[i]);
        printf ("algo1 gives %.16e\n", algo1 (x[i]));
        printf ("%s gives %.16e\n", RINTs, RINT (x[i]));
        exit (1);
      }
#endif

  double s = 0;
  c = clock ();
  for (int i = 0; i < N; i++)
    s += algo1 (x[i]);
  printf ("Algo1:  s=%e time=%e\n",
          s, (double) (clock () - c) / (double) CLOCKS_PER_SEC);

  double t = 0;
  c = clock ();
  for (int i = 0; i < N; i++)
    t += algo1a (x[i]);
  printf ("Algo1a: t=%e time=%e\n",
          t, (double) (clock () - c) / (double) CLOCKS_PER_SEC);

  double u = 0;
  c = clock ();
  for (int i = 0; i < N; i++)
    u += RINT (x[i]);
  printf ("%s:  u=%e time=%e\n",
          RINTs, u, (double) (clock () - c) / (double) CLOCKS_PER_SEC);

  free (x);
  
  return 0;
}
