LCOV - code coverage report
Current view: top level - ecm - spm.c (source / functions) Hit Total Coverage
Test: unnamed Lines: 86 94 91.5 %
Date: 2022-03-21 11:19:20 Functions: 6 6 100.0 %

          Line data    Source code
       1             : /* spm.c - "small prime modulus" functions to precompute an inverse and a
       2             :    primitive root for a small prime
       3             : 
       4             : Copyright 2005, 2006, 2008, 2009, 2010, 2012 Dave Newman, Jason Papadopoulos,
       5             : Paul Zimmermann, Alexander Kruppa.
       6             : 
       7             : The SP Library is free software; you can redistribute it and/or modify
       8             : it under the terms of the GNU Lesser General Public License as published by
       9             : the Free Software Foundation; either version 3 of the License, or (at your
      10             : option) any later version.
      11             : 
      12             : The SP Library is distributed in the hope that it will be useful, but
      13             : WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
      14             : or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public
      15             : License for more details.
      16             : 
      17             : You should have received a copy of the GNU Lesser General Public License
      18             : along with the SP Library; see the file COPYING.LIB.  If not, write to
      19             : the Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston,
      20             : MA 02110-1301, USA. */
      21             : 
      22             : #include <stdlib.h>
      23             : #include "sp.h"
      24             : 
      25             : /* Returns the exponent of $q$ in the factorisation of $n$ */
      26             : static int
      27       88852 : exponent (const sp_t q, sp_t n)
      28             : {
      29             :   int i;
      30      735821 :   for (i = 0; n % q == (sp_t) 0; i++, n /= q);
      31       88852 :   return i;
      32             : }
      33             : 
      34             : /* Returns i so that ord(a) = q^i. This assumes that ord(a) is indeed
      35             :    a low power of q. */
      36             : static int
      37       88852 : ordpow (const sp_t q, sp_t a, const sp_t sp, const sp_t mul_c)
      38             : {
      39       88852 :   int i = 0;
      40      644617 :   for (i = 0; a != (sp_t) 1; i++, a = sp_pow (a, q, sp, mul_c));
      41       88852 :   return i;
      42             : }
      43             : 
      44             : /* initialize roots of unity and twiddle factors for one NTT.
      45             :    If successful, returns 1.
      46             :    If unsuccessful, returns 0 (and frees allocated memory) */
      47             : static int
      48       52472 : nttdata_init (const sp_t sp, const sp_t mul_c, 
      49             :                 const sp_t prim_root, const spv_size_t log2_len,
      50             :                 sp_nttdata_t data, spv_size_t breakover)
      51             : {
      52             :   spv_t r, t;
      53             :   spv_size_t i, j, k;
      54             : 
      55       52472 :   r = data->ntt_roots = 
      56       52472 :           (spv_t) sp_aligned_malloc ((log2_len + 1) * sizeof(sp_t));
      57       52472 :   if (r == NULL)
      58           0 :     return 0;
      59             : 
      60       52472 :   i = log2_len;
      61       52472 :   r[i] = prim_root;
      62      453784 :   for (i--; (int)i >= 0; i--)
      63      401312 :     r[i] = sp_sqr (r[i+1], sp, mul_c);
      64             : 
      65       52472 :   k = MIN(log2_len, breakover);
      66       52472 :   t = data->twiddle = (spv_t) sp_aligned_malloc (sizeof(sp_t) << k);
      67       52472 :   if (t == NULL)
      68             :     {
      69           0 :       sp_aligned_free (r);
      70           0 :       return 0;
      71             :     }
      72       52472 :   data->twiddle_size = 1 << k;
      73             : 
      74      450276 :   for (i = k; i; i--) 
      75             :     {
      76      397804 :       sp_t w = r[i];
      77    80195160 :       for (j = t[0] = 1; j < ((spv_size_t) 1 << (i-1)); j++) 
      78    79797356 :         t[j] = sp_mul (t[j-1], w, sp, mul_c);
      79             : 
      80      397804 :       t += j;
      81             :     }
      82       52472 :   return 1;
      83             : }
      84             : 
      85             : static void
      86       52472 : nttdata_clear(sp_nttdata_t data)
      87             : {
      88       52472 :   sp_aligned_free(data->ntt_roots);
      89       52472 :   sp_aligned_free(data->twiddle);
      90       52472 : }
      91             : 
      92             : /* Compute some constants, including a primitive n'th root of unity. 
      93             :    Returns NULL in case of error.
      94             :    k is the number of limbs of the number to factor
      95             : */
      96             : spm_t
      97       26236 : spm_init (spv_size_t n, sp_t sp, mp_size_t k)
      98             : {
      99             :   sp_t a, b, bd, sc;
     100             :   spv_size_t q, nc, ntt_power;
     101       26236 :   spm_t spm = (spm_t) malloc (sizeof (__spm_struct));
     102       26236 :   if (spm == NULL)
     103           0 :     return NULL;
     104             : 
     105             :   ASSERT (sp % (sp_t) n == (sp_t) 1);
     106             : 
     107       26236 :   spm->sp = sp;
     108       26236 :   sp_reciprocal (spm->mul_c, sp);
     109             : 
     110             :   /* compute spm->invm = -1/p mod B where B = 2^GMP_NUMB_BITS */
     111       26236 :   a = sp_pow (2, GMP_NUMB_BITS, sp, spm->mul_c); /* a = B mod p */
     112       26236 :   a = sp_inv (a, sp, spm->mul_c);                /* a = 1/B mod p */
     113             :   /* a = 1/B mod p thus B*a - 1 = invm*p */
     114       26236 :   a --;
     115       26236 :   b = GMP_NUMB_MASK;
     116             : #if SP_NUMB_BITS == W_TYPE_SIZE - 2
     117       26236 :   a = (a << 2) + (b >> (GMP_NUMB_BITS - 2));
     118       26236 :   b = (b << 2) & GMP_NUMB_MASK;
     119       26236 :   udiv_qrnnd (bd, sc, a, b, sp << 2);
     120             : #else
     121             :   a = (a << 1) + (b >> (GMP_NUMB_BITS - 1));
     122             :   b = (b << 1) & GMP_NUMB_MASK;
     123             :   udiv_qrnnd (bd, sc, a, b, sp << 1);
     124             : #endif
     125       26236 :   spm->invm = bd;
     126             : 
     127             :   /* compute spm->Bpow = B^(k+1) mod p */
     128       26236 :   spm->Bpow = sp_pow (2, GMP_NUMB_BITS * (k + 1), sp, spm->mul_c);
     129             : 
     130             :   /* find an $n$-th primitive root $a$ of unity $(mod sp)$. */
     131             : 
     132             :   /* Construct a $b$ whose order $(mod sp)$ is equal to $n$.
     133             :      We try different $a$ values and test if the exponent of $q$ in $ord(a)$
     134             :      is at least as large as in $n$. If it isn't, we move to another $a$.
     135             :      If it is, we optionally exponentiate to make the exponents equal and
     136             :      test for the remaining $q$'s.
     137             :      We assume that the largest prime dividing $n$ is very small, 
     138             :      so no optimizations in factoring n are made. */
     139       26236 :   a = 2;
     140       26236 :   b = a;
     141       26236 :   nc = n; /* nc is remaining cofactor of n */
     142       26236 :   q = 2;
     143       26236 :   sc = sp - 1;
     144             : #ifdef PARI
     145             :   printf ("/* spm_init */ n = %lu; sp = %lu; /* PARI */\n", n, sp);
     146             :   printf ("exponent(a,b) = {local(i); while(b%%a == 0,i++;b/=a); "
     147             :           "return(i)} /* PARI */\n");
     148             : #endif
     149      115088 :   for ( ; nc != (spv_size_t) 1; q++)
     150             :     {
     151       88852 :       if (nc % q == (spv_size_t) 0)
     152             :         {
     153       88852 :           const int k = exponent (q, n); /* q^k || n */
     154             :           sp_t d;
     155             :           int l;
     156             : #ifdef PARI
     157             :           printf ("exponent(%lu, n) == %d /* PARI */\n", q, k);
     158             : #endif
     159             :           /* Remove all factors of $q$ from $sp-1$ */
     160      779670 :           for (d = sp - 1; d % q == (spv_size_t) 0; d /= q);
     161       88852 :           bd = sp_pow (b, d, sp, spm->mul_c);
     162             :           /* Now ord(bd) = q^l, q^l || ord(a) */
     163       88852 :           l = ordpow (q, bd, sp, spm->mul_c);
     164             : #ifdef PARI
     165             :           printf ("exponent(%lu, znorder(Mod(%lu, sp))) == %d /* PARI */\n", 
     166             :                   q, b, l);
     167             : #endif
     168       88852 :           if (l < k)
     169             :             {
     170             :               /* No good, q appears in ord(a) in a lower power than in n. 
     171             :                  Try next $a$ */
     172       54254 :               a++;
     173       54254 :               b = a;
     174       54254 :               nc = n;
     175       54254 :               q = 1; /* Loop increment following "continue" will make q=2 */
     176       54254 :               sc = sp - 1;
     177       54254 :               continue;
     178             :             }
     179             :           else
     180             :             {
     181             :               /* Reduce the exponent of $q$ in $ord(b)$ until is it 
     182             :                  equal to that in $n$ */
     183       51235 :               for ( ; l > k; l--)
     184             :                 {
     185             : #ifdef PARI
     186             :                   printf ("Exponentiating %lu by %lu\n", b, q);
     187             : #endif
     188       16637 :                   b = sp_pow (b, q, sp, spm->mul_c);
     189             :                 }
     190             : #ifdef PARI
     191             :               printf ("New b = %lu\n", b);
     192             : #endif
     193             :             }
     194      218468 :           do {nc /= q;} while (nc % q == 0); /* Divide out all q from nc */
     195      287149 :           while (sc % q == (sp_t) 0) /* Divide out all q from sc */
     196      252551 :             sc /= q;
     197             :         }
     198             :     }
     199             :   
     200       26236 :   b = sp_pow (b, sc, sp, spm->mul_c);
     201             : #ifdef PARI
     202             :   printf ("znorder(Mod(%lu, sp)) == n /* PARI */\n", b, sp, n);
     203             : #endif
     204             : 
     205             :   /* turn this into a primitive n'th root of unity mod p */
     206       26236 :   spm->prim_root = b;
     207       26236 :   spm->inv_prim_root = sp_inv (b, sp, spm->mul_c);
     208             : 
     209             :   /* initialize auxiliary data for all supported power-of-2 NTT sizes */
     210       26236 :   ntt_power = 0;
     211             :   while (1)
     212             :     {
     213      226892 :       if (n & (1 << ntt_power))
     214       26236 :         break;
     215      200656 :       ntt_power++;
     216             :     }
     217             : 
     218       26236 :   if (nttdata_init (sp, spm->mul_c, 
     219             :                     sp_pow (spm->prim_root, 
     220             :                             n >> ntt_power, sp, spm->mul_c),
     221       26236 :                     ntt_power, spm->nttdata, 
     222             :                     NTT_GFP_TWIDDLE_DIF_BREAKOVER))
     223             :     {
     224       26236 :       if (nttdata_init (sp, spm->mul_c, 
     225             :                         sp_pow (spm->inv_prim_root, 
     226             :                                 n >> ntt_power, sp, spm->mul_c),
     227       26236 :                         ntt_power, spm->inttdata, 
     228             :                         NTT_GFP_TWIDDLE_DIT_BREAKOVER))
     229             :         {
     230       26236 :           spm->scratch = (spv_t) sp_aligned_malloc (MAX_NTT_BLOCK_SIZE *
     231             :                                                     sizeof(sp_t));
     232       26236 :           if (spm->scratch != NULL)
     233       26236 :             return spm;
     234           0 :           nttdata_clear (spm->inttdata);
     235             :         }
     236           0 :       nttdata_clear (spm->nttdata);
     237             :     }
     238           0 :   free (spm);
     239           0 :   return NULL;
     240             : }
     241             : 
     242             : void
     243       26236 : spm_clear (spm_t spm)
     244             : {
     245       26236 :   nttdata_clear (spm->nttdata);
     246       26236 :   nttdata_clear (spm->inttdata);
     247       26236 :   sp_aligned_free (spm->scratch);
     248       26236 :   free (spm);
     249       26236 : }

Generated by: LCOV version 1.14