Line data Source code
1 : /* mpzspv.c - "mpz small prime polynomial" functions for arithmetic on mpzv's
2 : reduced modulo a mpzspm
3 :
4 : Copyright 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012 Dave Newman,
5 : Jason Papadopoulos, Alexander Kruppa, Paul Zimmermann.
6 :
7 : The SP Library is free software; you can redistribute it and/or modify
8 : it under the terms of the GNU Lesser General Public License as published by
9 : the Free Software Foundation; either version 3 of the License, or (at your
10 : option) any later version.
11 :
12 : The SP Library is distributed in the hope that it will be useful, but
13 : WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
14 : or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
15 : License for more details.
16 :
17 : You should have received a copy of the GNU Lesser General Public License
18 : along with the SP Library; see the file COPYING.LIB. If not, write to
19 : the Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston,
20 : MA 02110-1301, USA. */
21 :
22 : #include <stdio.h> /* for stderr */
23 : #include <stdlib.h>
24 : #include <string.h> /* for memset */
25 : #include "ecm-impl.h"
26 : #include "sp.h"
27 :
28 : mpzspv_t
29 85419 : mpzspv_init (spv_size_t len, mpzspm_t mpzspm)
30 : {
31 : unsigned int i;
32 85419 : mpzspv_t x = (mpzspv_t) malloc (mpzspm->sp_num * sizeof (spv_t));
33 :
34 85419 : if (x == NULL)
35 0 : return NULL;
36 :
37 1926105 : for (i = 0; i < mpzspm->sp_num; i++)
38 : {
39 1840686 : x[i] = (spv_t) sp_aligned_malloc (len * sizeof (sp_t));
40 :
41 1840686 : if (x[i] == NULL)
42 : {
43 0 : while (i--)
44 0 : sp_aligned_free (x[i]);
45 :
46 0 : free (x);
47 0 : return NULL;
48 : }
49 : }
50 :
51 85419 : return x;
52 : }
53 :
54 : void
55 85419 : mpzspv_clear (mpzspv_t x, mpzspm_t mpzspm)
56 : {
57 : unsigned int i;
58 :
59 : ASSERT (mpzspv_verify (x, 0, 0, mpzspm));
60 :
61 1926105 : for (i = 0; i < mpzspm->sp_num; i++)
62 1840686 : sp_aligned_free (x[i]);
63 :
64 85419 : free (x);
65 85419 : }
66 :
67 : #ifdef WANT_ASSERT
68 : /* check that:
69 : * - each of the spv's is at least offset + len long
70 : * - the data specified by (offset, len) is correctly normalised in the
71 : * range [0, sp)
72 : *
73 : * return 1 for success, 0 for failure */
74 :
75 : int
76 : mpzspv_verify (mpzspv_t x, spv_size_t offset, spv_size_t len, mpzspm_t mpzspm)
77 : {
78 : unsigned int i;
79 : spv_size_t j;
80 :
81 : for (i = 0; i < mpzspm->sp_num; i++)
82 : {
83 : for (j = offset; j < offset + len; j++)
84 : if (x[i][j] >= mpzspm->spm[i]->sp)
85 : return 0;
86 : }
87 :
88 : return 1;
89 : }
90 : #endif
91 :
92 : void
93 281399 : mpzspv_set (mpzspv_t r, spv_size_t r_offset, mpzspv_t x, spv_size_t x_offset,
94 : spv_size_t len, mpzspm_t mpzspm)
95 : {
96 : unsigned int i;
97 :
98 : ASSERT (mpzspv_verify (r, r_offset + len, 0, mpzspm));
99 : ASSERT (mpzspv_verify (x, x_offset, len, mpzspm));
100 :
101 6169480 : for (i = 0; i < mpzspm->sp_num; i++)
102 5888081 : spv_set (r[i] + r_offset, x[i] + x_offset, len);
103 281399 : }
104 :
105 : #if 0
106 : void
107 : mpzspv_revcopy (mpzspv_t r, spv_size_t r_offset, mpzspv_t x,
108 : spv_size_t x_offset, spv_size_t len, mpzspm_t mpzspm)
109 : {
110 : unsigned int i;
111 :
112 : ASSERT (mpzspv_verify (r, r_offset + len, 0, mpzspm));
113 : ASSERT (mpzspv_verify (x, x_offset, len, mpzspm));
114 :
115 : for (i = 0; i < mpzspm->sp_num; i++)
116 : spv_rev (r[i] + r_offset, x[i] + x_offset, len);
117 : }
118 : #endif
119 :
120 : void
121 24254 : mpzspv_set_sp (mpzspv_t r, spv_size_t offset, sp_t c, spv_size_t len,
122 : mpzspm_t mpzspm)
123 : {
124 : unsigned int i;
125 :
126 : ASSERT (mpzspv_verify (r, offset + len, 0, mpzspm));
127 : ASSERT (c < SP_MIN); /* not strictly necessary but avoids mod functions */
128 :
129 519868 : for (i = 0; i < mpzspm->sp_num; i++)
130 495614 : spv_set_sp (r[i] + offset, c, len);
131 24254 : }
132 :
133 : void
134 620 : mpzspv_neg (mpzspv_t r, spv_size_t r_offset, mpzspv_t x, spv_size_t x_offset,
135 : spv_size_t len, mpzspm_t mpzspm)
136 : {
137 : unsigned int i;
138 :
139 : ASSERT (mpzspv_verify (r, r_offset + len, 0, mpzspm));
140 : ASSERT (mpzspv_verify (x, x_offset, len, mpzspm));
141 :
142 7135 : for (i = 0; i < mpzspm->sp_num; i++)
143 6515 : spv_neg (r[i] + r_offset, x[i] + x_offset, len, mpzspm->spm[i]->sp);
144 620 : }
145 :
146 : void
147 266 : mpzspv_add (mpzspv_t r, spv_size_t r_offset, mpzspv_t x, spv_size_t x_offset,
148 : mpzspv_t y, spv_size_t y_offset, spv_size_t len, mpzspm_t mpzspm)
149 : {
150 : unsigned int i;
151 :
152 : ASSERT (mpzspv_verify (r, r_offset + len, 0, mpzspm));
153 : ASSERT (mpzspv_verify (x, x_offset, len, mpzspm));
154 :
155 2889 : for (i = 0; i < mpzspm->sp_num; i++)
156 2623 : spv_add (r[i] + r_offset, x[i] + x_offset, y[i] + y_offset, len,
157 2623 : mpzspm->spm[i]->sp);
158 266 : }
159 :
160 : void
161 916 : mpzspv_reverse (mpzspv_t x, spv_size_t offset, spv_size_t len, mpzspm_t mpzspm)
162 : {
163 : unsigned int i;
164 : spv_size_t j;
165 : sp_t t;
166 : spv_t spv;
167 :
168 : ASSERT (mpzspv_verify (x, offset, len, mpzspm));
169 :
170 15174 : for (i = 0; i < mpzspm->sp_num; i++)
171 : {
172 14258 : spv = x[i] + offset;
173 10276300 : for (j = 0; j < len - 1 - j; j++)
174 : {
175 10262042 : t = spv[j];
176 10262042 : spv[j] = spv[len - 1 - j];
177 10262042 : spv[len - 1 - j] = t;
178 : }
179 : }
180 916 : }
181 :
182 : /* Return {xp, xn} mod p.
183 : Assume 2p < B where B = 2^GMP_NUMB_LIMB.
184 : We first compute {xp, xn} / B^n mod p using Montgomery reduction,
185 : where the number N to factor has n limbs.
186 : Then we multiply by B^(n+1) mod p (precomputed) and divide by B mod p.
187 : Assume invm = -1/p mod B and Bpow = B^n mod p */
188 : static mp_limb_t
189 612931765 : ecm_mod_1 (mp_ptr xp, mp_size_t xn, mp_limb_t p, mp_size_t n,
190 : mp_limb_t invm, mp_limb_t Bpow)
191 : {
192 : mp_limb_t q, cy, hi, lo, x0, x1;
193 :
194 612931765 : if (xn == 0)
195 0 : return 0;
196 :
197 : /* the code below assumes xn <= n+1, thus we call mpn_mod_1 otherwise,
198 : but this should never (or rarely) happen */
199 612931765 : if (xn > n + 1)
200 119553 : return mpn_mod_1 (xp, xn, p);
201 :
202 612812212 : x0 = xp[0];
203 612812212 : cy = (mp_limb_t) 0;
204 5340841149 : while (n-- > 0)
205 : {
206 : /* Invariant: cy is the input carry on xp[1], x0 is xp[0] */
207 4728028937 : x1 = (xn > 1) ? xp[1] : 0;
208 4728028937 : q = x0 * invm; /* q = -x0/p mod B */
209 4728028937 : umul_ppmm (hi, lo, q, p); /* hi*B + lo = -x0 mod B */
210 : /* Add hi*B + lo to x1*B + x0. Since p <= B-2 we have
211 : hi*B + lo <= (B-1)(B-2) = B^2-3B+2, thus hi <= B-3 */
212 4728028937 : hi += cy + (lo != 0); /* cannot overflow */
213 4728028937 : x0 = x1 + hi;
214 4728028937 : cy = x0 < hi;
215 4728028937 : xn --;
216 4728028937 : xp ++;
217 : }
218 612812212 : if (cy != 0)
219 10314 : x0 -= p;
220 : /* now x0 = {xp, xn} / B^n mod p */
221 612812212 : umul_ppmm (x1, x0, x0, Bpow);
222 : /* since Bpow < p, x1 <= p-1 */
223 612812212 : q = x0 * invm;
224 612812212 : umul_ppmm (hi, lo, q, p);
225 : /* hi <= p-1 thus hi+x1+1 < 2p-1 < B */
226 612812212 : hi = hi + x1 + (lo != 0);
227 646269133 : while (hi >= p)
228 33456921 : hi -= p;
229 612812212 : return hi;
230 : }
231 :
232 : #ifdef TIMING_CRT
233 : int mpzspv_from_mpzv_slow_time = 0;
234 : int mpzspv_to_mpzv_time = 0;
235 : int mpzspv_normalise_time = 0;
236 : #endif
237 :
238 : /* convert mpzvi to CRT representation, naive version */
239 : static void
240 46284690 : mpzspv_from_mpzv_slow (mpzspv_t x, const spv_size_t offset, mpz_t mpzvi,
241 : mpzspm_t mpzspm)
242 : {
243 46284690 : const unsigned int sp_num = mpzspm->sp_num;
244 : unsigned int j;
245 46284690 : mp_size_t n = mpz_size (mpzspm->modulus);
246 :
247 : #ifdef TIMING_CRT
248 : mpzspv_from_mpzv_slow_time -= cputime ();
249 : #endif
250 659216455 : for (j = 0; j < sp_num; j++)
251 612931765 : x[j][offset] = ecm_mod_1 (PTR(mpzvi), SIZ(mpzvi),
252 612931765 : (mp_limb_t) mpzspm->spm[j]->sp, n,
253 612931765 : mpzspm->spm[j]->invm, mpzspm->spm[j]->Bpow);
254 : #ifdef TIMING_CRT
255 : mpzspv_from_mpzv_slow_time += cputime ();
256 : #endif
257 : /* The typecast to mp_limb_t assumes that mp_limb_t is at least
258 : as wide as sp_t */
259 46284690 : }
260 :
261 : /* convert mpzvi to CRT representation, fast version, assumes
262 : mpzspm->T has been precomputed (see mpzspm.c).
263 : Warning: this function should be thread-safe, since it might be called
264 : simultaneously by several threads. */
265 : static void
266 1560 : mpzspv_from_mpzv_fast (mpzspv_t x, const spv_size_t offset, mpz_t mpzvi,
267 : mpzspm_t mpzspm)
268 : {
269 1560 : const unsigned int sp_num = mpzspm->sp_num;
270 1560 : unsigned int i, j, k, i0 = I0_THRESHOLD, I0;
271 1560 : mpzv_t *T = mpzspm->T;
272 : mpz_t *U;
273 1560 : unsigned int d = mpzspm->d, ni;
274 :
275 1560 : U = malloc (sp_num * sizeof (mpz_t));
276 262308 : for (j = 0; j < sp_num; j++)
277 260748 : mpz_init (U[j]);
278 :
279 : ASSERT (d > i0);
280 :
281 : /* initially we split mpzvi in two */
282 1560 : ni = 1 << (d - 1);
283 1560 : mpz_mod (U[0], mpzvi, T[d-1][0]);
284 1560 : mpz_mod (U[ni], mpzvi, T[d-1][1]);
285 1674 : for (i = d-1; i-- > i0;)
286 : { /* goes down from depth i+1 to i */
287 114 : ni = 1 << i;
288 380 : for (j = k = 0; j + ni < sp_num; j += 2*ni, k += 2)
289 : {
290 266 : mpz_mod (U[j+ni], U[j], T[i][k+1]);
291 266 : mpz_mod (U[j], U[j], T[i][k]);
292 : }
293 : /* for the last entry U[j] if j < sp_num, there is nothing to do */
294 : }
295 : /* last steps */
296 1560 : I0 = 1 << i0;
297 4946 : for (j = 0; j < sp_num; j += I0)
298 : {
299 264134 : for (k = j; k < j + I0 && k < sp_num; k++)
300 260748 : x[k][offset] = mpn_mod_1 (PTR(U[j]), SIZ(U[j]),
301 260748 : (mp_limb_t) mpzspm->spm[k]->sp);
302 : }
303 : /* The typecast to mp_limb_t assumes that mp_limb_t is at least
304 : as wide as sp_t */
305 :
306 262308 : for (j = 0; j < sp_num; j++)
307 260748 : mpz_clear (U[j]);
308 1560 : free(U);
309 1560 : }
310 :
311 : #if defined(TRACE_mpzspv_from_mpzv) || defined(TRACE_ntt_sqr_reciprocal)
312 : static void
313 : ntt_print_vec (const char *msg, const spv_t spv, const spv_size_t l,
314 : const sp_t p)
315 : {
316 : spv_size_t i;
317 :
318 : /* Warning: on some computers, for example gcc49.fsffrance.org,
319 : "unsigned long" might be shorter than "sp_t" */
320 : gmp_printf ("%s [%Nd", msg, (mp_ptr) spv, 1);
321 : for (i = 1; i < l; i++)
322 : gmp_printf (", %Nd", (mp_ptr) spv + i, 1);
323 : printf ("] (mod %llu)\n", (long long unsigned int) p);
324 : }
325 : #endif
326 :
327 : /* convert an array of len mpz_t numbers to CRT representation modulo
328 : sp_num moduli */
329 : void
330 18533786 : mpzspv_from_mpzv (mpzspv_t x, const spv_size_t offset, const mpzv_t mpzv,
331 : const spv_size_t len, mpzspm_t mpzspm)
332 : {
333 18533786 : const unsigned int sp_num = mpzspm->sp_num;
334 : long i;
335 :
336 : ASSERT (mpzspv_verify (x, offset + len, 0, mpzspm));
337 : ASSERT (sizeof (mp_limb_t) >= sizeof (sp_t));
338 :
339 : #ifdef TRACE_mpzspv_from_mpzv
340 : for (i = 0; i < (long) len; i++)
341 : gmp_printf ("mpzspv_from_mpzv: mpzv[%ld] = %Zd\n", i, mpzv[i]);
342 : #endif
343 :
344 : #if defined(_OPENMP)
345 : #pragma omp parallel private(i) if (len > 16384)
346 : {
347 : /* Multi-threading with dynamic scheduling slows things down */
348 : #pragma omp for schedule(static)
349 : #endif
350 64820855 : for (i = 0; i < (long) len; i++)
351 : {
352 : unsigned int j;
353 46287069 : if (mpz_sgn (mpzv[i]) == 0)
354 : {
355 8464 : for (j = 0; j < sp_num; j++)
356 7645 : x[j][i + offset] = 0;
357 : }
358 : else
359 : {
360 : ASSERT(mpz_sgn (mpzv[i]) > 0); /* We can't handle negative values */
361 46286250 : if (mpzspm->T == NULL)
362 46284690 : mpzspv_from_mpzv_slow (x, i + offset, mpzv[i], mpzspm);
363 : else
364 1560 : mpzspv_from_mpzv_fast (x, i + offset, mpzv[i], mpzspm);
365 : }
366 : }
367 : #if defined(_OPENMP)
368 : }
369 : #endif
370 :
371 : #ifdef TRACE_mpzspv_from_mpzv
372 : for (i = 0; i < (long) sp_num; i++)
373 : ntt_print_vec ("mpzspv_from_mpzv: ", x[i] + offset, len, mpzspm->spm[i]->sp);
374 : #endif
375 18533786 : }
376 :
377 : /* Convert the len residues x[][offset..offset+len-1] from "spv" (RNS) format
378 : * to mpz_t format.
379 : * See: Daniel J. Bernstein and Jonathan P. Sorenson,
380 : * Modular Exponentiation via the explicit Chinese Remainder Theorem,
381 : * Mathematics of Computation 2007,
382 : * Theorem 2.1: Let p_1, ..., p_s be pairwise coprime integers. Write
383 : * P = p_1 * ... * p_s. Let q_1, ..., q_s be integers with
384 : * q_iP/p_i = 1 mod p_i. Let u be an integer with |u| < P/2. Let u_1, ..., u_s
385 : * with u = u_i mod p_i. Let t_1, ..., t_s be integers with
386 : * t_i = u_i q_i mod p_i. Then u = P \alpha - P round(\alpha) where
387 : * \alpha = \sum_i t_i/p_i
388 : *
389 : * time: O(len * sp_num^2) where sp_num is proportional to the modulus size
390 : * memory: MPZSPV_NORMALISE_STRIDE floats */
391 : void
392 86720 : mpzspv_to_mpzv (mpzspv_t x, spv_size_t offset, mpzv_t mpzv,
393 : spv_size_t len, mpzspm_t mpzspm)
394 : {
395 : unsigned int i;
396 : spv_size_t k, l;
397 86720 : float *f = (float *) malloc (MPZSPV_NORMALISE_STRIDE * sizeof (float));
398 : float prime_recip;
399 : sp_t t;
400 86720 : spm_t *spm = mpzspm->spm;
401 : mpz_t mt;
402 :
403 86720 : if (f == NULL)
404 : {
405 0 : fprintf (stderr, "Cannot allocate memory in mpzspv_to_mpzv\n");
406 0 : exit (1);
407 : }
408 :
409 : ASSERT (mpzspv_verify (x, offset, len, mpzspm));
410 86720 : ASSERT_ALWAYS(mpzspm->sp_num <= 1677721);
411 :
412 : #ifdef TIMING_CRT
413 : mpzspv_to_mpzv_time -= cputime ();
414 : #endif
415 86720 : mpz_init (mt);
416 310194 : for (l = 0; l < len; l += MPZSPV_NORMALISE_STRIDE)
417 : {
418 223474 : spv_size_t stride = MIN (MPZSPV_NORMALISE_STRIDE, len - l);
419 :
420 : /* we apply the above theorem to mpzv[l]...mpzv[l+stride-1] at once */
421 27369203 : for (k = 0; k < stride; k++)
422 : {
423 27145729 : f[k] = 0.5; /* this is performed len times */
424 27145729 : mpz_set_ui (mpzv[k + l], 0);
425 : }
426 :
427 3423175 : for (i = 0; i < mpzspm->sp_num; i++)
428 : {
429 : /* this loop is performed len*sp_num/MPZSPV_NORMALISE_STRIDE times */
430 :
431 : /* prime_recip = 1/p_i * (1+u)^2 wih |u| <= 2^(-24) where one
432 : exponent is due to the sp -> float conversion, and one to the
433 : division */
434 3199701 : prime_recip = 1.0f / (float) spm[i]->sp; /* 1/p_i */
435 :
436 360198271 : for (k = 0; k < stride; k++)
437 : {
438 : /* this loop is performed len*sp_num times */
439 :
440 : /* crt3[i] = p_i/P mod p_i (q_i in the theorem) */
441 356998570 : t = sp_mul (x[i][l + k + offset], mpzspm->crt3[i], spm[i]->sp,
442 356998570 : spm[i]->mul_c);
443 :
444 : /* crt1[i] = P / p_i mod modulus: we accumulate in mpzv[l + k]
445 : the sum of P t_i/p_i = t_i (P/p_i) mod N.
446 : If N has n limbs, crt1[i] has n limbs too,
447 : thus mpzv[l+k] has about n limbs */
448 : if (sizeof (sp_t) > sizeof (unsigned long))
449 : {
450 : mpz_set_sp (mt, t);
451 : mpz_addmul (mpzv[l + k], mpzspm->crt1[i], mt);
452 : }
453 : else
454 356998570 : mpz_addmul_ui (mpzv[l + k], mpzspm->crt1[i], t);
455 :
456 : /* After the conversion from t to float and the multiplication,
457 : the value of (float) t * prime_recip = t/p_i * (1+v)^4
458 : where |v| <= 2^(-24). Since |t| < p_i, the absolute error
459 : is bounded by (1+v^4)-1 <= 5*v. Thus the total error on f[k]
460 : is bounded by 5*sp_num*2^(-24). Since we want this to be smaller
461 : than 0.5, we need sp_num <= 2^23/5 thus sp_num <= 1677721.
462 : This corresponds to a number of at most 15656374 digits on a
463 : 32-bit machine, and at most 31312749 digits on 64-bit. */
464 356998570 : f[k] += (float) t * prime_recip;
465 : }
466 : }
467 :
468 : /* crt2[i] = -i*P mod modulus */
469 27369203 : for (k = 0; k < stride; k++)
470 27145729 : mpz_add (mpzv[l + k], mpzv[l + k], mpzspm->crt2[(unsigned int) f[k]]);
471 : }
472 :
473 86720 : mpz_clear (mt);
474 86720 : free (f);
475 : #ifdef TIMING_CRT
476 : mpzspv_to_mpzv_time += cputime ();
477 : #endif
478 86720 : }
479 :
480 : #if 0
481 : void
482 : mpzspv_pwmul (mpzspv_t r, spv_size_t r_offset, mpzspv_t x, spv_size_t x_offset,
483 : mpzspv_t y, spv_size_t y_offset, spv_size_t len, mpzspm_t mpzspm)
484 : {
485 : unsigned int i;
486 :
487 : ASSERT (mpzspv_verify (r, r_offset + len, 0, mpzspm));
488 : ASSERT (mpzspv_verify (x, x_offset, len, mpzspm));
489 : ASSERT (mpzspv_verify (y, y_offset, len, mpzspm));
490 :
491 : for (i = 0; i < mpzspm->sp_num; i++)
492 : spv_pwmul (r[i] + r_offset, x[i] + x_offset, y[i] + y_offset,
493 : len, mpzspm->spm[i]->sp, mpzspm->spm[i]->mul_c);
494 : }
495 : #endif
496 :
497 : /* Normalise the vector x[][offset..offset+len-1] of RNS residues modulo the
498 : * input modulus N.
499 : *
500 : * Reference: Bernstein & Sorenson: Explicit CRT mod m mod p_j, Theorem 4.1.
501 : *
502 : * time: O(len * sp_num^2)
503 : * memory: MPZSPV_NORMALISE_STRIDE mpzspv coeffs
504 : * 6 * MPZSPV_NORMALISE_STRIDE sp's
505 : * MPZSPV_NORMALISE_STRIDE floats
506 : * For a subquadratic version: look at Section 23 of
507 : * http://cr.yp.to/papers.html#multapps
508 : */
509 : void
510 24491 : mpzspv_normalise (mpzspv_t x, spv_size_t offset, spv_size_t len,
511 : mpzspm_t mpzspm)
512 : {
513 24491 : unsigned int i, j, sp_num = mpzspm->sp_num;
514 : spv_size_t k, l;
515 : sp_t v;
516 : spv_t s, d, w;
517 24491 : spm_t *spm = mpzspm->spm;
518 : float prime_recip;
519 : float *f;
520 : mpzspv_t t;
521 :
522 : #ifdef TIMING_CRT
523 : mpzspv_normalise_time -= cputime ();
524 : #endif
525 : ASSERT (mpzspv_verify (x, offset, len, mpzspm));
526 :
527 24491 : f = (float *) malloc (MPZSPV_NORMALISE_STRIDE * sizeof (float));
528 24491 : s = (spv_t) malloc (3 * MPZSPV_NORMALISE_STRIDE * sizeof (sp_t));
529 24491 : d = (spv_t) malloc (3 * MPZSPV_NORMALISE_STRIDE * sizeof (sp_t));
530 24491 : if (f == NULL || s == NULL || d == NULL)
531 : {
532 0 : fprintf (stderr, "Cannot allocate memory in mpzspv_normalise\n");
533 0 : exit (1);
534 : }
535 24491 : t = mpzspv_init (MPZSPV_NORMALISE_STRIDE, mpzspm);
536 :
537 24491 : memset (s, 0, 3 * MPZSPV_NORMALISE_STRIDE * sizeof (sp_t));
538 :
539 290652 : for (l = 0; l < len; l += MPZSPV_NORMALISE_STRIDE)
540 : {
541 266161 : spv_size_t stride = MIN (MPZSPV_NORMALISE_STRIDE, len - l);
542 :
543 : /* FIXME: use B&S Theorem 2.2 */
544 19434723 : for (k = 0; k < stride; k++)
545 19168562 : f[k] = 0.5; /* this is executed len times */
546 :
547 5906402 : for (i = 0; i < sp_num; i++)
548 : {
549 : /* this loop is performed len*sp_num/MPZSPV_NORMALISE_STRIDE times */
550 5640241 : prime_recip = 1.0f / (float) spm[i]->sp;
551 :
552 359145077 : for (k = 0; k < stride; k++)
553 : {
554 : /* this is executed len*sp_num times,
555 : crt3[i] = p_i/P mod p_i (q_i in Theorem 3.1) */
556 707009672 : x[i][l + k + offset] = sp_mul (x[i][l + k + offset],
557 353504836 : mpzspm->crt3[i], spm[i]->sp, spm[i]->mul_c);
558 : /* now x[i] is t_i in Theorem 3.1 */
559 353504836 : f[k] += (float) x[i][l + k + offset] * prime_recip;
560 : }
561 : }
562 :
563 5906402 : for (i = 0; i < sp_num; i++)
564 : {
565 359145077 : for (k = 0; k < stride; k++)
566 : {
567 : /* this is executed len*sp_num times */
568 :
569 : /* crt5[i] = (-P mod modulus) mod p_i */
570 353504836 : umul_ppmm (d[3 * k + 1], d[3 * k], mpzspm->crt5[i], (sp_t) f[k]);
571 : /* {d+3*k,2} = ((-P mod modulus) mod p_i) * round(sum(t_j/p_j)),
572 : this accounts for the right term in Theorem 4.1 */
573 353504836 : d[3 * k + 2] = 0;
574 : }
575 :
576 132106218 : for (j = 0; j < sp_num; j++)
577 : {
578 : /* this is executed len*sp_num^2/MPZSPV_NORMALISE_STRIDE times */
579 126465977 : w = x[j] + offset;
580 : /* crt4[i][j] = ((P / p[i]) mod modulus) mod p[j] */
581 126465977 : v = mpzspm->crt4[i][j];
582 :
583 7409599329 : for (k = 0; k < stride; k++)
584 : /* this is executed len*sp_num^2 times, and computes the left
585 : term in Theorem 4.1 */
586 7283133352 : umul_ppmm (s[3 * k + 1], s[3 * k], w[k + l], v);
587 :
588 : /* This mpn_add_n adds in parallel all "stride" contributions,
589 : and accounts for about a third of the function's runtime.
590 : Since d has size O(stride), the cumulated complexity of this
591 : call is O(len*sp_num^2) */
592 126465977 : mpn_add_n ((mp_ptr) d, (mp_srcptr) d, (mp_srcptr) s, 3 * stride);
593 : }
594 :
595 : /* we finally reduce the contribution modulo each p_i */
596 359145077 : for (k = 0; k < stride; k++)
597 353504836 : t[i][k] = mpn_mod_1 ((mp_ptr) (d + 3 * k), 3, spm[i]->sp);
598 : }
599 266161 : mpzspv_set (x, l + offset, t, 0, stride, mpzspm);
600 : }
601 :
602 24491 : mpzspv_clear (t, mpzspm);
603 :
604 24491 : free (s);
605 24491 : free (d);
606 24491 : free (f);
607 : #ifdef TIMING_CRT
608 : mpzspv_normalise_time += cputime ();
609 : #endif
610 24491 : }
611 :
612 : void
613 1840 : mpzspv_to_ntt (mpzspv_t x, spv_size_t offset, spv_size_t len,
614 : spv_size_t ntt_size, int monic, mpzspm_t mpzspm)
615 : {
616 : unsigned int i;
617 : spv_size_t j, log2_ntt_size;
618 : spm_t spm;
619 : spv_t spv;
620 :
621 : ASSERT (mpzspv_verify (x, offset, len, mpzspm));
622 : ASSERT (mpzspv_verify (x, offset + ntt_size, 0, mpzspm));
623 :
624 1840 : log2_ntt_size = ceil_log_2 (ntt_size);
625 :
626 30206 : for (i = 0; i < mpzspm->sp_num; i++)
627 : {
628 28366 : spm = mpzspm->spm[i];
629 28366 : spv = x[i] + offset;
630 :
631 28366 : if (ntt_size < len)
632 : {
633 0 : for (j = ntt_size; j < len; j += ntt_size)
634 0 : spv_add (spv, spv, spv + j, ntt_size, spm->sp);
635 : }
636 28366 : if (ntt_size > len)
637 14183 : spv_set_zero (spv + len, ntt_size - len);
638 :
639 28366 : if (monic)
640 14183 : spv[len % ntt_size] = sp_add (spv[len % ntt_size], 1, spm->sp);
641 :
642 28366 : spv_ntt_gfp_dif (spv, log2_ntt_size, spm);
643 : }
644 1840 : }
645 :
646 : #if 0
647 : void
648 : mpzspv_from_ntt (mpzspv_t x, spv_size_t offset, spv_size_t ntt_size,
649 : spv_size_t monic_pos, mpzspm_t mpzspm)
650 : {
651 : unsigned int i;
652 : spv_size_t log2_ntt_size;
653 : spm_t spm;
654 : spv_t spv;
655 :
656 : ASSERT (mpzspv_verify (x, offset, ntt_size, mpzspm));
657 :
658 : log2_ntt_size = ceil_log_2 (ntt_size);
659 :
660 : for (i = 0; i < mpzspm->sp_num; i++)
661 : {
662 : spm = mpzspm->spm[i];
663 : spv = x[i] + offset;
664 :
665 : spv_ntt_gfp_dit (spv, log2_ntt_size, spm);
666 :
667 : /* spm->sp - (spm->sp - 1) / ntt_size is the inverse of ntt_size */
668 : spv_mul_sp (spv, spv, spm->sp - (spm->sp - 1) / ntt_size,
669 : ntt_size, spm->sp, spm->mul_c);
670 :
671 : if (monic_pos)
672 : spv[monic_pos % ntt_size] = sp_sub (spv[monic_pos % ntt_size],
673 : 1, spm->sp);
674 : }
675 : }
676 : #endif
677 :
678 : void
679 1 : mpzspv_random (mpzspv_t x, spv_size_t offset, spv_size_t len, mpzspm_t mpzspm)
680 : {
681 : unsigned int i;
682 :
683 : ASSERT (mpzspv_verify (x, offset, len, mpzspm));
684 :
685 25 : for (i = 0; i < mpzspm->sp_num; i++)
686 24 : spv_random (x[i] + offset, len, mpzspm->spm[i]->sp);
687 1 : }
688 :
689 :
690 : /* Do multiplication via NTT. Depending on the value of "steps", does
691 : in-place forward transform of x, in-place forward transform of y,
692 : pair-wise multiplication of x by y to r, in-place inverse transform of r.
693 : Contrary to calling these three operations separately, this function does
694 : all three steps on a small-prime vector at a time, resulting in slightly
695 : better cache efficiency (also in preparation to storing NTT vectors on disk
696 : and reading them in for the multiplication). */
697 :
698 : void
699 66531 : mpzspv_mul_ntt (mpzspv_t r, const spv_size_t offsetr,
700 : mpzspv_t x, const spv_size_t offsetx, const spv_size_t lenx,
701 : mpzspv_t y, const spv_size_t offsety, const spv_size_t leny,
702 : const spv_size_t ntt_size, const int monic, const spv_size_t monic_pos,
703 : mpzspm_t mpzspm, const int steps)
704 : {
705 : spv_size_t log2_ntt_size;
706 : int i;
707 :
708 : ASSERT (mpzspv_verify (x, offsetx, lenx, mpzspm));
709 : ASSERT (mpzspv_verify (y, offsety, leny, mpzspm));
710 : ASSERT (mpzspv_verify (x, offsetx + ntt_size, 0, mpzspm));
711 : ASSERT (mpzspv_verify (y, offsety + ntt_size, 0, mpzspm));
712 : ASSERT (mpzspv_verify (r, offsetr + ntt_size, 0, mpzspm));
713 :
714 66531 : log2_ntt_size = ceil_log_2 (ntt_size);
715 :
716 : /* Need parallelization at higher level (e.g., handling a branch of the
717 : product tree in one thread) to make this worthwhile for ECM */
718 : #define MPZSPV_MUL_NTT_OPENMP 0
719 :
720 : #if defined(_OPENMP) && MPZSPV_MUL_NTT_OPENMP
721 : #pragma omp parallel if (ntt_size > 16384)
722 : {
723 : #pragma omp for
724 : #endif
725 1473287 : for (i = 0; i < (int) mpzspm->sp_num; i++)
726 : {
727 : spv_size_t j;
728 1406756 : spm_t spm = mpzspm->spm[i];
729 1406756 : spv_t spvr = r[i] + offsetr;
730 1406756 : spv_t spvx = x[i] + offsetx;
731 1406756 : spv_t spvy = y[i] + offsety;
732 :
733 1406756 : if ((steps & NTT_MUL_STEP_FFT1) != 0) {
734 1400241 : if (ntt_size < lenx)
735 : {
736 0 : for (j = ntt_size; j < lenx; j += ntt_size)
737 0 : spv_add (spvx, spvx, spvx + j, ntt_size, spm->sp);
738 : }
739 1400241 : if (ntt_size > lenx)
740 824057 : spv_set_zero (spvx + lenx, ntt_size - lenx);
741 :
742 1400241 : if (monic)
743 581358 : spvx[lenx % ntt_size] = sp_add (spvx[lenx % ntt_size], 1, spm->sp);
744 :
745 1400241 : spv_ntt_gfp_dif (spvx, log2_ntt_size, spm);
746 : }
747 :
748 1406756 : if ((steps & NTT_MUL_STEP_FFT2) != 0) {
749 712553 : if (ntt_size < leny)
750 : {
751 0 : for (j = ntt_size; j < leny; j += ntt_size)
752 0 : spv_add (spvy, spvy, spvy + j, ntt_size, spm->sp);
753 : }
754 712553 : if (ntt_size > leny)
755 608792 : spv_set_zero (spvy + leny, ntt_size - leny);
756 :
757 712553 : if (monic)
758 581358 : spvy[leny % ntt_size] = sp_add (spvy[leny % ntt_size], 1, spm->sp);
759 :
760 712553 : spv_ntt_gfp_dif (spvy, log2_ntt_size, spm);
761 : }
762 :
763 1406756 : if ((steps & NTT_MUL_STEP_MUL) != 0) {
764 1406756 : spv_pwmul (spvr, spvx, spvy, ntt_size, spm->sp, spm->mul_c);
765 : }
766 :
767 1406756 : if ((steps & NTT_MUL_STEP_IFFT) != 0) {
768 : ASSERT (sizeof (mp_limb_t) >= sizeof (sp_t));
769 :
770 1406756 : spv_ntt_gfp_dit (spvr, log2_ntt_size, spm);
771 :
772 : /* spm->sp - (spm->sp - 1) / ntt_size is the inverse of ntt_size */
773 1406756 : spv_mul_sp (spvr, spvr, spm->sp - (spm->sp - 1) / ntt_size,
774 : ntt_size, spm->sp, spm->mul_c);
775 :
776 1406756 : if (monic_pos)
777 581358 : spvr[monic_pos % ntt_size] = sp_sub (spvr[monic_pos % ntt_size],
778 : 1, spm->sp);
779 : }
780 : }
781 : #if defined(_OPENMP) && MPZSPV_MUL_NTT_OPENMP
782 : }
783 : #endif
784 66531 : }
785 :
786 : /* Computes a DCT-I of the length dctlen. Input is the spvlen coefficients
787 : in spv. tmp is temp space and must have space for 2*dctlen-2 sp_t's */
788 :
789 : void
790 561 : mpzspv_to_dct1 (mpzspv_t dct, const mpzspv_t spv, const spv_size_t spvlen,
791 : const spv_size_t dctlen, mpzspv_t tmp,
792 : const mpzspm_t mpzspm)
793 : {
794 561 : const spv_size_t l = 2 * (dctlen - 1); /* Length for the DFT */
795 561 : const spv_size_t log2_l = ceil_log_2 (l);
796 : int j;
797 :
798 : #ifdef _OPENMP
799 : #pragma omp parallel private(j)
800 : {
801 : #pragma omp for
802 : #endif
803 8256 : for (j = 0; j < (int) mpzspm->sp_num; j++)
804 : {
805 7695 : const spm_t spm = mpzspm->spm[j];
806 : spv_size_t i;
807 :
808 : /* Make a symmetric copy of spv in tmp. I.e. with spv = [3, 2, 1],
809 : spvlen = 3, dctlen = 5 (hence l = 8), we want
810 : tmp = [3, 2, 1, 0, 0, 0, 1, 2] */
811 7695 : spv_set (tmp[j], spv[j], spvlen);
812 7695 : spv_rev (tmp[j] + l - spvlen + 1, spv[j] + 1, spvlen - 1);
813 : /* Now we have [3, 2, 1, ?, ?, ?, 1, 2]. Fill the ?'s with zeros. */
814 7695 : spv_set_sp (tmp[j] + spvlen, (sp_t) 0, l - 2 * spvlen + 1);
815 :
816 : #if 0
817 : printf ("mpzspv_to_dct1: tmp[%d] = [", j);
818 : for (i = 0; i < l; i++)
819 : printf ("%lu, ", tmp[j][i]);
820 : printf ("]\n");
821 : #endif
822 :
823 7695 : spv_ntt_gfp_dif (tmp[j], log2_l, spm);
824 :
825 : #if 0
826 : printf ("mpzspv_to_dct1: tmp[%d] = [", j);
827 : for (i = 0; i < l; i++)
828 : printf ("%lu, ", tmp[j][i]);
829 : printf ("]\n");
830 : #endif
831 :
832 : /* The forward transform is scrambled. We want elements [0 ... l/2]
833 : of the unscrabled data, that is all the coefficients with the most
834 : significant bit in the index (in log2(l) word size) unset, plus the
835 : element at index l/2. By scrambling, these map to the elements with
836 : even index, plus the element at index 1.
837 : The elements with scrambled index 2*i are stored in h[i], the
838 : element with scrambled index 1 is stored in h[params->l] */
839 :
840 : #ifdef WANT_ASSERT
841 : /* Test that the coefficients are symmetric (if they were unscrambled)
842 : and that our algorithm for finding identical coefficients in the
843 : scrambled data works */
844 : {
845 : spv_size_t m = 5;
846 : for (i = 2; i < l; i += 2L)
847 : {
848 : /* This works, but why? */
849 : if (i + i / 2L > m)
850 : m = 2L * m + 1L;
851 :
852 : ASSERT (tmp[j][i] == tmp[j][m - i]);
853 : #if 0
854 : printf ("mpzspv_to_dct1: DFT[%lu] == DFT[%lu]\n", i, m - i);
855 : #endif
856 : }
857 : }
858 : #endif
859 :
860 : /* Copy coefficients to dct buffer */
861 56926031 : for (i = 0; i < l / 2; i++)
862 56918336 : dct[j][i] = tmp[j][i * 2];
863 7695 : dct[j][l / 2] = tmp[j][1];
864 : }
865 : #ifdef _OPENMP
866 : }
867 : #endif
868 561 : }
869 :
870 :
871 : /* Multiply the polynomial in "dft" by the RLP in "dct", where "dft"
872 : contains the polynomial coefficients (not FFT'd yet) and "dct"
873 : contains the DCT-I coefficients of the RLP. The latter are
874 : assumed to be in the layout produced by mpzspv_to_dct1().
875 : Output are the coefficients of the product polynomial, stored in dft.
876 : The "steps" parameter controls which steps are computed:
877 : NTT_MUL_STEP_FFT1: do forward transform
878 : NTT_MUL_STEP_MUL: do point-wise product
879 : NTT_MUL_STEP_IFFT: do inverse transform
880 : */
881 :
882 : void
883 1133 : mpzspv_mul_by_dct (mpzspv_t dft, const mpzspv_t dct, const spv_size_t len,
884 : const mpzspm_t mpzspm, const int steps)
885 : {
886 : int j;
887 1133 : spv_size_t log2_len = ceil_log_2 (len);
888 :
889 : #ifdef _OPENMP
890 : #pragma omp parallel private(j)
891 : {
892 : #pragma omp for
893 : #endif
894 17065 : for (j = 0; j < (int) (mpzspm->sp_num); j++)
895 : {
896 15932 : const spm_t spm = mpzspm->spm[j];
897 15932 : const spv_t spv = dft[j];
898 : unsigned long i, m;
899 :
900 : /* Forward DFT of dft[j] */
901 15932 : if ((steps & NTT_MUL_STEP_FFT1) != 0)
902 13309 : spv_ntt_gfp_dif (spv, log2_len, spm);
903 :
904 : /* Point-wise product */
905 15932 : if ((steps & NTT_MUL_STEP_MUL) != 0)
906 : {
907 13309 : m = 5UL;
908 :
909 13309 : spv[0] = sp_mul (spv[0], dct[j][0], spm->sp, spm->mul_c);
910 13309 : spv[1] = sp_mul (spv[1], dct[j][len / 2UL], spm->sp, spm->mul_c);
911 :
912 64172064 : for (i = 2UL; i < len; i += 2UL)
913 : {
914 : /* This works, but why? */
915 64158755 : if (i + i / 2UL > m)
916 82036 : m = 2UL * m + 1;
917 :
918 64158755 : spv[i] = sp_mul (spv[i], dct[j][i / 2UL], spm->sp, spm->mul_c);
919 64158755 : spv[m - i] = sp_mul (spv[m - i], dct[j][i / 2UL], spm->sp,
920 : spm->mul_c);
921 : }
922 : }
923 :
924 : /* Inverse transform of dft[j] */
925 15932 : if ((steps & NTT_MUL_STEP_IFFT) != 0)
926 : {
927 10686 : spv_ntt_gfp_dit (spv, log2_len, spm);
928 :
929 : /* Divide by transform length. FIXME: scale the DCT of h instead */
930 10686 : spv_mul_sp (spv, spv, spm->sp - (spm->sp - 1) / len, len,
931 : spm->sp, spm->mul_c);
932 : }
933 : }
934 : #ifdef _OPENMP
935 : }
936 : #endif
937 1133 : }
938 :
939 :
940 : void
941 3659 : mpzspv_sqr_reciprocal (mpzspv_t dft, const spv_size_t n,
942 : const mpzspm_t mpzspm)
943 : {
944 3659 : const spv_size_t log2_n = ceil_log_2 (n);
945 3659 : const spv_size_t len = ((spv_size_t) 2) << log2_n;
946 3659 : const spv_size_t log2_len = 1 + log2_n;
947 : int j;
948 :
949 : ASSERT(mpzspm->max_ntt_size % 3UL == 0UL);
950 : ASSERT(len % 3UL != 0UL);
951 : ASSERT(mpzspm->max_ntt_size % len == 0UL);
952 :
953 : #ifdef _OPENMP
954 : #pragma omp parallel
955 : {
956 : #pragma omp for
957 : #endif
958 49943 : for (j = 0; j < (int) (mpzspm->sp_num); j++)
959 : {
960 46284 : const spm_t spm = mpzspm->spm[j];
961 46284 : const spv_t spv = dft[j];
962 : sp_t w1, w2, invlen;
963 46284 : const sp_t sp = spm->sp, mul_c = spm->mul_c;
964 : spv_size_t i;
965 :
966 : /* Zero out NTT elements [n .. len-n] */
967 46284 : spv_set_sp (spv + n, (sp_t) 0, len - 2*n + 1);
968 :
969 : #ifdef TRACE_ntt_sqr_reciprocal
970 : if (j == 0)
971 : {
972 : printf ("ntt_sqr_reciprocal: NTT vector mod %lu\n", sp);
973 : ntt_print_vec ("ntt_sqr_reciprocal: before weighting:", spv, len);
974 : }
975 : #endif
976 :
977 : /* Compute the root for the weight signal, a 3rd primitive root
978 : of unity */
979 46284 : w1 = sp_pow (spm->prim_root, mpzspm->max_ntt_size / 3UL, sp,
980 : mul_c);
981 : /* Compute iw= 1/w */
982 46284 : w2 = sp_pow (spm->inv_prim_root, mpzspm->max_ntt_size / 3UL, sp,
983 : mul_c);
984 : #ifdef TRACE_ntt_sqr_reciprocal
985 : if (j == 0)
986 : printf ("w1 = %lu ,w2 = %lu\n", w1, w2);
987 : #endif
988 : ASSERT(sp_mul(w1, w2, sp, mul_c) == (sp_t) 1);
989 : ASSERT(w1 != (sp_t) 1);
990 : ASSERT(sp_pow (w1, 3UL, sp, mul_c) == (sp_t) 1);
991 : ASSERT(w2 != (sp_t) 1);
992 : ASSERT(sp_pow (w2, 3UL, sp, mul_c) == (sp_t) 1);
993 :
994 : /* Fill NTT elements spv[len-n+1 .. len-1] with coefficients and
995 : apply weight signal to spv[i] and spv[l-i] for 0 <= i < n
996 : Use the fact that w^i + w^{-i} = -1 if i != 0 (mod 3). */
997 8328609 : for (i = 0; i + 2 < n; i += 3)
998 : {
999 : sp_t t, u;
1000 :
1001 8282325 : if (i > 0)
1002 8245919 : spv[len - i] = spv[i];
1003 :
1004 8282325 : t = spv[i + 1];
1005 8282325 : u = sp_mul (t, w1, sp, mul_c);
1006 8282325 : spv[i + 1] = u;
1007 8282325 : spv[len - i - 1] = sp_neg (sp_add (t, u, sp), sp);
1008 :
1009 8282325 : t = spv[i + 2];
1010 8282325 : u = sp_mul (t, w2, sp, mul_c);
1011 8282325 : spv[i + 2] = u;
1012 8282325 : spv[len - i - 2] = sp_neg (sp_add (t, u, sp), sp);
1013 : }
1014 46284 : if (i < n && i > 0)
1015 : {
1016 19747 : spv[len - i] = spv[i];
1017 : }
1018 46284 : if (i + 1 < n)
1019 : {
1020 : sp_t t, u;
1021 12559 : t = spv[i + 1];
1022 12559 : u = sp_mul (t, w1, sp, mul_c);
1023 12559 : spv[i + 1] = u;
1024 12559 : spv[len - i - 1] = sp_neg (sp_add (t, u, sp), sp);
1025 : }
1026 :
1027 : #ifdef TRACE_ntt_sqr_reciprocal
1028 : if (j == 0)
1029 : ntt_print_vec ("ntt_sqr_reciprocal: after weighting:", spv, len);
1030 : #endif
1031 :
1032 : /* Forward DFT of dft[j] */
1033 46284 : spv_ntt_gfp_dif (spv, log2_len, spm);
1034 :
1035 : #ifdef TRACE_ntt_sqr_reciprocal
1036 : if (j == 0)
1037 : ntt_print_vec ("ntt_sqr_reciprocal: after forward transform:",
1038 : spv, len);
1039 : #endif
1040 :
1041 : /* Square the transformed vector point-wise */
1042 46284 : spv_pwmul (spv, spv, spv, len, sp, mul_c);
1043 :
1044 : #ifdef TRACE_ntt_sqr_reciprocal
1045 : if (j == 0)
1046 : ntt_print_vec ("ntt_sqr_reciprocal: after point-wise squaring:",
1047 : spv, len);
1048 : #endif
1049 :
1050 : /* Inverse transform of dft[j] */
1051 46284 : spv_ntt_gfp_dit (spv, log2_len, spm);
1052 :
1053 : #ifdef TRACE_ntt_sqr_reciprocal
1054 : if (j == 0)
1055 : ntt_print_vec ("ntt_sqr_reciprocal: after inverse transform:",
1056 : spv, len);
1057 : #endif
1058 :
1059 : /* Un-weight and divide by transform length */
1060 46284 : invlen = sp - (sp - (sp_t) 1) / len; /* invlen = 1/len (mod sp) */
1061 46284 : w1 = sp_mul (invlen, w1, sp, mul_c);
1062 46284 : w2 = sp_mul (invlen, w2, sp, mul_c);
1063 16606834 : for (i = 0; i < 2 * n - 3; i += 3)
1064 : {
1065 16560550 : spv[i] = sp_mul (spv[i], invlen, sp, mul_c);
1066 16560550 : spv[i + 1] = sp_mul (spv[i + 1], w2, sp, mul_c);
1067 16560550 : spv[i + 2] = sp_mul (spv[i + 2], w1, sp, mul_c);
1068 : }
1069 46284 : if (i < 2 * n - 1)
1070 33725 : spv[i] = sp_mul (spv[i], invlen, sp, mul_c);
1071 46284 : if (i < 2 * n - 2)
1072 16659 : spv[i + 1] = sp_mul (spv[i + 1], w2, sp, mul_c);
1073 :
1074 : #ifdef TRACE_ntt_sqr_reciprocal
1075 : if (j == 0)
1076 : ntt_print_vec ("ntt_sqr_reciprocal: after un-weighting:", spv, len);
1077 : #endif
1078 :
1079 : /* Separate the coefficients of R in the wrapped-around product. */
1080 :
1081 : /* Set w1 = cuberoot(1)^l where cuberoot(1) is the same primitive
1082 : 3rd root of unity we used for the weight signal */
1083 46284 : w1 = sp_pow (spm->prim_root, mpzspm->max_ntt_size / 3UL, sp,
1084 : mul_c);
1085 46284 : w1 = sp_pow (w1, len % 3UL, sp, mul_c);
1086 :
1087 : /* Set w2 = 1/(w1 - 1/w1). Incidentally, w2 = 1/sqrt(-3) */
1088 46284 : w2 = sp_inv (w1, sp, mul_c);
1089 46284 : w2 = sp_sub (w1, w2, sp);
1090 46284 : w2 = sp_inv (w2, sp, mul_c);
1091 : #ifdef TRACE_ntt_sqr_reciprocal
1092 : if (j == 0)
1093 : printf ("For separating: w1 = %lu, w2 = %lu\n", w1, w2);
1094 : #endif
1095 :
1096 23960930 : for (i = len - (2*n - 2); i <= len / 2; i++)
1097 : {
1098 : sp_t t, u;
1099 : /* spv[i] = s_i + w^{-l} s_{l-i}.
1100 : spv[l-i] = s_{l-i} + w^{-l} s_i */
1101 23914646 : t = sp_mul (spv[i], w1, sp, mul_c); /* t = w^l s_i + s_{l-i} */
1102 23914646 : t = sp_sub (t, spv[len - i], sp); /* t = w^l s_i + w^{-l} s_i */
1103 23914646 : t = sp_mul (t, w2, sp, mul_c); /* t = s_1 */
1104 :
1105 23914646 : u = sp_sub (spv[i], t, sp); /* u = w^{-l} s_{l-i} */
1106 23914646 : u = sp_mul (u, w1, sp, mul_c); /* u = s_{l-i} */
1107 23914646 : spv[i] = t;
1108 23914646 : spv[len - i] = u;
1109 : ASSERT(i < len / 2 || t == u);
1110 : }
1111 :
1112 : #ifdef TRACE_ntt_sqr_reciprocal
1113 : if (j == 0)
1114 : ntt_print_vec ("ntt_sqr_reciprocal: after un-wrapping:", spv, len);
1115 : #endif
1116 : }
1117 : #ifdef _OPENMP
1118 : }
1119 : #endif
1120 3659 : }
|