LCOV - code coverage report
Current view: top level - ecm - mpzspv.c (source / functions) Hit Total Coverage
Test: unnamed Lines: 299 315 94.9 %
Date: 2022-03-21 11:19:20 Functions: 19 19 100.0 %

          Line data    Source code
       1             : /* mpzspv.c - "mpz small prime polynomial" functions for arithmetic on mpzv's
       2             :    reduced modulo a mpzspm
       3             : 
       4             : Copyright 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012 Dave Newman,
       5             : Jason Papadopoulos, Alexander Kruppa, Paul Zimmermann.
       6             : 
       7             : The SP Library is free software; you can redistribute it and/or modify
       8             : it under the terms of the GNU Lesser General Public License as published by
       9             : the Free Software Foundation; either version 3 of the License, or (at your
      10             : option) any later version.
      11             : 
      12             : The SP Library is distributed in the hope that it will be useful, but
      13             : WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
      14             : or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public
      15             : License for more details.
      16             : 
      17             : You should have received a copy of the GNU Lesser General Public License
      18             : along with the SP Library; see the file COPYING.LIB.  If not, write to
      19             : the Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston,
      20             : MA 02110-1301, USA. */
      21             : 
      22             : #include <stdio.h> /* for stderr */
      23             : #include <stdlib.h>
      24             : #include <string.h> /* for memset */
      25             : #include "ecm-impl.h"
      26             : #include "sp.h"
      27             : 
      28             : mpzspv_t
      29       85419 : mpzspv_init (spv_size_t len, mpzspm_t mpzspm)
      30             : {
      31             :   unsigned int i;
      32       85419 :   mpzspv_t x = (mpzspv_t) malloc (mpzspm->sp_num * sizeof (spv_t));
      33             :   
      34       85419 :   if (x == NULL)
      35           0 :     return NULL;
      36             :   
      37     1926105 :   for (i = 0; i < mpzspm->sp_num; i++)
      38             :     {
      39     1840686 :       x[i] = (spv_t) sp_aligned_malloc (len * sizeof (sp_t));
      40             :       
      41     1840686 :       if (x[i] == NULL)
      42             :         {
      43           0 :           while (i--)
      44           0 :             sp_aligned_free (x[i]);
      45             :           
      46           0 :           free (x);
      47           0 :           return NULL;
      48             :         }
      49             :     }
      50             :   
      51       85419 :   return x;
      52             : }
      53             : 
      54             : void
      55       85419 : mpzspv_clear (mpzspv_t x, mpzspm_t mpzspm)
      56             : {
      57             :   unsigned int i;
      58             :         
      59             :   ASSERT (mpzspv_verify (x, 0, 0, mpzspm));
      60             :   
      61     1926105 :   for (i = 0; i < mpzspm->sp_num; i++)
      62     1840686 :     sp_aligned_free (x[i]);
      63             :   
      64       85419 :   free (x);
      65       85419 : }
      66             : 
      67             : #ifdef WANT_ASSERT
      68             : /* check that:
      69             :  *  - each of the spv's is at least offset + len long
      70             :  *  - the data specified by (offset, len) is correctly normalised in the
      71             :  *    range [0, sp)
      72             :  *
      73             :  * return 1 for success, 0 for failure */
      74             : 
      75             : int
      76             : mpzspv_verify (mpzspv_t x, spv_size_t offset, spv_size_t len, mpzspm_t mpzspm)
      77             : {
      78             :   unsigned int i;
      79             :   spv_size_t j;
      80             :   
      81             :   for (i = 0; i < mpzspm->sp_num; i++)
      82             :     {
      83             :       for (j = offset; j < offset + len; j++)
      84             :         if (x[i][j] >= mpzspm->spm[i]->sp)
      85             :           return 0;
      86             :     }
      87             : 
      88             :   return 1;
      89             : }
      90             : #endif
      91             : 
      92             : void
      93      281399 : mpzspv_set (mpzspv_t r, spv_size_t r_offset, mpzspv_t x, spv_size_t x_offset,
      94             :     spv_size_t len, mpzspm_t mpzspm)
      95             : {
      96             :   unsigned int i;
      97             :   
      98             :   ASSERT (mpzspv_verify (r, r_offset + len, 0, mpzspm));
      99             :   ASSERT (mpzspv_verify (x, x_offset, len, mpzspm));
     100             :   
     101     6169480 :   for (i = 0; i < mpzspm->sp_num; i++)
     102     5888081 :     spv_set (r[i] + r_offset, x[i] + x_offset, len);
     103      281399 : }
     104             : 
     105             : #if 0
     106             : void
     107             : mpzspv_revcopy (mpzspv_t r, spv_size_t r_offset, mpzspv_t x, 
     108             :     spv_size_t x_offset, spv_size_t len, mpzspm_t mpzspm)
     109             : {
     110             :   unsigned int i;
     111             :   
     112             :   ASSERT (mpzspv_verify (r, r_offset + len, 0, mpzspm));
     113             :   ASSERT (mpzspv_verify (x, x_offset, len, mpzspm));
     114             :   
     115             :   for (i = 0; i < mpzspm->sp_num; i++)
     116             :     spv_rev (r[i] + r_offset, x[i] + x_offset, len);
     117             : }
     118             : #endif
     119             : 
     120             : void
     121       24254 : mpzspv_set_sp (mpzspv_t r, spv_size_t offset, sp_t c, spv_size_t len,
     122             :     mpzspm_t mpzspm)
     123             : {
     124             :   unsigned int i;
     125             :   
     126             :   ASSERT (mpzspv_verify (r, offset + len, 0, mpzspm));
     127             :   ASSERT (c < SP_MIN); /* not strictly necessary but avoids mod functions */
     128             :   
     129      519868 :   for (i = 0; i < mpzspm->sp_num; i++)
     130      495614 :     spv_set_sp (r[i] + offset, c, len);
     131       24254 : }
     132             : 
     133             : void
     134         620 : mpzspv_neg (mpzspv_t r, spv_size_t r_offset, mpzspv_t x, spv_size_t x_offset,
     135             :     spv_size_t len, mpzspm_t mpzspm)
     136             : {
     137             :   unsigned int i;
     138             :   
     139             :   ASSERT (mpzspv_verify (r, r_offset + len, 0, mpzspm));
     140             :   ASSERT (mpzspv_verify (x, x_offset, len, mpzspm));
     141             :   
     142        7135 :   for (i = 0; i < mpzspm->sp_num; i++)
     143        6515 :     spv_neg (r[i] + r_offset, x[i] + x_offset, len, mpzspm->spm[i]->sp);
     144         620 : }
     145             : 
     146             : void
     147         266 : mpzspv_add (mpzspv_t r, spv_size_t r_offset, mpzspv_t x, spv_size_t x_offset,
     148             :             mpzspv_t y, spv_size_t y_offset, spv_size_t len, mpzspm_t mpzspm)
     149             : {
     150             :   unsigned int i;
     151             :   
     152             :   ASSERT (mpzspv_verify (r, r_offset + len, 0, mpzspm));
     153             :   ASSERT (mpzspv_verify (x, x_offset, len, mpzspm));
     154             :   
     155        2889 :   for (i = 0; i < mpzspm->sp_num; i++)
     156        2623 :     spv_add (r[i] + r_offset, x[i] + x_offset, y[i] + y_offset, len, 
     157        2623 :              mpzspm->spm[i]->sp);
     158         266 : }
     159             : 
     160             : void
     161         916 : mpzspv_reverse (mpzspv_t x, spv_size_t offset, spv_size_t len, mpzspm_t mpzspm)
     162             : {
     163             :   unsigned int i;
     164             :   spv_size_t j;
     165             :   sp_t t;
     166             :   spv_t spv;
     167             :   
     168             :   ASSERT (mpzspv_verify (x, offset, len, mpzspm));
     169             :   
     170       15174 :   for (i = 0; i < mpzspm->sp_num; i++)
     171             :     {
     172       14258 :       spv = x[i] + offset;
     173    10276300 :       for (j = 0; j < len - 1 - j; j++)
     174             :         {
     175    10262042 :           t = spv[j];
     176    10262042 :           spv[j] = spv[len - 1 - j];
     177    10262042 :           spv[len - 1 - j] = t;
     178             :         }
     179             :     }
     180         916 : }
     181             : 
     182             : /* Return {xp, xn} mod p.
     183             :    Assume 2p < B where B = 2^GMP_NUMB_LIMB.
     184             :    We first compute {xp, xn} / B^n mod p using Montgomery reduction,
     185             :    where the number N to factor has n limbs.
     186             :    Then we multiply by B^(n+1) mod p (precomputed) and divide by B mod p.
     187             :    Assume invm = -1/p mod B and Bpow = B^n mod p */
     188             : static mp_limb_t
     189   612931765 : ecm_mod_1 (mp_ptr xp, mp_size_t xn, mp_limb_t p, mp_size_t n,
     190             :            mp_limb_t invm, mp_limb_t Bpow)
     191             : {
     192             :   mp_limb_t q, cy, hi, lo, x0, x1;
     193             : 
     194   612931765 :   if (xn == 0)
     195           0 :     return 0;
     196             : 
     197             :   /* the code below assumes xn <= n+1, thus we call mpn_mod_1 otherwise,
     198             :      but this should never (or rarely) happen */
     199   612931765 :   if (xn > n + 1)
     200      119553 :     return mpn_mod_1 (xp, xn, p);
     201             : 
     202   612812212 :   x0 = xp[0];
     203   612812212 :   cy = (mp_limb_t) 0;
     204  5340841149 :   while (n-- > 0)
     205             :     {
     206             :       /* Invariant: cy is the input carry on xp[1], x0 is xp[0] */
     207  4728028937 :       x1 = (xn > 1) ? xp[1] : 0;
     208  4728028937 :       q = x0 * invm; /* q = -x0/p mod B */
     209  4728028937 :       umul_ppmm (hi, lo, q, p); /* hi*B + lo = -x0 mod B */
     210             :       /* Add hi*B + lo to x1*B + x0. Since p <= B-2 we have
     211             :          hi*B + lo <= (B-1)(B-2) = B^2-3B+2, thus hi <= B-3 */
     212  4728028937 :       hi += cy + (lo != 0); /* cannot overflow */
     213  4728028937 :       x0 = x1 + hi;
     214  4728028937 :       cy = x0 < hi;
     215  4728028937 :       xn --;
     216  4728028937 :       xp ++;
     217             :     }
     218   612812212 :   if (cy != 0)
     219       10314 :     x0 -= p;
     220             :   /* now x0 = {xp, xn} / B^n mod p */
     221   612812212 :   umul_ppmm (x1, x0, x0, Bpow);
     222             :   /* since Bpow < p, x1 <= p-1 */
     223   612812212 :   q = x0 * invm;
     224   612812212 :   umul_ppmm (hi, lo, q, p);
     225             :   /* hi <= p-1 thus hi+x1+1 < 2p-1 < B */
     226   612812212 :   hi = hi + x1 + (lo != 0);
     227   646269133 :   while (hi >= p)
     228    33456921 :     hi -= p;
     229   612812212 :   return hi;
     230             : }
     231             : 
     232             : #ifdef TIMING_CRT
     233             : int mpzspv_from_mpzv_slow_time = 0;
     234             : int mpzspv_to_mpzv_time = 0;
     235             : int mpzspv_normalise_time = 0;
     236             : #endif
     237             : 
     238             : /* convert mpzvi to CRT representation, naive version */
     239             : static void
     240    46284690 : mpzspv_from_mpzv_slow (mpzspv_t x, const spv_size_t offset, mpz_t mpzvi,
     241             :                        mpzspm_t mpzspm)
     242             : {
     243    46284690 :   const unsigned int sp_num = mpzspm->sp_num;
     244             :   unsigned int j;
     245    46284690 :   mp_size_t n = mpz_size (mpzspm->modulus);
     246             : 
     247             : #ifdef TIMING_CRT
     248             :   mpzspv_from_mpzv_slow_time -= cputime ();
     249             : #endif
     250   659216455 :   for (j = 0; j < sp_num; j++)
     251   612931765 :     x[j][offset] = ecm_mod_1 (PTR(mpzvi), SIZ(mpzvi),
     252   612931765 :                               (mp_limb_t) mpzspm->spm[j]->sp, n,
     253   612931765 :                               mpzspm->spm[j]->invm, mpzspm->spm[j]->Bpow);
     254             : #ifdef TIMING_CRT
     255             :   mpzspv_from_mpzv_slow_time += cputime ();
     256             : #endif
     257             :   /* The typecast to mp_limb_t assumes that mp_limb_t is at least
     258             :      as wide as sp_t */
     259    46284690 : }
     260             : 
     261             : /* convert mpzvi to CRT representation, fast version, assumes
     262             :    mpzspm->T has been precomputed (see mpzspm.c).
     263             :    Warning: this function should be thread-safe, since it might be called
     264             :    simultaneously by several threads. */
     265             : static void
     266        1560 : mpzspv_from_mpzv_fast (mpzspv_t x, const spv_size_t offset, mpz_t mpzvi,
     267             :                        mpzspm_t mpzspm)
     268             : {
     269        1560 :   const unsigned int sp_num = mpzspm->sp_num;
     270        1560 :   unsigned int i, j, k, i0 = I0_THRESHOLD, I0;
     271        1560 :   mpzv_t *T = mpzspm->T;
     272             :   mpz_t *U;
     273        1560 :   unsigned int d = mpzspm->d, ni;
     274             : 
     275        1560 :   U = malloc (sp_num * sizeof (mpz_t));
     276      262308 :   for (j = 0; j < sp_num; j++)
     277      260748 :     mpz_init (U[j]);
     278             : 
     279             :   ASSERT (d > i0);
     280             : 
     281             :   /* initially we split mpzvi in two */
     282        1560 :   ni = 1 << (d - 1);
     283        1560 :   mpz_mod (U[0], mpzvi, T[d-1][0]);
     284        1560 :   mpz_mod (U[ni], mpzvi, T[d-1][1]);
     285        1674 :   for (i = d-1; i-- > i0;)
     286             :     { /* goes down from depth i+1 to i */
     287         114 :       ni = 1 << i;
     288         380 :       for (j = k = 0; j + ni < sp_num; j += 2*ni, k += 2)
     289             :         {
     290         266 :           mpz_mod (U[j+ni], U[j], T[i][k+1]);
     291         266 :           mpz_mod (U[j], U[j], T[i][k]);
     292             :         }
     293             :       /* for the last entry U[j] if j < sp_num, there is nothing to do */
     294             :     }
     295             :   /* last steps */
     296        1560 :   I0 = 1 << i0;
     297        4946 :   for (j = 0; j < sp_num; j += I0)
     298             :     {
     299      264134 :       for (k = j; k < j + I0 && k < sp_num; k++)
     300      260748 :         x[k][offset] = mpn_mod_1 (PTR(U[j]), SIZ(U[j]),
     301      260748 :                                   (mp_limb_t) mpzspm->spm[k]->sp);
     302             :     }
     303             :   /* The typecast to mp_limb_t assumes that mp_limb_t is at least
     304             :      as wide as sp_t */
     305             : 
     306      262308 :   for (j = 0; j < sp_num; j++)
     307      260748 :     mpz_clear (U[j]);
     308        1560 :   free(U);
     309        1560 : }
     310             : 
     311             : #if defined(TRACE_mpzspv_from_mpzv) || defined(TRACE_ntt_sqr_reciprocal)
     312             : static void
     313             : ntt_print_vec (const char *msg, const spv_t spv, const spv_size_t l, 
     314             :                const sp_t p)
     315             : {
     316             :   spv_size_t i;
     317             : 
     318             :   /* Warning: on some computers, for example gcc49.fsffrance.org,
     319             :      "unsigned long" might be shorter than "sp_t" */
     320             :   gmp_printf ("%s [%Nd", msg, (mp_ptr) spv, 1);
     321             :   for (i = 1; i < l; i++)
     322             :     gmp_printf (", %Nd", (mp_ptr) spv + i, 1);
     323             :   printf ("] (mod %llu)\n", (long long unsigned int) p);
     324             : }
     325             : #endif
     326             : 
     327             : /* convert an array of len mpz_t numbers to CRT representation modulo
     328             :    sp_num moduli */
     329             : void
     330    18533786 : mpzspv_from_mpzv (mpzspv_t x, const spv_size_t offset, const mpzv_t mpzv,
     331             :     const spv_size_t len, mpzspm_t mpzspm)
     332             : {
     333    18533786 :   const unsigned int sp_num = mpzspm->sp_num;
     334             :   long i;
     335             : 
     336             :   ASSERT (mpzspv_verify (x, offset + len, 0, mpzspm));
     337             :   ASSERT (sizeof (mp_limb_t) >= sizeof (sp_t));
     338             : 
     339             : #ifdef TRACE_mpzspv_from_mpzv
     340             :   for (i = 0; i < (long) len; i++)
     341             :     gmp_printf ("mpzspv_from_mpzv: mpzv[%ld] = %Zd\n", i, mpzv[i]);
     342             : #endif
     343             : 
     344             : #if defined(_OPENMP)
     345             : #pragma omp parallel private(i) if (len > 16384)
     346             :   {
     347             :     /* Multi-threading with dynamic scheduling slows things down */
     348             : #pragma omp for schedule(static)
     349             : #endif
     350    64820855 :     for (i = 0; i < (long) len; i++)
     351             :     {
     352             :       unsigned int j;
     353    46287069 :       if (mpz_sgn (mpzv[i]) == 0)
     354             :         {
     355        8464 :           for (j = 0; j < sp_num; j++)
     356        7645 :             x[j][i + offset] = 0;
     357             :         }
     358             :       else
     359             :         {
     360             :           ASSERT(mpz_sgn (mpzv[i]) > 0); /* We can't handle negative values */
     361    46286250 :           if (mpzspm->T == NULL)
     362    46284690 :             mpzspv_from_mpzv_slow (x, i + offset, mpzv[i], mpzspm);
     363             :           else
     364        1560 :             mpzspv_from_mpzv_fast (x, i + offset, mpzv[i], mpzspm);
     365             :         }
     366             :     }
     367             : #if defined(_OPENMP)
     368             :   }
     369             : #endif
     370             : 
     371             : #ifdef TRACE_mpzspv_from_mpzv
     372             :   for (i = 0; i < (long) sp_num; i++)
     373             :     ntt_print_vec ("mpzspv_from_mpzv: ", x[i] + offset, len, mpzspm->spm[i]->sp);
     374             : #endif
     375    18533786 : }
     376             : 
     377             : /* Convert the len residues x[][offset..offset+len-1] from "spv" (RNS) format
     378             :  * to mpz_t format.
     379             :  * See: Daniel J. Bernstein and Jonathan P. Sorenson,
     380             :  * Modular Exponentiation via the explicit Chinese Remainder Theorem,
     381             :  * Mathematics of Computation 2007,
     382             :  * Theorem 2.1: Let p_1, ..., p_s be pairwise coprime integers. Write
     383             :  * P = p_1 * ... * p_s. Let q_1, ..., q_s be integers with
     384             :  * q_iP/p_i = 1 mod p_i. Let u be an integer with |u| < P/2. Let u_1, ..., u_s
     385             :  * with u = u_i mod p_i. Let t_1, ..., t_s be integers with
     386             :  * t_i = u_i q_i mod p_i. Then u = P \alpha - P round(\alpha) where
     387             :  * \alpha = \sum_i t_i/p_i
     388             :  *
     389             :  * time: O(len * sp_num^2) where sp_num is proportional to the modulus size
     390             :  * memory: MPZSPV_NORMALISE_STRIDE floats */
     391             : void
     392       86720 : mpzspv_to_mpzv (mpzspv_t x, spv_size_t offset, mpzv_t mpzv,
     393             :     spv_size_t len, mpzspm_t mpzspm)
     394             : {
     395             :   unsigned int i;
     396             :   spv_size_t k, l;
     397       86720 :   float *f = (float *) malloc (MPZSPV_NORMALISE_STRIDE * sizeof (float));
     398             :   float prime_recip;
     399             :   sp_t t;
     400       86720 :   spm_t *spm = mpzspm->spm;
     401             :   mpz_t mt;
     402             : 
     403       86720 :   if (f == NULL)
     404             :     {
     405           0 :       fprintf (stderr, "Cannot allocate memory in mpzspv_to_mpzv\n");
     406           0 :       exit (1);
     407             :     }
     408             :   
     409             :   ASSERT (mpzspv_verify (x, offset, len, mpzspm));
     410       86720 :   ASSERT_ALWAYS(mpzspm->sp_num <= 1677721);
     411             : 
     412             : #ifdef TIMING_CRT
     413             :   mpzspv_to_mpzv_time -= cputime ();
     414             : #endif
     415       86720 :   mpz_init (mt);
     416      310194 :   for (l = 0; l < len; l += MPZSPV_NORMALISE_STRIDE)
     417             :     {
     418      223474 :       spv_size_t stride = MIN (MPZSPV_NORMALISE_STRIDE, len - l);
     419             : 
     420             :       /* we apply the above theorem to mpzv[l]...mpzv[l+stride-1] at once */
     421    27369203 :       for (k = 0; k < stride; k++)
     422             :         {
     423    27145729 :           f[k] = 0.5; /* this is performed len times */
     424    27145729 :           mpz_set_ui (mpzv[k + l], 0);
     425             :         }
     426             :   
     427     3423175 :     for (i = 0; i < mpzspm->sp_num; i++)
     428             :       {
     429             :         /* this loop is performed len*sp_num/MPZSPV_NORMALISE_STRIDE times */
     430             : 
     431             :         /* prime_recip = 1/p_i * (1+u)^2 wih |u| <= 2^(-24) where one
     432             :            exponent is due to the sp -> float conversion, and one to the
     433             :            division */
     434     3199701 :         prime_recip = 1.0f / (float) spm[i]->sp; /* 1/p_i */
     435             :       
     436   360198271 :         for (k = 0; k < stride; k++)
     437             :           {
     438             :             /* this loop is performed len*sp_num times */
     439             : 
     440             :             /* crt3[i] = p_i/P mod p_i (q_i in the theorem) */
     441   356998570 :             t = sp_mul (x[i][l + k + offset], mpzspm->crt3[i], spm[i]->sp,
     442   356998570 :                   spm[i]->mul_c);
     443             : 
     444             :             /* crt1[i] = P / p_i mod modulus: we accumulate in mpzv[l + k]
     445             :                the sum of P t_i/p_i = t_i (P/p_i) mod N.
     446             :                If N has n limbs, crt1[i] has n limbs too,
     447             :                thus mpzv[l+k] has about n limbs */
     448             :             if (sizeof (sp_t) > sizeof (unsigned long))
     449             :               {
     450             :                 mpz_set_sp (mt, t);
     451             :                 mpz_addmul (mpzv[l + k], mpzspm->crt1[i], mt);
     452             :               }
     453             :             else
     454   356998570 :               mpz_addmul_ui (mpzv[l + k], mpzspm->crt1[i], t);
     455             : 
     456             :             /* After the conversion from t to float and the multiplication,
     457             :                the value of (float) t * prime_recip = t/p_i * (1+v)^4
     458             :                where |v| <= 2^(-24). Since |t| < p_i, the absolute error
     459             :                is bounded by (1+v^4)-1 <= 5*v. Thus the total error on f[k]
     460             :                is bounded by 5*sp_num*2^(-24). Since we want this to be smaller
     461             :                than 0.5, we need sp_num <= 2^23/5 thus sp_num <= 1677721.
     462             :                This corresponds to a number of at most 15656374 digits on a
     463             :                32-bit machine, and at most 31312749 digits on 64-bit. */
     464   356998570 :             f[k] += (float) t * prime_recip;
     465             :           }
     466             :       }
     467             : 
     468             :     /* crt2[i] = -i*P mod modulus */
     469    27369203 :     for (k = 0; k < stride; k++)
     470    27145729 :       mpz_add (mpzv[l + k], mpzv[l + k], mpzspm->crt2[(unsigned int) f[k]]);
     471             :   }
     472             :   
     473       86720 :   mpz_clear (mt);
     474       86720 :   free (f);
     475             : #ifdef TIMING_CRT
     476             :   mpzspv_to_mpzv_time += cputime ();
     477             : #endif
     478       86720 : }  
     479             : 
     480             : #if 0
     481             : void
     482             : mpzspv_pwmul (mpzspv_t r, spv_size_t r_offset, mpzspv_t x, spv_size_t x_offset,
     483             :               mpzspv_t y, spv_size_t y_offset, spv_size_t len, mpzspm_t mpzspm)
     484             : {
     485             :   unsigned int i;
     486             :   
     487             :   ASSERT (mpzspv_verify (r, r_offset + len, 0, mpzspm));
     488             :   ASSERT (mpzspv_verify (x, x_offset, len, mpzspm));
     489             :   ASSERT (mpzspv_verify (y, y_offset, len, mpzspm));
     490             :   
     491             :   for (i = 0; i < mpzspm->sp_num; i++)
     492             :     spv_pwmul (r[i] + r_offset, x[i] + x_offset, y[i] + y_offset,
     493             :         len, mpzspm->spm[i]->sp, mpzspm->spm[i]->mul_c);
     494             : }
     495             : #endif
     496             : 
     497             : /* Normalise the vector x[][offset..offset+len-1] of RNS residues modulo the
     498             :  * input modulus N.
     499             :  *
     500             :  * Reference: Bernstein & Sorenson: Explicit CRT mod m mod p_j, Theorem 4.1.
     501             :  *
     502             :  * time:   O(len * sp_num^2)
     503             :  * memory: MPZSPV_NORMALISE_STRIDE mpzspv coeffs
     504             :  *         6 * MPZSPV_NORMALISE_STRIDE sp's
     505             :  *         MPZSPV_NORMALISE_STRIDE floats
     506             :  * For a subquadratic version: look at Section 23 of
     507             :  * http://cr.yp.to/papers.html#multapps
     508             : */
     509             : void
     510       24491 : mpzspv_normalise (mpzspv_t x, spv_size_t offset, spv_size_t len,
     511             :     mpzspm_t mpzspm)
     512             : {
     513       24491 :   unsigned int i, j, sp_num = mpzspm->sp_num;
     514             :   spv_size_t k, l;
     515             :   sp_t v;
     516             :   spv_t s, d, w;
     517       24491 :   spm_t *spm = mpzspm->spm;
     518             :   float prime_recip;
     519             :   float *f;
     520             :   mpzspv_t t;
     521             : 
     522             : #ifdef TIMING_CRT
     523             :   mpzspv_normalise_time -= cputime ();
     524             : #endif
     525             :   ASSERT (mpzspv_verify (x, offset, len, mpzspm));
     526             : 
     527       24491 :   f = (float *) malloc (MPZSPV_NORMALISE_STRIDE * sizeof (float));
     528       24491 :   s = (spv_t) malloc (3 * MPZSPV_NORMALISE_STRIDE * sizeof (sp_t));
     529       24491 :   d = (spv_t) malloc (3 * MPZSPV_NORMALISE_STRIDE * sizeof (sp_t));
     530       24491 :   if (f == NULL || s == NULL || d == NULL)
     531             :     {
     532           0 :       fprintf (stderr, "Cannot allocate memory in mpzspv_normalise\n");
     533           0 :       exit (1);
     534             :     }
     535       24491 :   t = mpzspv_init (MPZSPV_NORMALISE_STRIDE, mpzspm);
     536             :   
     537       24491 :   memset (s, 0, 3 * MPZSPV_NORMALISE_STRIDE * sizeof (sp_t));
     538             : 
     539      290652 :   for (l = 0; l < len; l += MPZSPV_NORMALISE_STRIDE)
     540             :     {
     541      266161 :       spv_size_t stride = MIN (MPZSPV_NORMALISE_STRIDE, len - l);
     542             :       
     543             :       /* FIXME: use B&S Theorem 2.2 */
     544    19434723 :       for (k = 0; k < stride; k++)
     545    19168562 :         f[k] = 0.5; /* this is executed len times */
     546             :       
     547     5906402 :       for (i = 0; i < sp_num; i++)
     548             :         {
     549             :           /* this loop is performed len*sp_num/MPZSPV_NORMALISE_STRIDE times */
     550     5640241 :           prime_recip = 1.0f / (float) spm[i]->sp;
     551             :       
     552   359145077 :           for (k = 0; k < stride; k++)
     553             :             {
     554             :               /* this is executed len*sp_num times,
     555             :                  crt3[i] = p_i/P mod p_i (q_i in Theorem 3.1) */
     556   707009672 :               x[i][l + k + offset] = sp_mul (x[i][l + k + offset],
     557   353504836 :                   mpzspm->crt3[i], spm[i]->sp, spm[i]->mul_c);
     558             :               /* now x[i] is t_i in Theorem 3.1 */
     559   353504836 :               f[k] += (float) x[i][l + k + offset] * prime_recip;
     560             :             }
     561             :         }
     562             :       
     563     5906402 :       for (i = 0; i < sp_num; i++)
     564             :         {
     565   359145077 :           for (k = 0; k < stride; k++)
     566             :             {
     567             :               /* this is executed len*sp_num times */
     568             : 
     569             :               /* crt5[i] = (-P mod modulus) mod p_i */
     570   353504836 :               umul_ppmm (d[3 * k + 1], d[3 * k], mpzspm->crt5[i], (sp_t) f[k]);
     571             :               /* {d+3*k,2} = ((-P mod modulus) mod p_i) * round(sum(t_j/p_j)),
     572             :                  this accounts for the right term in Theorem 4.1 */
     573   353504836 :               d[3 * k + 2] = 0;
     574             :             }
     575             :         
     576   132106218 :           for (j = 0; j < sp_num; j++)
     577             :             {
     578             :               /* this is executed len*sp_num^2/MPZSPV_NORMALISE_STRIDE times */
     579   126465977 :               w = x[j] + offset;
     580             :               /* crt4[i][j] = ((P / p[i]) mod modulus) mod p[j] */
     581   126465977 :               v = mpzspm->crt4[i][j];
     582             :             
     583  7409599329 :               for (k = 0; k < stride; k++)
     584             :                 /* this is executed len*sp_num^2 times, and computes the left
     585             :                    term in Theorem 4.1 */
     586  7283133352 :                 umul_ppmm (s[3 * k + 1], s[3 * k], w[k + l], v);
     587             :               
     588             :               /* This mpn_add_n adds in parallel all "stride" contributions,
     589             :                  and accounts for about a third of the function's runtime.
     590             :                  Since d has size O(stride), the cumulated complexity of this
     591             :                  call is O(len*sp_num^2) */
     592   126465977 :               mpn_add_n ((mp_ptr) d, (mp_srcptr) d, (mp_srcptr) s, 3 * stride);
     593             :             }      
     594             : 
     595             :           /* we finally reduce the contribution modulo each p_i */
     596   359145077 :           for (k = 0; k < stride; k++)
     597   353504836 :             t[i][k] = mpn_mod_1 ((mp_ptr) (d + 3 * k), 3, spm[i]->sp);
     598             :         }         
     599      266161 :       mpzspv_set (x, l + offset, t, 0, stride, mpzspm);
     600             :     }
     601             :   
     602       24491 :   mpzspv_clear (t, mpzspm);
     603             :   
     604       24491 :   free (s);
     605       24491 :   free (d);
     606       24491 :   free (f);
     607             : #ifdef TIMING_CRT
     608             :   mpzspv_normalise_time += cputime ();
     609             : #endif
     610       24491 : }
     611             : 
     612             : void
     613        1840 : mpzspv_to_ntt (mpzspv_t x, spv_size_t offset, spv_size_t len,
     614             :     spv_size_t ntt_size, int monic, mpzspm_t mpzspm)
     615             : {
     616             :   unsigned int i;
     617             :   spv_size_t j, log2_ntt_size;
     618             :   spm_t spm;
     619             :   spv_t spv;
     620             :   
     621             :   ASSERT (mpzspv_verify (x, offset, len, mpzspm));
     622             :   ASSERT (mpzspv_verify (x, offset + ntt_size, 0, mpzspm));
     623             :   
     624        1840 :   log2_ntt_size = ceil_log_2 (ntt_size);
     625             : 
     626       30206 :   for (i = 0; i < mpzspm->sp_num; i++)
     627             :     {
     628       28366 :       spm = mpzspm->spm[i];
     629       28366 :       spv = x[i] + offset;
     630             :       
     631       28366 :       if (ntt_size < len)
     632             :         {
     633           0 :           for (j = ntt_size; j < len; j += ntt_size)
     634           0 :             spv_add (spv, spv, spv + j, ntt_size, spm->sp);
     635             :         }
     636       28366 :       if (ntt_size > len)
     637       14183 :         spv_set_zero (spv + len, ntt_size - len);
     638             : 
     639       28366 :       if (monic)
     640       14183 :         spv[len % ntt_size] = sp_add (spv[len % ntt_size], 1, spm->sp);
     641             :       
     642       28366 :       spv_ntt_gfp_dif (spv, log2_ntt_size, spm);
     643             :     }
     644        1840 : }
     645             : 
     646             : #if 0
     647             : void
     648             : mpzspv_from_ntt (mpzspv_t x, spv_size_t offset, spv_size_t ntt_size,
     649             :                  spv_size_t monic_pos, mpzspm_t mpzspm)
     650             : {
     651             :   unsigned int i;
     652             :   spv_size_t log2_ntt_size;
     653             :   spm_t spm;
     654             :   spv_t spv;
     655             :   
     656             :   ASSERT (mpzspv_verify (x, offset, ntt_size, mpzspm));
     657             :   
     658             :   log2_ntt_size = ceil_log_2 (ntt_size);
     659             : 
     660             :   for (i = 0; i < mpzspm->sp_num; i++)
     661             :     {
     662             :       spm = mpzspm->spm[i];
     663             :       spv = x[i] + offset;
     664             :       
     665             :       spv_ntt_gfp_dit (spv, log2_ntt_size, spm);
     666             : 
     667             :       /* spm->sp - (spm->sp - 1) / ntt_size is the inverse of ntt_size */
     668             :       spv_mul_sp (spv, spv, spm->sp - (spm->sp - 1) / ntt_size,
     669             :           ntt_size, spm->sp, spm->mul_c);
     670             :       
     671             :       if (monic_pos)
     672             :         spv[monic_pos % ntt_size] = sp_sub (spv[monic_pos % ntt_size],
     673             :             1, spm->sp);
     674             :     }
     675             : }
     676             : #endif
     677             : 
     678             : void
     679           1 : mpzspv_random (mpzspv_t x, spv_size_t offset, spv_size_t len, mpzspm_t mpzspm)
     680             : {
     681             :   unsigned int i;
     682             : 
     683             :   ASSERT (mpzspv_verify (x, offset, len, mpzspm));
     684             : 
     685          25 :   for (i = 0; i < mpzspm->sp_num; i++)
     686          24 :     spv_random (x[i] + offset, len, mpzspm->spm[i]->sp);
     687           1 : }
     688             : 
     689             : 
     690             : /* Do multiplication via NTT. Depending on the value of "steps", does 
     691             :    in-place forward transform of x, in-place forward transform of y, 
     692             :    pair-wise multiplication of x by y to r, in-place inverse transform of r. 
     693             :    Contrary to calling these three operations separately, this function does 
     694             :    all three steps on a small-prime vector at a time, resulting in slightly 
     695             :    better cache efficiency (also in preparation to storing NTT vectors on disk 
     696             :    and reading them in for the multiplication). */
     697             : 
     698             : void
     699       66531 : mpzspv_mul_ntt (mpzspv_t r, const spv_size_t offsetr, 
     700             :     mpzspv_t x, const spv_size_t offsetx, const spv_size_t lenx,
     701             :     mpzspv_t y, const spv_size_t offsety, const spv_size_t leny,
     702             :     const spv_size_t ntt_size, const int monic, const spv_size_t monic_pos, 
     703             :     mpzspm_t mpzspm, const int steps)
     704             : {
     705             :   spv_size_t log2_ntt_size;
     706             :   int i;
     707             :   
     708             :   ASSERT (mpzspv_verify (x, offsetx, lenx, mpzspm));
     709             :   ASSERT (mpzspv_verify (y, offsety, leny, mpzspm));
     710             :   ASSERT (mpzspv_verify (x, offsetx + ntt_size, 0, mpzspm));
     711             :   ASSERT (mpzspv_verify (y, offsety + ntt_size, 0, mpzspm));
     712             :   ASSERT (mpzspv_verify (r, offsetr + ntt_size, 0, mpzspm));
     713             :   
     714       66531 :   log2_ntt_size = ceil_log_2 (ntt_size);
     715             : 
     716             :   /* Need parallelization at higher level (e.g., handling a branch of the 
     717             :      product tree in one thread) to make this worthwhile for ECM */
     718             : #define MPZSPV_MUL_NTT_OPENMP 0
     719             : 
     720             : #if defined(_OPENMP) && MPZSPV_MUL_NTT_OPENMP
     721             : #pragma omp parallel if (ntt_size > 16384)
     722             :   {
     723             : #pragma omp for
     724             : #endif
     725     1473287 :   for (i = 0; i < (int) mpzspm->sp_num; i++)
     726             :     {
     727             :       spv_size_t j;
     728     1406756 :       spm_t spm = mpzspm->spm[i];
     729     1406756 :       spv_t spvr = r[i] + offsetr;
     730     1406756 :       spv_t spvx = x[i] + offsetx;
     731     1406756 :       spv_t spvy = y[i] + offsety;
     732             : 
     733     1406756 :       if ((steps & NTT_MUL_STEP_FFT1) != 0) {
     734     1400241 :         if (ntt_size < lenx)
     735             :           {
     736           0 :             for (j = ntt_size; j < lenx; j += ntt_size)
     737           0 :               spv_add (spvx, spvx, spvx + j, ntt_size, spm->sp);
     738             :           }
     739     1400241 :         if (ntt_size > lenx)
     740      824057 :           spv_set_zero (spvx + lenx, ntt_size - lenx);
     741             : 
     742     1400241 :         if (monic)
     743      581358 :           spvx[lenx % ntt_size] = sp_add (spvx[lenx % ntt_size], 1, spm->sp);
     744             : 
     745     1400241 :         spv_ntt_gfp_dif (spvx, log2_ntt_size, spm);
     746             :       }
     747             : 
     748     1406756 :       if ((steps & NTT_MUL_STEP_FFT2) != 0) {
     749      712553 :         if (ntt_size < leny)
     750             :           {
     751           0 :             for (j = ntt_size; j < leny; j += ntt_size)
     752           0 :               spv_add (spvy, spvy, spvy + j, ntt_size, spm->sp);
     753             :           }
     754      712553 :         if (ntt_size > leny)
     755      608792 :           spv_set_zero (spvy + leny, ntt_size - leny);
     756             : 
     757      712553 :         if (monic)
     758      581358 :           spvy[leny % ntt_size] = sp_add (spvy[leny % ntt_size], 1, spm->sp);
     759             : 
     760      712553 :         spv_ntt_gfp_dif (spvy, log2_ntt_size, spm);
     761             :       }
     762             : 
     763     1406756 :       if ((steps & NTT_MUL_STEP_MUL) != 0) {
     764     1406756 :         spv_pwmul (spvr, spvx, spvy, ntt_size, spm->sp, spm->mul_c);
     765             :       }
     766             : 
     767     1406756 :       if ((steps & NTT_MUL_STEP_IFFT) != 0) {
     768             :         ASSERT (sizeof (mp_limb_t) >= sizeof (sp_t));
     769             : 
     770     1406756 :         spv_ntt_gfp_dit (spvr, log2_ntt_size, spm);
     771             : 
     772             :         /* spm->sp - (spm->sp - 1) / ntt_size is the inverse of ntt_size */
     773     1406756 :         spv_mul_sp (spvr, spvr, spm->sp - (spm->sp - 1) / ntt_size,
     774             :             ntt_size, spm->sp, spm->mul_c);
     775             : 
     776     1406756 :         if (monic_pos)
     777      581358 :           spvr[monic_pos % ntt_size] = sp_sub (spvr[monic_pos % ntt_size],
     778             :               1, spm->sp);
     779             :       }
     780             :     }
     781             : #if defined(_OPENMP) && MPZSPV_MUL_NTT_OPENMP
     782             :   }
     783             : #endif
     784       66531 : }
     785             : 
     786             : /* Computes a DCT-I of the length dctlen. Input is the spvlen coefficients
     787             :    in spv. tmp is temp space and must have space for 2*dctlen-2 sp_t's */
     788             : 
     789             : void
     790         561 : mpzspv_to_dct1 (mpzspv_t dct, const mpzspv_t spv, const spv_size_t spvlen, 
     791             :                 const spv_size_t dctlen, mpzspv_t tmp, 
     792             :                 const mpzspm_t mpzspm)
     793             : {
     794         561 :   const spv_size_t l = 2 * (dctlen - 1); /* Length for the DFT */
     795         561 :   const spv_size_t log2_l = ceil_log_2 (l);
     796             :   int j;
     797             : 
     798             : #ifdef _OPENMP
     799             : #pragma omp parallel private(j)
     800             :   {
     801             : #pragma omp for
     802             : #endif
     803        8256 :   for (j = 0; j < (int) mpzspm->sp_num; j++)
     804             :     {
     805        7695 :       const spm_t spm = mpzspm->spm[j];
     806             :       spv_size_t i;
     807             :       
     808             :       /* Make a symmetric copy of spv in tmp. I.e. with spv = [3, 2, 1], 
     809             :          spvlen = 3, dctlen = 5 (hence l = 8), we want 
     810             :          tmp = [3, 2, 1, 0, 0, 0, 1, 2] */
     811        7695 :       spv_set (tmp[j], spv[j], spvlen);
     812        7695 :       spv_rev (tmp[j] + l - spvlen + 1, spv[j] + 1, spvlen - 1);
     813             :       /* Now we have [3, 2, 1, ?, ?, ?, 1, 2]. Fill the ?'s with zeros. */
     814        7695 :       spv_set_sp (tmp[j] + spvlen, (sp_t) 0, l - 2 * spvlen + 1);
     815             : 
     816             : #if 0
     817             :       printf ("mpzspv_to_dct1: tmp[%d] = [", j);
     818             :       for (i = 0; i < l; i++)
     819             :           printf ("%lu, ", tmp[j][i]);
     820             :       printf ("]\n");
     821             : #endif
     822             :       
     823        7695 :       spv_ntt_gfp_dif (tmp[j], log2_l, spm);
     824             : 
     825             : #if 0
     826             :       printf ("mpzspv_to_dct1: tmp[%d] = [", j);
     827             :       for (i = 0; i < l; i++)
     828             :           printf ("%lu, ", tmp[j][i]);
     829             :       printf ("]\n");
     830             : #endif
     831             : 
     832             :       /* The forward transform is scrambled. We want elements [0 ... l/2]
     833             :          of the unscrabled data, that is all the coefficients with the most 
     834             :          significant bit in the index (in log2(l) word size) unset, plus the 
     835             :          element at index l/2. By scrambling, these map to the elements with 
     836             :          even index, plus the element at index 1. 
     837             :          The elements with scrambled index 2*i are stored in h[i], the
     838             :          element with scrambled index 1 is stored in h[params->l] */
     839             :   
     840             : #ifdef WANT_ASSERT
     841             :       /* Test that the coefficients are symmetric (if they were unscrambled)
     842             :          and that our algorithm for finding identical coefficients in the 
     843             :          scrambled data works */
     844             :       {
     845             :         spv_size_t m = 5;
     846             :         for (i = 2; i < l; i += 2L)
     847             :           {
     848             :             /* This works, but why? */
     849             :             if (i + i / 2L > m)
     850             :                 m = 2L * m + 1L;
     851             : 
     852             :             ASSERT (tmp[j][i] == tmp[j][m - i]);
     853             : #if 0
     854             :             printf ("mpzspv_to_dct1: DFT[%lu] == DFT[%lu]\n", i, m - i);
     855             : #endif
     856             :           }
     857             :       }
     858             : #endif
     859             : 
     860             :       /* Copy coefficients to dct buffer */
     861    56926031 :       for (i = 0; i < l / 2; i++)
     862    56918336 :         dct[j][i] = tmp[j][i * 2];
     863        7695 :       dct[j][l / 2] = tmp[j][1];
     864             :     }
     865             : #ifdef _OPENMP
     866             :   }
     867             : #endif
     868         561 : }
     869             : 
     870             : 
     871             : /* Multiply the polynomial in "dft" by the RLP in "dct", where "dft" 
     872             :    contains the polynomial coefficients (not FFT'd yet) and "dct" 
     873             :    contains the DCT-I coefficients of the RLP. The latter are 
     874             :    assumed to be in the layout produced by mpzspv_to_dct1().
     875             :    Output are the coefficients of the product polynomial, stored in dft. 
     876             :    The "steps" parameter controls which steps are computed:
     877             :    NTT_MUL_STEP_FFT1: do forward transform
     878             :    NTT_MUL_STEP_MUL: do point-wise product
     879             :    NTT_MUL_STEP_IFFT: do inverse transform 
     880             : */
     881             : 
     882             : void
     883        1133 : mpzspv_mul_by_dct (mpzspv_t dft, const mpzspv_t dct, const spv_size_t len, 
     884             :                    const mpzspm_t mpzspm, const int steps)
     885             : {
     886             :   int j;
     887        1133 :   spv_size_t log2_len = ceil_log_2 (len);
     888             :   
     889             : #ifdef _OPENMP
     890             : #pragma omp parallel private(j)
     891             :   {
     892             : #pragma omp for
     893             : #endif
     894       17065 :     for (j = 0; j < (int) (mpzspm->sp_num); j++)
     895             :       {
     896       15932 :         const spm_t spm = mpzspm->spm[j];
     897       15932 :         const spv_t spv = dft[j];
     898             :         unsigned long i, m;
     899             :         
     900             :         /* Forward DFT of dft[j] */
     901       15932 :         if ((steps & NTT_MUL_STEP_FFT1) != 0)
     902       13309 :           spv_ntt_gfp_dif (spv, log2_len, spm);
     903             :         
     904             :         /* Point-wise product */
     905       15932 :         if ((steps & NTT_MUL_STEP_MUL) != 0)
     906             :           {
     907       13309 :             m = 5UL;
     908             :             
     909       13309 :             spv[0] = sp_mul (spv[0], dct[j][0], spm->sp, spm->mul_c);
     910       13309 :             spv[1] = sp_mul (spv[1], dct[j][len / 2UL], spm->sp, spm->mul_c);
     911             :             
     912    64172064 :             for (i = 2UL; i < len; i += 2UL)
     913             :               {
     914             :                 /* This works, but why? */
     915    64158755 :                 if (i + i / 2UL > m)
     916       82036 :                   m = 2UL * m + 1;
     917             :                 
     918    64158755 :                 spv[i] = sp_mul (spv[i], dct[j][i / 2UL], spm->sp, spm->mul_c);
     919    64158755 :                 spv[m - i] = sp_mul (spv[m - i], dct[j][i / 2UL], spm->sp, 
     920             :                                      spm->mul_c);
     921             :               }
     922             :           }
     923             :         
     924             :         /* Inverse transform of dft[j] */
     925       15932 :         if ((steps & NTT_MUL_STEP_IFFT) != 0)
     926             :           {
     927       10686 :             spv_ntt_gfp_dit (spv, log2_len, spm);
     928             :             
     929             :             /* Divide by transform length. FIXME: scale the DCT of h instead */
     930       10686 :             spv_mul_sp (spv, spv, spm->sp - (spm->sp - 1) / len, len, 
     931             :                         spm->sp, spm->mul_c);
     932             :           }
     933             :       }
     934             : #ifdef _OPENMP
     935             :   }
     936             : #endif
     937        1133 : }
     938             : 
     939             : 
     940             : void 
     941        3659 : mpzspv_sqr_reciprocal (mpzspv_t dft, const spv_size_t n, 
     942             :                        const mpzspm_t mpzspm)
     943             : {
     944        3659 :   const spv_size_t log2_n = ceil_log_2 (n);
     945        3659 :   const spv_size_t len = ((spv_size_t) 2) << log2_n;
     946        3659 :   const spv_size_t log2_len = 1 + log2_n;
     947             :   int j;
     948             : 
     949             :   ASSERT(mpzspm->max_ntt_size % 3UL == 0UL);
     950             :   ASSERT(len % 3UL != 0UL);
     951             :   ASSERT(mpzspm->max_ntt_size % len == 0UL);
     952             : 
     953             : #ifdef _OPENMP
     954             : #pragma omp parallel
     955             :   {
     956             : #pragma omp for
     957             : #endif
     958       49943 :     for (j = 0; j < (int) (mpzspm->sp_num); j++)
     959             :       {
     960       46284 :         const spm_t spm = mpzspm->spm[j];
     961       46284 :         const spv_t spv = dft[j];
     962             :         sp_t w1, w2, invlen;
     963       46284 :         const sp_t sp = spm->sp, mul_c = spm->mul_c;
     964             :         spv_size_t i;
     965             : 
     966             :         /* Zero out NTT elements [n .. len-n] */
     967       46284 :         spv_set_sp (spv + n, (sp_t) 0, len - 2*n + 1);
     968             : 
     969             : #ifdef TRACE_ntt_sqr_reciprocal
     970             :         if (j == 0)
     971             :           {
     972             :             printf ("ntt_sqr_reciprocal: NTT vector mod %lu\n", sp);
     973             :             ntt_print_vec ("ntt_sqr_reciprocal: before weighting:", spv, len);
     974             :           }
     975             : #endif
     976             : 
     977             :         /* Compute the root for the weight signal, a 3rd primitive root 
     978             :            of unity */
     979       46284 :         w1 = sp_pow (spm->prim_root, mpzspm->max_ntt_size / 3UL, sp, 
     980             :                      mul_c);
     981             :         /* Compute iw= 1/w */
     982       46284 :         w2 = sp_pow (spm->inv_prim_root, mpzspm->max_ntt_size / 3UL, sp, 
     983             :                      mul_c);
     984             : #ifdef TRACE_ntt_sqr_reciprocal
     985             :         if (j == 0)
     986             :           printf ("w1 = %lu ,w2 = %lu\n", w1, w2);
     987             : #endif
     988             :         ASSERT(sp_mul(w1, w2, sp, mul_c) == (sp_t) 1);
     989             :         ASSERT(w1 != (sp_t) 1);
     990             :         ASSERT(sp_pow (w1, 3UL, sp, mul_c) == (sp_t) 1);
     991             :         ASSERT(w2 != (sp_t) 1);
     992             :         ASSERT(sp_pow (w2, 3UL, sp, mul_c) == (sp_t) 1);
     993             : 
     994             :         /* Fill NTT elements spv[len-n+1 .. len-1] with coefficients and
     995             :            apply weight signal to spv[i] and spv[l-i] for 0 <= i < n
     996             :            Use the fact that w^i + w^{-i} = -1 if i != 0 (mod 3). */
     997     8328609 :         for (i = 0; i + 2 < n; i += 3)
     998             :           {
     999             :             sp_t t, u;
    1000             :             
    1001     8282325 :             if (i > 0)
    1002     8245919 :               spv[len - i] = spv[i];
    1003             :             
    1004     8282325 :             t = spv[i + 1];
    1005     8282325 :             u = sp_mul (t, w1, sp, mul_c);
    1006     8282325 :             spv[i + 1] = u;
    1007     8282325 :             spv[len - i - 1] = sp_neg (sp_add (t, u, sp), sp);
    1008             : 
    1009     8282325 :             t = spv[i + 2];
    1010     8282325 :             u = sp_mul (t, w2, sp, mul_c);
    1011     8282325 :             spv[i + 2] = u;
    1012     8282325 :             spv[len - i - 2] = sp_neg (sp_add (t, u, sp), sp);
    1013             :           }
    1014       46284 :         if (i < n && i > 0)
    1015             :           {
    1016       19747 :             spv[len - i] = spv[i];
    1017             :           }
    1018       46284 :         if (i + 1 < n)
    1019             :           {
    1020             :             sp_t t, u;
    1021       12559 :             t = spv[i + 1];
    1022       12559 :             u = sp_mul (t, w1, sp, mul_c);
    1023       12559 :             spv[i + 1] = u;
    1024       12559 :             spv[len - i - 1] = sp_neg (sp_add (t, u, sp), sp);
    1025             :           }
    1026             : 
    1027             : #ifdef TRACE_ntt_sqr_reciprocal
    1028             :       if (j == 0)
    1029             :         ntt_print_vec ("ntt_sqr_reciprocal: after weighting:", spv, len);
    1030             : #endif
    1031             : 
    1032             :         /* Forward DFT of dft[j] */
    1033       46284 :         spv_ntt_gfp_dif (spv, log2_len, spm);
    1034             : 
    1035             : #ifdef TRACE_ntt_sqr_reciprocal
    1036             :         if (j == 0)
    1037             :           ntt_print_vec ("ntt_sqr_reciprocal: after forward transform:", 
    1038             :                          spv, len);
    1039             : #endif
    1040             : 
    1041             :         /* Square the transformed vector point-wise */
    1042       46284 :         spv_pwmul (spv, spv, spv, len, sp, mul_c);
    1043             :       
    1044             : #ifdef TRACE_ntt_sqr_reciprocal
    1045             :         if (j == 0)
    1046             :           ntt_print_vec ("ntt_sqr_reciprocal: after point-wise squaring:", 
    1047             :                          spv, len);
    1048             : #endif
    1049             : 
    1050             :         /* Inverse transform of dft[j] */
    1051       46284 :         spv_ntt_gfp_dit (spv, log2_len, spm);
    1052             :       
    1053             : #ifdef TRACE_ntt_sqr_reciprocal
    1054             :         if (j == 0)
    1055             :           ntt_print_vec ("ntt_sqr_reciprocal: after inverse transform:", 
    1056             :                          spv, len);
    1057             : #endif
    1058             : 
    1059             :         /* Un-weight and divide by transform length */
    1060       46284 :         invlen = sp - (sp - (sp_t) 1) / len; /* invlen = 1/len (mod sp) */
    1061       46284 :         w1 = sp_mul (invlen, w1, sp, mul_c);
    1062       46284 :         w2 = sp_mul (invlen, w2, sp, mul_c);
    1063    16606834 :         for (i = 0; i < 2 * n - 3; i += 3)
    1064             :           {
    1065    16560550 :             spv[i] = sp_mul (spv[i], invlen, sp, mul_c);
    1066    16560550 :             spv[i + 1] = sp_mul (spv[i + 1], w2, sp, mul_c);
    1067    16560550 :             spv[i + 2] = sp_mul (spv[i + 2], w1, sp, mul_c);
    1068             :           }
    1069       46284 :         if (i < 2 * n - 1)
    1070       33725 :           spv[i] = sp_mul (spv[i], invlen, sp, mul_c);
    1071       46284 :         if (i < 2 * n - 2)
    1072       16659 :           spv[i + 1] = sp_mul (spv[i + 1], w2, sp, mul_c);
    1073             :         
    1074             : #ifdef TRACE_ntt_sqr_reciprocal
    1075             :         if (j == 0)
    1076             :           ntt_print_vec ("ntt_sqr_reciprocal: after un-weighting:", spv, len);
    1077             : #endif
    1078             : 
    1079             :         /* Separate the coefficients of R in the wrapped-around product. */
    1080             : 
    1081             :         /* Set w1 = cuberoot(1)^l where cuberoot(1) is the same primitive
    1082             :            3rd root of unity we used for the weight signal */
    1083       46284 :         w1 = sp_pow (spm->prim_root, mpzspm->max_ntt_size / 3UL, sp, 
    1084             :                      mul_c);
    1085       46284 :         w1 = sp_pow (w1, len % 3UL, sp, mul_c);
    1086             :         
    1087             :         /* Set w2 = 1/(w1 - 1/w1). Incidentally, w2 = 1/sqrt(-3) */
    1088       46284 :         w2 = sp_inv (w1, sp, mul_c);
    1089       46284 :         w2 = sp_sub (w1, w2, sp);
    1090       46284 :         w2 = sp_inv (w2, sp, mul_c);
    1091             : #ifdef TRACE_ntt_sqr_reciprocal
    1092             :         if (j == 0)
    1093             :           printf ("For separating: w1 = %lu, w2 = %lu\n", w1, w2);
    1094             : #endif
    1095             :         
    1096    23960930 :         for (i = len - (2*n - 2); i <= len / 2; i++)
    1097             :           {
    1098             :             sp_t t, u;
    1099             :             /* spv[i] = s_i + w^{-l} s_{l-i}. 
    1100             :                spv[l-i] = s_{l-i} + w^{-l} s_i */
    1101    23914646 :             t = sp_mul (spv[i], w1, sp, mul_c); /* t = w^l s_i + s_{l-i} */
    1102    23914646 :             t = sp_sub (t, spv[len - i], sp);   /* t = w^l s_i + w^{-l} s_i */
    1103    23914646 :             t = sp_mul (t, w2, sp, mul_c);      /* t = s_1 */
    1104             : 
    1105    23914646 :             u = sp_sub (spv[i], t, sp);         /* u = w^{-l} s_{l-i} */
    1106    23914646 :             u = sp_mul (u, w1, sp, mul_c);      /* u = s_{l-i} */
    1107    23914646 :             spv[i] = t;
    1108    23914646 :             spv[len - i] = u;
    1109             :             ASSERT(i < len / 2 || t == u);
    1110             :           }
    1111             : 
    1112             : #ifdef TRACE_ntt_sqr_reciprocal
    1113             :         if (j == 0)
    1114             :           ntt_print_vec ("ntt_sqr_reciprocal: after un-wrapping:", spv, len);
    1115             : #endif
    1116             :       }
    1117             : #ifdef _OPENMP
    1118             :     }
    1119             : #endif
    1120        3659 : }

Generated by: LCOV version 1.14