LCOV - code coverage report
Current view: top level - ecm - mpzspm.c (source / functions) Hit Total Coverage
Test: unnamed Lines: 137 153 89.5 %
Date: 2022-03-21 11:19:20 Functions: 5 5 100.0 %

          Line data    Source code
       1             : /* mpzspm.c - "mpz small prime moduli" - pick a set of small primes large
       2             :    enough to represent a mpzv
       3             : 
       4             : Copyright 2005, 2006, 2007, 2008, 2009, 2010 Dave Newman, Jason Papadopoulos,
       5             : Paul Zimmermann, Alexander Kruppa.
       6             : 
       7             : The SP Library is free software; you can redistribute it and/or modify
       8             : it under the terms of the GNU Lesser General Public License as published by
       9             : the Free Software Foundation; either version 3 of the License, or (at your
      10             : option) any later version.
      11             : 
      12             : The SP Library is distributed in the hope that it will be useful, but
      13             : WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
      14             : or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public
      15             : License for more details.
      16             : 
      17             : You should have received a copy of the GNU Lesser General Public License
      18             : along with the SP Library; see the file COPYING.LIB.  If not, write to
      19             : the Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston,
      20             : MA 02110-1301, USA. */
      21             : 
      22             : #include <stdio.h> /* for printf */
      23             : #include <stdlib.h>
      24             : #include "sp.h"
      25             : #include "ecm-impl.h"
      26             : 
      27             : /* Tables for the maximum possible modulus (in bit size) for different 
      28             :    transform lengths l.
      29             :    The modulus is limited by the condition that primes must be 
      30             :    p_i == 1 (mod l), and \Prod_i p_i >= 4l (modulus * S)^2, 
      31             :    where S=\Sum_i p_i.
      32             :    Hence for each l=2^k, we take the product P and sum S of primes p_i,
      33             :    SP_MIN <= p_i <= SP_MAX and p_i == 1 (mod l), and store 
      34             :    floor (log_2 (sqrt (P / (4l S^2)))) in the table.
      35             :    We only consider power-of-two transform lengths <= 2^31 here.
      36             : 
      37             :    Table entries generated with
      38             :    
      39             :    l=2^k;p=1;P=1;S=0;while(p<=SP_MAX, if(p>=SP_MIN && isprime(p), S+=p; P*=p); \
      40             :    p+=l);print(floor (log (sqrt (P / (4*l * S^2)))/log(2)))
      41             : 
      42             :    in Pari/GP for k=9 ... 24. k<9 simply were doubled and rounded down in 
      43             :    each step.
      44             : 
      45             :    We curently assume that SP_MIN = 2^(SP_NUMB_BITS-1) and 
      46             :    SP_MAX = 2^(SP_NUMB_BITS).
      47             :    
      48             : */
      49             : 
      50             : #if (SP_NUMB_BITS == 30)
      51             : static unsigned long sp_max_modulus_bits[32] = 
      52             :   {0, 380000000, 190000000, 95000000, 48000000, 24000000, 12000000, 6000000, 
      53             :    3000000, 1512786, 756186, 378624, 188661, 93737, 46252, 23342, 11537, 5791, 
      54             :    3070, 1563, 782, 397, 132, 43, 0, 0, 0, 0, 0, 0, 0, 0};
      55             : #elif (SP_NUMB_BITS == 31)
      56             : static unsigned long sp_max_modulus_bits[32] = 
      57             :   {0, 750000000, 380000000, 190000000, 95000000, 48000000, 24000000, 12000000, 
      58             :    6000000, 3028766, 1512573, 756200, 379353, 190044, 94870, 47414, 23322, 
      59             :    11620, 5891, 2910, 1340, 578, 228, 106, 60, 30, 0, 0, 0, 0, 0, 0};
      60             : #elif (SP_NUMB_BITS == 32)
      61             : static unsigned long sp_max_modulus_bits[32] = 
      62             :   {0, 1520000000, 760000000, 380000000, 190000000, 95000000, 48000000, 
      63             :    24000000, 12000000, 6041939, 3022090, 1509176, 752516, 376924, 190107, 
      64             :    95348, 47601, 24253, 11971, 6162, 3087, 1557, 833, 345, 172, 78, 46, 15, 
      65             :    0, 0, 0, 0};
      66             : #elif (SP_NUMB_BITS >= 60)
      67             :   /* There are so many primes, we can do pretty much any modulus with 
      68             :      any transform length. I didn't bother computing the actual values. */
      69             : static unsigned long sp_max_modulus_bits[32] =  
      70             :   {0, ULONG_MAX, ULONG_MAX, ULONG_MAX, ULONG_MAX, ULONG_MAX, ULONG_MAX, 
      71             :    ULONG_MAX, ULONG_MAX, ULONG_MAX, ULONG_MAX, ULONG_MAX, ULONG_MAX, ULONG_MAX, 
      72             :    ULONG_MAX, ULONG_MAX, ULONG_MAX, ULONG_MAX, ULONG_MAX, ULONG_MAX, ULONG_MAX, 
      73             :    ULONG_MAX, ULONG_MAX, ULONG_MAX, ULONG_MAX, ULONG_MAX, ULONG_MAX, ULONG_MAX, 
      74             :    ULONG_MAX, ULONG_MAX, 0, 0};
      75             : #else
      76             : #error Table of maximal modulus for transform lengths not defined for this SP_MIN
      77             : ;
      78             : #endif
      79             : 
      80             : #define CHECK(cond,msg,label)       \
      81             :   if (cond)                         \
      82             :     {                               \
      83             :       outputf (OUTPUT_ERROR, msg);  \
      84             :       goto label;                   \
      85             :     }                               \
      86             : 
      87             : /* Returns the largest possible transform length we can do for modulus
      88             :    without running out of primes */
      89             : 
      90             : spv_size_t
      91         425 : mpzspm_max_len (mpz_t modulus)
      92             : {
      93             :   int i;
      94             :   size_t b;
      95             : 
      96         425 :   b = mpz_sizeinbase (modulus, 2); /* b = floor (log_2 (modulus)) + 1 */
      97             :   /* Transform length 2^k is ok if log2(modulus) <= sp_max_modulus_bits[k]
      98             :      <==> ceil(log2(modulus)) <= sp_max_modulus_bits[k] 
      99             :      <==> floor(log_2(modulus)) + 1 <= sp_max_modulus_bits[k] if modulus 
     100             :      isn't a power of 2 */
     101             :      
     102       12750 :   for (i = 0; i < 30; i++)
     103             :     {
     104       12750 :       if (b > sp_max_modulus_bits[i + 1])
     105         425 :         break;
     106             :     }
     107             : 
     108         425 :   return (spv_size_t)1 << i;
     109             : }
     110             : 
     111             : /* initialize mpzspm->T such that with m[j] := mpzspm->spm[j]->sp
     112             :    T[0][0] = m[0], ..., T[0][n-1] = m[n-1]
     113             :    ...
     114             :    T[d-1][0] = m[0]*...*m[ceil(n/2)-1], T[d-1][1] = m[ceil(n/2)] * ... * m[n-1]
     115             :    T[d][0] = m[0] * ... * m[n-1]
     116             :    where d = ceil(log(n)/log(2)).
     117             :    If n = 5, T[0]: 1, 1, 1, 1, 1
     118             :              T[1]: 2, 2, 1
     119             :              T[2]: 4, 1
     120             : */
     121             : static void
     122        1678 : mpzspm_product_tree_init (mpzspm_t mpzspm)
     123             : {
     124             :   unsigned int d, i, j, oldn;
     125        1678 :   unsigned int n = mpzspm->sp_num;
     126             :   mpzv_t *T;
     127             : 
     128        7653 :   for (i = n, d = 0; i > 1; i = (i + 1) / 2, d ++);
     129        1678 :   if (d <= I0_THRESHOLD)
     130             :     {
     131        1656 :       mpzspm->T = NULL;
     132        1656 :       return;
     133             :     }
     134          22 :   T = (mpzv_t*) malloc ((d + 1) * sizeof (mpzv_t));
     135          22 :   T[0] = (mpzv_t) malloc (n * sizeof (mpz_t));
     136        5992 :   for (j = 0; j < n; j++)
     137             :     {
     138        5970 :       mpz_init (T[0][j]);
     139        5970 :       mpz_set_sp (T[0][j], mpzspm->spm[j]->sp);
     140             :     }
     141         204 :   for (i = 1; i <= d; i++)
     142             :     {
     143         182 :       oldn = n;
     144         182 :       n = (n + 1) / 2;
     145         182 :       T[i] = (mpzv_t) malloc (n * sizeof (mpz_t));
     146        6214 :       for (j = 0; j < n; j++)
     147             :         {
     148        6032 :           mpz_init (T[i][j]);
     149        6032 :           if (2 * j + 1 < oldn)
     150        5948 :             mpz_mul (T[i][j], T[i-1][2*j], T[i-1][2*j+1]);
     151             :           else /* oldn is odd */
     152          84 :             mpz_set (T[i][j], T[i-1][2*j]);
     153             :         }
     154             :     }
     155          22 :   mpzspm->T = T;
     156          22 :   mpzspm->d = d;
     157             : }
     158             : 
     159             : /* This function initializes a mpzspm_t structure which contains the number
     160             :    of small primes, the small primes with associated primitive roots and 
     161             :    precomputed data for the CRT to allow convolution products of length up 
     162             :    to "max_len" with modulus "modulus". 
     163             :    Returns NULL in case of an error. */
     164             : 
     165             : mpzspm_t
     166        1678 : mpzspm_init (spv_size_t max_len, mpz_t modulus)
     167             : {
     168             :   unsigned int ub, i, j;
     169             :   mpz_t P, S, T, mp, mt; /* mp is p as mpz_t, mt is a temp mpz_t */
     170             :   sp_t p, a;
     171             :   mpzspm_t mpzspm;
     172             :   long st;
     173             : 
     174        1678 :   st = cputime ();
     175             : 
     176        1678 :   mpzspm = (mpzspm_t) malloc (sizeof (__mpzspm_struct));
     177        1678 :   if (mpzspm == NULL)
     178           0 :     return NULL;
     179             :   
     180             :   /* Upper bound for the number of primes we need.
     181             :    * Let minp, maxp denote the min, max permissible prime,
     182             :    * S the sum of p_1, p_2, ..., p_ub,
     183             :    * P the product of p_1, p_2, ..., p_ub/
     184             :    * 
     185             :    * Choose ub s.t.
     186             :    *
     187             :    *     ub * log(minp) >= log(4 * max_len * modulus^2 * maxp^4)
     188             :    * 
     189             :    * =>  P >= minp ^ ub >= 4 * max_len * modulus^2 * maxp^4
     190             :    *                    >= 4 * max_len * modulus^2 * (ub * maxp)^2
     191             :    *                    >= 4 * max_len * modulus^2 * S^2
     192             :    * 
     193             :    * So we need at most ub primes to satisfy this condition. */
     194             :   
     195        1678 :   ub = (2 + 2 * mpz_sizeinbase (modulus, 2) + ceil_log_2 (max_len) + \
     196        1678 :       4 * SP_NUMB_BITS) / (SP_NUMB_BITS - 1);
     197             :   
     198        1678 :   mpzspm->spm = (spm_t *) malloc (ub * sizeof (spm_t));
     199        1678 :   CHECK(mpzspm->spm == NULL, "Out of memory in mpzspm_init()\n",
     200             :         error_clear_mpzspm);
     201        1678 :   mpzspm->sp_num = 0;
     202             : 
     203             :   /* product of primes selected so far */
     204        1678 :   mpz_init_set_ui (P, 1UL);
     205             :   /* sum of primes selected so far */
     206        1678 :   mpz_init (S);
     207             :   /* T is len*modulus^2, the upper bound on output coefficients of a 
     208             :      convolution */
     209        1678 :   mpz_init (T); 
     210        1678 :   mpz_mul (T, modulus, modulus);
     211        1678 :   mpz_mul_ui (T, T, max_len);
     212        1678 :   mpz_init (mp);
     213        1678 :   mpz_init (mt);
     214             :   
     215             :   /* find primes congruent to 1 mod max_len so we can do
     216             :    * a ntt of size max_len */
     217             :   /* Find the largest p <= SP_MAX that is p == 1 (mod max_len) */
     218        1678 :   p = ((SP_MAX - 1) / (sp_t) max_len) * (sp_t) max_len + 1;
     219             :   
     220             :   do
     221             :     {
     222      453560 :       while (p >= SP_MIN && p > (sp_t) max_len && !sp_prime(p))
     223      427324 :         p -= (sp_t) max_len;
     224             : 
     225             :       /* all primes must be in range */
     226       26236 :       if (p < SP_MIN || p <= (sp_t) max_len)
     227             :         {
     228           0 :           outputf (OUTPUT_ERROR, 
     229             :                    "not enough primes == 1 (mod %lu) in interval\n", 
     230             :                    (unsigned long) max_len);
     231           0 :           goto error_clear_mpzspm_spm;
     232             :         }
     233             :       
     234       26236 :       mpzspm->spm[mpzspm->sp_num] = spm_init (max_len, p, mpz_size (modulus));
     235       26236 :       CHECK(mpzspm->spm[mpzspm->sp_num] == NULL,
     236             :             "Out of memory in mpzspm_init()\n", error_clear_mpzspm_spm);
     237       26236 :       mpzspm->sp_num++;
     238             :       
     239       26236 :       mpz_set_sp (mp, p);
     240       26236 :       mpz_mul (P, P, mp);
     241       26236 :       mpz_add (S, S, mp);
     242             : 
     243             :       /* we want P > 4 * max_len * (modulus * S)^2. The S^2 term is due to 
     244             :          theorem 3.1 in Bernstein and Sorenson's paper */
     245       26236 :       mpz_mul (T, S, modulus);
     246       26236 :       mpz_mul (T, T, T);
     247       26236 :       mpz_mul_ui (T, T, max_len);
     248       26236 :       mpz_mul_2exp (T, T, 2UL);
     249             :       
     250       26236 :       p -= (sp_t) max_len;
     251             :     }
     252       26236 :   while (mpz_cmp (P, T) <= 0);
     253             : 
     254             :   /* we add the test_verbose() call to avoid calls to cputime() even if
     255             :      nothing is printed */
     256        1678 :   if (test_verbose (OUTPUT_DEVVERBOSE))
     257          60 :     outputf (OUTPUT_DEVVERBOSE, "mpzspm_init: finding %u primes took %lums\n", 
     258          60 :              mpzspm->sp_num, cputime() - st);
     259             : 
     260        1678 :   mpz_init_set (mpzspm->modulus, modulus);
     261             :   
     262        1678 :   mpzspm->max_ntt_size = max_len;
     263             :   
     264        1678 :   mpzspm->crt1 = (mpzv_t) malloc (mpzspm->sp_num * sizeof (mpz_t));
     265        1678 :   mpzspm->crt2 = (mpzv_t) malloc ((mpzspm->sp_num + 2) * sizeof (mpz_t));
     266        1678 :   mpzspm->crt3 = (spv_t) malloc (mpzspm->sp_num * sizeof (sp_t));
     267        1678 :   mpzspm->crt4 = (spv_t *) malloc (mpzspm->sp_num * sizeof (spv_t));
     268        1678 :   mpzspm->crt5 = (spv_t) malloc (mpzspm->sp_num * sizeof (sp_t));
     269        1678 :   CHECK(mpzspm->crt1 == NULL || mpzspm->crt2 == NULL || mpzspm->crt3 == NULL ||
     270             :         mpzspm->crt4 == NULL || mpzspm->crt5 == NULL,
     271             :         "Out of memory in mpzspm_init()\n", error_clear_crt);
     272             : 
     273       27914 :   for (i = 0; i < mpzspm->sp_num; i++)
     274       26236 :     mpzspm->crt4[i] = NULL;
     275       27914 :   for (i = 0; i < mpzspm->sp_num; i++)
     276             :     {
     277       26236 :       mpzspm->crt4[i] = (spv_t) malloc (mpzspm->sp_num * sizeof (sp_t));
     278       26236 :       CHECK(mpzspm->crt4[i] == NULL, "Out of memory in mpzspm_init()\n",
     279             :             error_clear_crt);
     280             :     }
     281             :   
     282       27914 :   for (i = 0; i < mpzspm->sp_num; i++)
     283             :     {
     284       26236 :       p = mpzspm->spm[i]->sp;
     285       26236 :       mpz_set_sp (mp, p);
     286             :       
     287             :       /* crt3[i] = (P / p)^{-1} mod p */
     288       26236 :       mpz_fdiv_q (T, P, mp);
     289       26236 :       mpz_fdiv_r (mt, T, mp);
     290       26236 :       a = mpz_get_sp (mt);
     291       26236 :       mpzspm->crt3[i] = sp_inv (a, p, mpzspm->spm[i]->mul_c);
     292             :      
     293             :       /* crt1[i] = (P / p) mod modulus */
     294       26236 :       mpz_init (mpzspm->crt1[i]);
     295       26236 :       mpz_mod (mpzspm->crt1[i], T, modulus);
     296             : 
     297             :       /* crt4[i][j] = ((P / p[i]) mod modulus) mod p[j] */
     298     3564590 :       for (j = 0; j < mpzspm->sp_num; j++)
     299             :         {
     300     3538354 :           mpz_set_sp (mp, mpzspm->spm[j]->sp);
     301     3538354 :           mpz_fdiv_r (mt, mpzspm->crt1[i], mp);
     302     3538354 :           mpzspm->crt4[j][i] = mpz_get_sp (mt);
     303             :         }
     304             :       
     305             :       /* crt5[i] = (-P mod modulus) mod p */
     306       26236 :       mpz_mod (T, P, modulus);
     307       26236 :       mpz_sub (T, modulus, T);
     308       26236 :       mpz_set_sp (mp, p);
     309       26236 :       mpz_fdiv_r (mt, T, mp);
     310       26236 :       mpzspm->crt5[i] = mpz_get_sp (mt);
     311             :     }
     312             :   
     313        1678 :   mpz_set_ui (T, 0);
     314             : 
     315             :   /* set crt2[i] = -i*P mod modulus */
     316       31270 :   for (i = 0; i < mpzspm->sp_num + 2; i++)
     317             :     {
     318       29592 :       mpz_mod (T, T, modulus);
     319       29592 :       mpz_init_set (mpzspm->crt2[i], T);
     320       29592 :       mpz_sub (T, T, P);
     321             :     }
     322             :   
     323        1678 :   mpz_clear (mp);
     324        1678 :   mpz_clear (mt);
     325        1678 :   mpz_clear (P);
     326        1678 :   mpz_clear (S);
     327        1678 :   mpz_clear (T);
     328             : 
     329        1678 :   mpzspm_product_tree_init (mpzspm);
     330             : 
     331        1678 :   if (test_verbose (OUTPUT_DEVVERBOSE))
     332          60 :     outputf (OUTPUT_DEVVERBOSE, "mpzspm_init took %lums\n", cputime() - st);
     333             : 
     334        1678 :   return mpzspm;
     335             :   
     336             :   /* Error cases: free memory we allocated so far */
     337             : 
     338           0 :   error_clear_crt:
     339           0 :   free (mpzspm->crt1);
     340           0 :   free (mpzspm->crt2);
     341           0 :   free (mpzspm->crt3);
     342           0 :   free (mpzspm->crt4);
     343           0 :   free (mpzspm->crt5);
     344             :   
     345           0 :   error_clear_mpzspm_spm:
     346           0 :   for (i = 0; i < mpzspm->sp_num; i++)
     347           0 :     free (mpzspm->spm[i]);
     348           0 :   free (mpzspm->spm);
     349             : 
     350           0 :   error_clear_mpzspm:
     351           0 :   free (mpzspm);
     352             : 
     353           0 :   return NULL;
     354             : }
     355             : 
     356             : /* clear the product tree T */
     357             : static void
     358        1678 : mpzspm_product_tree_clear (mpzspm_t mpzspm)
     359             : {
     360             :   unsigned int i, j;
     361        1678 :   unsigned int n = mpzspm->sp_num;
     362        1678 :   unsigned int d = mpzspm->d;
     363        1678 :   mpzv_t *T = mpzspm->T;
     364             : 
     365        1678 :   if (T == NULL) /* use the slow method */
     366        1656 :     return;
     367             : 
     368         226 :   for (i = 0; i <= d; i++)
     369             :     {
     370       12206 :       for (j = 0; j < n; j++)
     371       12002 :         mpz_clear (T[i][j]);
     372         204 :       free (T[i]);
     373         204 :       n = (n + 1) / 2;
     374             :     }
     375          22 :   free (T);
     376             : }
     377             : 
     378        1678 : void mpzspm_clear (mpzspm_t mpzspm)
     379             : {
     380             :   unsigned int i;
     381             : 
     382        1678 :   mpzspm_product_tree_clear (mpzspm);
     383             : 
     384       27914 :   for (i = 0; i < mpzspm->sp_num; i++)
     385             :     {
     386       26236 :       mpz_clear (mpzspm->crt1[i]);
     387       26236 :       free (mpzspm->crt4[i]);
     388       26236 :       spm_clear (mpzspm->spm[i]);
     389             :     }
     390             : 
     391       31270 :   for (i = 0; i < mpzspm->sp_num + 2; i++)
     392       29592 :     mpz_clear (mpzspm->crt2[i]);
     393             :   
     394        1678 :   free (mpzspm->crt1);
     395        1678 :   free (mpzspm->crt2);
     396        1678 :   free (mpzspm->crt3);
     397        1678 :   free (mpzspm->crt4);
     398        1678 :   free (mpzspm->crt5);
     399             :   
     400        1678 :   mpz_clear (mpzspm->modulus);
     401        1678 :   free (mpzspm->spm);
     402        1678 :   free (mpzspm);
     403        1678 : }
     404             : 

Generated by: LCOV version 1.14