Line data Source code
1 : /* spm.c - "small prime modulus" functions to precompute an inverse and a
2 : primitive root for a small prime
3 :
4 : Copyright 2005, 2006, 2008, 2009, 2010, 2012 Dave Newman, Jason Papadopoulos,
5 : Paul Zimmermann, Alexander Kruppa.
6 :
7 : The SP Library is free software; you can redistribute it and/or modify
8 : it under the terms of the GNU Lesser General Public License as published by
9 : the Free Software Foundation; either version 3 of the License, or (at your
10 : option) any later version.
11 :
12 : The SP Library is distributed in the hope that it will be useful, but
13 : WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
14 : or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
15 : License for more details.
16 :
17 : You should have received a copy of the GNU Lesser General Public License
18 : along with the SP Library; see the file COPYING.LIB. If not, write to
19 : the Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston,
20 : MA 02110-1301, USA. */
21 :
22 : #include <stdlib.h>
23 : #include "sp.h"
24 :
25 : /* Returns the exponent of $q$ in the factorisation of $n$ */
26 : static int
27 88852 : exponent (const sp_t q, sp_t n)
28 : {
29 : int i;
30 735821 : for (i = 0; n % q == (sp_t) 0; i++, n /= q);
31 88852 : return i;
32 : }
33 :
34 : /* Returns i so that ord(a) = q^i. This assumes that ord(a) is indeed
35 : a low power of q. */
36 : static int
37 88852 : ordpow (const sp_t q, sp_t a, const sp_t sp, const sp_t mul_c)
38 : {
39 88852 : int i = 0;
40 644617 : for (i = 0; a != (sp_t) 1; i++, a = sp_pow (a, q, sp, mul_c));
41 88852 : return i;
42 : }
43 :
44 : /* initialize roots of unity and twiddle factors for one NTT.
45 : If successful, returns 1.
46 : If unsuccessful, returns 0 (and frees allocated memory) */
47 : static int
48 52472 : nttdata_init (const sp_t sp, const sp_t mul_c,
49 : const sp_t prim_root, const spv_size_t log2_len,
50 : sp_nttdata_t data, spv_size_t breakover)
51 : {
52 : spv_t r, t;
53 : spv_size_t i, j, k;
54 :
55 52472 : r = data->ntt_roots =
56 52472 : (spv_t) sp_aligned_malloc ((log2_len + 1) * sizeof(sp_t));
57 52472 : if (r == NULL)
58 0 : return 0;
59 :
60 52472 : i = log2_len;
61 52472 : r[i] = prim_root;
62 453784 : for (i--; (int)i >= 0; i--)
63 401312 : r[i] = sp_sqr (r[i+1], sp, mul_c);
64 :
65 52472 : k = MIN(log2_len, breakover);
66 52472 : t = data->twiddle = (spv_t) sp_aligned_malloc (sizeof(sp_t) << k);
67 52472 : if (t == NULL)
68 : {
69 0 : sp_aligned_free (r);
70 0 : return 0;
71 : }
72 52472 : data->twiddle_size = 1 << k;
73 :
74 450276 : for (i = k; i; i--)
75 : {
76 397804 : sp_t w = r[i];
77 80195160 : for (j = t[0] = 1; j < ((spv_size_t) 1 << (i-1)); j++)
78 79797356 : t[j] = sp_mul (t[j-1], w, sp, mul_c);
79 :
80 397804 : t += j;
81 : }
82 52472 : return 1;
83 : }
84 :
85 : static void
86 52472 : nttdata_clear(sp_nttdata_t data)
87 : {
88 52472 : sp_aligned_free(data->ntt_roots);
89 52472 : sp_aligned_free(data->twiddle);
90 52472 : }
91 :
92 : /* Compute some constants, including a primitive n'th root of unity.
93 : Returns NULL in case of error.
94 : k is the number of limbs of the number to factor
95 : */
96 : spm_t
97 26236 : spm_init (spv_size_t n, sp_t sp, mp_size_t k)
98 : {
99 : sp_t a, b, bd, sc;
100 : spv_size_t q, nc, ntt_power;
101 26236 : spm_t spm = (spm_t) malloc (sizeof (__spm_struct));
102 26236 : if (spm == NULL)
103 0 : return NULL;
104 :
105 : ASSERT (sp % (sp_t) n == (sp_t) 1);
106 :
107 26236 : spm->sp = sp;
108 26236 : sp_reciprocal (spm->mul_c, sp);
109 :
110 : /* compute spm->invm = -1/p mod B where B = 2^GMP_NUMB_BITS */
111 26236 : a = sp_pow (2, GMP_NUMB_BITS, sp, spm->mul_c); /* a = B mod p */
112 26236 : a = sp_inv (a, sp, spm->mul_c); /* a = 1/B mod p */
113 : /* a = 1/B mod p thus B*a - 1 = invm*p */
114 26236 : a --;
115 26236 : b = GMP_NUMB_MASK;
116 : #if SP_NUMB_BITS == W_TYPE_SIZE - 2
117 26236 : a = (a << 2) + (b >> (GMP_NUMB_BITS - 2));
118 26236 : b = (b << 2) & GMP_NUMB_MASK;
119 26236 : udiv_qrnnd (bd, sc, a, b, sp << 2);
120 : #else
121 : a = (a << 1) + (b >> (GMP_NUMB_BITS - 1));
122 : b = (b << 1) & GMP_NUMB_MASK;
123 : udiv_qrnnd (bd, sc, a, b, sp << 1);
124 : #endif
125 26236 : spm->invm = bd;
126 :
127 : /* compute spm->Bpow = B^(k+1) mod p */
128 26236 : spm->Bpow = sp_pow (2, GMP_NUMB_BITS * (k + 1), sp, spm->mul_c);
129 :
130 : /* find an $n$-th primitive root $a$ of unity $(mod sp)$. */
131 :
132 : /* Construct a $b$ whose order $(mod sp)$ is equal to $n$.
133 : We try different $a$ values and test if the exponent of $q$ in $ord(a)$
134 : is at least as large as in $n$. If it isn't, we move to another $a$.
135 : If it is, we optionally exponentiate to make the exponents equal and
136 : test for the remaining $q$'s.
137 : We assume that the largest prime dividing $n$ is very small,
138 : so no optimizations in factoring n are made. */
139 26236 : a = 2;
140 26236 : b = a;
141 26236 : nc = n; /* nc is remaining cofactor of n */
142 26236 : q = 2;
143 26236 : sc = sp - 1;
144 : #ifdef PARI
145 : printf ("/* spm_init */ n = %lu; sp = %lu; /* PARI */\n", n, sp);
146 : printf ("exponent(a,b) = {local(i); while(b%%a == 0,i++;b/=a); "
147 : "return(i)} /* PARI */\n");
148 : #endif
149 115088 : for ( ; nc != (spv_size_t) 1; q++)
150 : {
151 88852 : if (nc % q == (spv_size_t) 0)
152 : {
153 88852 : const int k = exponent (q, n); /* q^k || n */
154 : sp_t d;
155 : int l;
156 : #ifdef PARI
157 : printf ("exponent(%lu, n) == %d /* PARI */\n", q, k);
158 : #endif
159 : /* Remove all factors of $q$ from $sp-1$ */
160 779670 : for (d = sp - 1; d % q == (spv_size_t) 0; d /= q);
161 88852 : bd = sp_pow (b, d, sp, spm->mul_c);
162 : /* Now ord(bd) = q^l, q^l || ord(a) */
163 88852 : l = ordpow (q, bd, sp, spm->mul_c);
164 : #ifdef PARI
165 : printf ("exponent(%lu, znorder(Mod(%lu, sp))) == %d /* PARI */\n",
166 : q, b, l);
167 : #endif
168 88852 : if (l < k)
169 : {
170 : /* No good, q appears in ord(a) in a lower power than in n.
171 : Try next $a$ */
172 54254 : a++;
173 54254 : b = a;
174 54254 : nc = n;
175 54254 : q = 1; /* Loop increment following "continue" will make q=2 */
176 54254 : sc = sp - 1;
177 54254 : continue;
178 : }
179 : else
180 : {
181 : /* Reduce the exponent of $q$ in $ord(b)$ until is it
182 : equal to that in $n$ */
183 51235 : for ( ; l > k; l--)
184 : {
185 : #ifdef PARI
186 : printf ("Exponentiating %lu by %lu\n", b, q);
187 : #endif
188 16637 : b = sp_pow (b, q, sp, spm->mul_c);
189 : }
190 : #ifdef PARI
191 : printf ("New b = %lu\n", b);
192 : #endif
193 : }
194 218468 : do {nc /= q;} while (nc % q == 0); /* Divide out all q from nc */
195 287149 : while (sc % q == (sp_t) 0) /* Divide out all q from sc */
196 252551 : sc /= q;
197 : }
198 : }
199 :
200 26236 : b = sp_pow (b, sc, sp, spm->mul_c);
201 : #ifdef PARI
202 : printf ("znorder(Mod(%lu, sp)) == n /* PARI */\n", b, sp, n);
203 : #endif
204 :
205 : /* turn this into a primitive n'th root of unity mod p */
206 26236 : spm->prim_root = b;
207 26236 : spm->inv_prim_root = sp_inv (b, sp, spm->mul_c);
208 :
209 : /* initialize auxiliary data for all supported power-of-2 NTT sizes */
210 26236 : ntt_power = 0;
211 : while (1)
212 : {
213 226892 : if (n & (1 << ntt_power))
214 26236 : break;
215 200656 : ntt_power++;
216 : }
217 :
218 26236 : if (nttdata_init (sp, spm->mul_c,
219 : sp_pow (spm->prim_root,
220 : n >> ntt_power, sp, spm->mul_c),
221 26236 : ntt_power, spm->nttdata,
222 : NTT_GFP_TWIDDLE_DIF_BREAKOVER))
223 : {
224 26236 : if (nttdata_init (sp, spm->mul_c,
225 : sp_pow (spm->inv_prim_root,
226 : n >> ntt_power, sp, spm->mul_c),
227 26236 : ntt_power, spm->inttdata,
228 : NTT_GFP_TWIDDLE_DIT_BREAKOVER))
229 : {
230 26236 : spm->scratch = (spv_t) sp_aligned_malloc (MAX_NTT_BLOCK_SIZE *
231 : sizeof(sp_t));
232 26236 : if (spm->scratch != NULL)
233 26236 : return spm;
234 0 : nttdata_clear (spm->inttdata);
235 : }
236 0 : nttdata_clear (spm->nttdata);
237 : }
238 0 : free (spm);
239 0 : return NULL;
240 : }
241 :
242 : void
243 26236 : spm_clear (spm_t spm)
244 : {
245 26236 : nttdata_clear (spm->nttdata);
246 26236 : nttdata_clear (spm->inttdata);
247 26236 : sp_aligned_free (spm->scratch);
248 26236 : free (spm);
249 26236 : }
|