// Compile with
// cc -std=c99 -Wall -Wextra -pedantic -O3 rho.c -lmpfr -lgmp -pthread -o rho

#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <assert.h>
#include <gmp.h>
#include <mpfr.h>
#include <pthread.h>

#define MIN(x, y) ((x) <= (y) ? (x) : (y))

// Global variables.
static unsigned nt, n, M, prec, maxb;
static mpz_t   *p,  *e;
static mpfr_t  *r0, *r1, **y0, **y1;

// Main thread function.
void *thread_func(void *args)
{
  unsigned id = (size_t)args;
  unsigned i0 = ((M-1)/nt) * id + MIN(id, (M-1)%nt) + 2;
  unsigned ni =  (M-1)/nt + (id < (M-1)%nt ? 1 : 0);
  unsigned j0 = (n/nt) * id + MIN(id, n%nt);
  unsigned j1 = j0 + n/nt + (id < n%nt ? 1 : 0);

  // Set MPFR precision.
  mpfr_set_default_prec(prec);

  // Compute \zeta(i) for i0 <= i < i1.
  mpfr_t z0[ni], z1[ni];
  for (unsigned i = 0, ii = i0; i < ni; ++i, ++ii) {
    mpfr_inits  (z0[i], z1[i], NULL);
    mpfr_zeta_ui(z0[i], ii, GMP_RNDD);
    mpfr_zeta_ui(z1[i], ii, GMP_RNDU);
  }

  // Remove the contribution of first n primes from the \zeta(i)'s, so as to
  // obtain the partial zeta evaluations \zeta_n(i) for i0 <= i < i1.
  {
    mpfr_t     p0, p1, t0, t1, u0, u1;
    mpfr_inits(p0, p1, t0, t1, u0, u1, NULL);
    for (unsigned j = 0; j < n; ++j) {
      mpfr_set_ui(p0, 1,        GMP_RNDD);
      mpfr_set_ui(p1, 1,        GMP_RNDU);
      mpfr_div_z (p0, p0, p[j], GMP_RNDD);
      mpfr_div_z (p1, p1, p[j], GMP_RNDU);
      mpfr_pow_ui(t0, p0, i0-1, GMP_RNDD);
      mpfr_pow_ui(t1, p1, i0-1, GMP_RNDU);
      for (unsigned i = 0; i < ni; ++i) {
        mpfr_mul(t0,    t0,    p0, GMP_RNDD);
        mpfr_mul(t1,    t1,    p1, GMP_RNDU);
        mpfr_mul(u0,    z0[i], t0, GMP_RNDD);
        mpfr_mul(u1,    z1[i], t1, GMP_RNDU);
        mpfr_sub(z0[i], z0[i], u1, GMP_RNDD);
        mpfr_sub(z1[i], z1[i], u0, GMP_RNDU);
      }
    }
    mpfr_clears(p0, p1, t0, t1, u0, u1, NULL);
  }

  // Invert \zeta_n(i) when the corresponding exponent e_i is negative.
  for (unsigned i = 0, ii = i0; i < ni; ++i, ++ii) {
    if (mpz_sgn(e[ii]) < 0) {
      mpz_neg    (e[ii], e[ii]);
      mpfr_swap  (z0[i], z1[i]);
      mpfr_ui_div(z0[i], 1, z0[i], GMP_RNDD);
      mpfr_ui_div(z1[i], 1, z1[i], GMP_RNDU);
    }
  }

  // Compute the product of \zeta_n(i)^e_i over i0 <= i < i1 using
  // multi-exponentiation.
  mpfr_ptr rr0 = r0[id];
  mpfr_ptr rr1 = r1[id];
  mpfr_init_set_ui(rr0, 1, GMP_RNDD);
  mpfr_init_set_ui(rr1, 1, GMP_RNDU);
  for (unsigned b = maxb+1; b--; ) {
    mpfr_sqr(rr0, rr0, GMP_RNDD);
    mpfr_sqr(rr1, rr1, GMP_RNDU);
    for (unsigned i = 0, ii = i0; i < ni; ++i, ++ii) {
      if (mpz_tstbit(e[ii], b)) {
        mpfr_mul(rr0, rr0, z0[i], GMP_RNDD);
        mpfr_mul(rr1, rr1, z1[i], GMP_RNDU);
      }
    }
  }

  // Multiply by terms j0 to j1-1 of the Euler product.
  {
    mpfr_t     t0, t1;
    mpfr_inits(t0, t1, NULL);
    mpz_t      t;
    mpz_init  (t);
    for (unsigned j = j0; j < j1; ++j) {
      mpz_add_ui (t, p[j], 1);
      mpz_mul    (t, t, p[j]);
      mpfr_mul_ui(t0,  rr0, 2,  GMP_RNDD);
      mpfr_mul_ui(t1,  rr1, 2,  GMP_RNDU);
      mpfr_div_z (t0,  t0,  t,  GMP_RNDD);
      mpfr_div_z (t1,  t1,  t,  GMP_RNDU);
      mpfr_sub   (rr0, rr0, t1, GMP_RNDD);
      mpfr_sub   (rr1, rr1, t0, GMP_RNDU);
    }
    mpfr_clears(t0, t1, NULL);
    mpz_clear  (t);
  }

  // Cleanup.
  for (unsigned i = 0; i < ni; ++i)
    mpfr_clears(z0[i],  z1[i],  NULL);

  return NULL;
}

int main(int argc, char **argv)
{
  // Default values.
  nt = 1;
  n  = 200;

  // Parse command-line arguments.
  char *argv0 = argv[0];
  if (argc > 2 && !strcmp(argv[1], "-nt")) {
    nt = atoi(argv[2]);
    argc -= 2; argv += 2;
  }

  if (argc < 2 && argc > 3) {
    fprintf(stderr, "Usage: %s [-nt <nt>] <digits> [<n>]\n", argv0);
    return EXIT_FAILURE;
  }

  unsigned nd = atoi(argv[1]);
  if (argc == 3)
    n = atoi(argv[2]);

  // Compute primes up to p_{n+1}.
  p = (mpz_t *)malloc((n+1) * sizeof(mpz_t));
  assert(p != NULL);
  mpz_init_set_ui(p[0], 2);
  for (unsigned i = 1; i <= n; ++i) {
    mpz_init     (p[i]);
    mpz_nextprime(p[i], p[i-1]);
  }

  // Compute optimal M for given precision, then compute bound on the
  // approximation error.
  unsigned  guard = 10;
  mpfr_t    err;
  mpfr_init(err);
  {
    mpfr_t      t, u, v;
    mpfr_inits (t, u, v, NULL);
    mpfr_set_ui(t,   2,         GMP_RNDU);
    mpfr_div_z (t,   t,   p[n], GMP_RNDU);
    mpfr_ui_sub(u,   1,   t,    GMP_RNDU);
    mpfr_div_ui(u,   u,   8,    GMP_RNDD);
    mpfr_log10 (u,   u,         GMP_RNDD);
    mpfr_sub_ui(u,   u,   nd+2, GMP_RNDD);
    mpfr_log10 (v,   t,         GMP_RNDU);
    mpfr_div   (u,   u,   v,    GMP_RNDU);
    M = mpfr_get_ui(u,          GMP_RNDU);

    mpfr_pow_ui(err, t,   M,    GMP_RNDU);
    mpfr_ui_sub(t,   1,   t,    GMP_RNDD);
    mpfr_div   (err, err, t,    GMP_RNDU);
    mpfr_mul_ui(err, err, 8,    GMP_RNDU);
    mpfr_expm1 (err, err,       GMP_RNDU);
    mpfr_log2  (t,   err,       GMP_RNDD);
    prec = -mpfr_get_si(t,      GMP_RNDD);
    mpfr_clears(t, u, v, NULL);
  }
  mpfr_fprintf(stderr, "Using n = %u (p_n = %Zu) and M = %u\n", n, p[n-1], M);

  // Compute exponents (e_i)_{2 <= i <= M} for the Taylor expansion of
  // 1 - f(1/t) / g(1/t).
  e = (mpz_t *)malloc((M+1) * sizeof(mpz_t));
  assert(e != NULL);
  for (unsigned i = 2; i <= M; ++i)
    mpz_init(e[i]);
  {
    mpz_t f[M+1], t[M+1], b;
    for (unsigned i = 0; i <= M; ++i)
      mpz_inits(f[i], t[i], NULL);
    mpz_init(b);

    mpz_set_ui(f[0], 1);
    for (unsigned i = 2; i <= M; ++i)
      mpz_set_si(f[i], i%2 ? 2 : -2);

    for (unsigned i = 2; i <= M; ++i) {
      mpz_set(e[i], f[i]);
      for (unsigned j = 0; j <= M; ++j)
        mpz_set(t[j], f[j]);
      for (unsigned j = 1; i*j <= M; ++j) {
        mpz_bin_ui(b, e[i], j);
        if (j%2)
          mpz_neg (b, b);
        for (unsigned k = 0; k+i*j <= M; ++k)
          mpz_addmul(f[k+i*j], t[k], b);
      }
    }

    for (unsigned i = 0; i <= M; ++i)
      mpz_clears(f[i], t[i], NULL);
    mpz_clear(b);
  }

  // Compute maximum size (in bits) of the exponents, and
  // set MPFR precision accordingly.
  maxb = 0;
  for (unsigned i = 2; i <= M; ++i) {
    unsigned b = mpz_sizeinbase(e[i], 2);
    if (maxb < b)
      maxb = b;
  }
  prec += maxb + guard;
  mpfr_set_default_prec(prec);
  mpfr_fprintf(stderr, "MPFR precision:      %u bits\n", prec);
  mpfr_fprintf(stderr, "Approximation error: %.3Rg\n", err);

  // Allocate thread data.
  r0 = (mpfr_t  *)malloc(nt * sizeof(mpfr_t));
  r1 = (mpfr_t  *)malloc(nt * sizeof(mpfr_t));
  y0 = (mpfr_t **)malloc(nt * sizeof(mpfr_t *));
  y1 = (mpfr_t **)malloc(nt * sizeof(mpfr_t *));
  assert(r0 != NULL && r1 != NULL);
  assert(y0 != NULL && y1 != NULL);

  // Create threads.
  pthread_t thr[nt];
  for (unsigned i = 0; i < nt; ++i)
    pthread_create(&thr[i], NULL, &thread_func, (void *)(size_t)i);

  // Wait for threads to finish.
  for (unsigned i = 0; i < nt; ++i)
    pthread_join(thr[i], NULL);

  // Multiply all products.
  for (unsigned i = 1; i < nt; ++i) {
    mpfr_mul(r0[0], r0[0], r0[i], GMP_RNDD);
    mpfr_mul(r1[0], r1[0], r1[i], GMP_RNDU);
  }

  // Finalize computation of \rho.
  mpfr_add_ui(r0[0], r0[0], 1, GMP_RNDD);
  mpfr_add_ui(r1[0], r1[0], 1, GMP_RNDU);
  mpfr_div_ui(r0[0], r0[0], 2, GMP_RNDD);
  mpfr_div_ui(r1[0], r1[0], 2, GMP_RNDU);

  // Compute the rounding error.
  {
    mpfr_t      rnd;
    mpfr_init  (rnd);
    mpfr_sub   (rnd, r1[0], r0[0], GMP_RNDU);
    mpfr_div_ui(rnd, rnd,   2,     GMP_RNDU);
    mpfr_fprintf(stderr, "Rounding error:      %.3Rg\n", rnd);
    mpfr_clear (rnd);
  }

  // Take approximation error into account.
  mpfr_sub(r0[0], r0[0], err, GMP_RNDD);
  mpfr_add(r1[0], r1[0], err, GMP_RNDU);

  // Print result. If some digits differ between lower and upper bounds,
  // use interval notation for these digits; e.g., 0.xxxxxx[yy..zz].
  char s0[nd+3], s1[nd+3];
  mpfr_snprintf(s0, nd+3, "%.*RNf", nd, r0[0]);
  mpfr_snprintf(s1, nd+3, "%.*RNf", nd, r1[0]);
  {
    unsigned i;
    for (i = 0; i < nd+3 && s0[i] == s1[i]; ++i);
    fwrite(s0, i, 1, stdout);
    if (i < nd+3)
      printf("[%s..%s]", s0+i, s1+i);
    printf("\n");
  }

  // Memory cleanup.
  for (unsigned i = 0; i <= n; ++i) mpz_clear  (p[i]);
  for (unsigned i = 2; i <= M; ++i) mpz_clear  (e[i]);
  for (unsigned i = 0; i < nt; ++i) mpfr_clears(r0[i], r1[i], NULL);
  mpfr_clear(err);
  mpfr_free_cache();

  free(p);
  free(e);
  free(r0);
  free(r1);
  free(y0);
  free(y1);

  return EXIT_SUCCESS;
}
