/* fft.c
 * Daniel S. Roche, January 2009
 * http://www.cs.uwaterloo.ca/~droche/
 *
 * Routines for FFT-based polynomial multiplication.
 * No extra space is used.
 *
 * See LICENSE.txt for copyright and permissions.
 */

#include <string.h>
#include "lsmul.h"

// Computes x*y mod p. The answer is in the range -p..2p. pinv is 1/p.
// Like MulMod from NTL
inline lsmul_ele fudge_mul (lsmul_ele x, lsmul_ele y, lsmul_ele p, double pinv)
{
   lsmul_ele q = (lsmul_ele) (((double)x) * ((double)y) * pinv);
   return (x*y - q*p);
}

// Compute the image modulo p (so output is in the range 0..p-1)
// of x, which is in the range -p..p-1.
// Uses arithmetic right shift trick from Shoup's NTL.
inline lsmul_ele one_red (lsmul_ele x, lsmul_ele p) {
   return x + ((x >> 31) & p);
}

// Computes DFT_omega(a) in-place, where a has length 2^k.
// Output will be in the reverted binary ordering.
// tp=p+p+p
inline void dft_forward (lsmul_ele *a, int k, 
   lsmul_ele p, lsmul_ele tp, double pinv, lsmul_ele omega)
{
   long i, stop, skip = 1L << k;
   lsmul_ele temp, accum, opow = omega;
   while (skip > 1) {
      skip >>= 1;
      for (i=0; !(i >> k); i += skip) {
         stop = i | skip;
         accum = (lsmul_ele) 1;
         for (; i<stop; ++i) {
            temp = a[i] - tp + a[i|skip];
            a[i|skip] = fudge_mul (a[i] - a[i|skip], accum, p, pinv) + p;
            a[i] = one_red (temp, tp);
            accum = fudge_mul (accum, opow, p, pinv);
         }
      }
      opow = fudge_mul(opow,opow,p,pinv);
   }
}


// Computes DFT_omega_inverse(a) in-place, where a has length 2^k.
// Input should be in reverted binary ordering.
inline void dft_reverse (lsmul_ele *a, int k, 
lsmul_ele p, lsmul_ele tp, double pinv, lsmul_ele omega)
{
   long i, stop, skip = 1;
   lsmul_ele temp, accum, opow = (lsmul_ele) 1;
   int curpow = k;

   while (curpow) {
      --curpow;
      accum = omega;
      for (i=0; i<curpow; ++i) accum = fudge_mul(accum,accum,p,pinv);
      // Now accum = omega^(2^curpow)
      opow = fudge_mul(opow,accum,p,pinv);
      // Now opow = omega^(-2^curpow)
      for (i=0; !(i >> k); i += skip) {
         stop = i | skip;
         accum = (lsmul_ele) 1;
         for (; i<stop; ++i) {
            temp = fudge_mul (accum, a[i|skip], p, pinv);
            a[i] -= p;
            a[i|skip] = one_red (a[i] - temp, tp);
            a[i] += (temp-p);
            a[i] = one_red (a[i], tp);
            accum = fudge_mul (accum, opow, p, pinv);
         }
      }
      skip <<= 1;
   }
}

void mul_fft (lsmul_ele *c, const lsmul_ele *a, long sa,
   const lsmul_ele *b, long sb, lsmul_ele p, int k, lsmul_ele omega)
{
   long j, ind, sc = (sa+sb-1);
   long pow2, mask;
   unsigned long long sum;
   int i;//, k = 0;
   double pinv = 1.0/p;
   lsmul_ele tinv = (p+1)/2, ninv = tinv, accum,
      twop=p+p, threep = twop+p, theta;

   assert ((sc>>k) <= 1);

   while (((sc-1) >> (k-1)) == 0) {
      --k;
      omega = fudge_mul (omega, omega, p, pinv);
   }
   theta = omega;

   assert ( (sc <= (1L << k)) && (sc > (1L << (k-1))) );

   pow2 = (1L << k);
   mask = pow2-1;

   for (i=k-1; i >= 0; --i) {
      memset (c, 0, pow2 * sizeof(lsmul_ele));

      pow2 >>= 1;
      mask >>= 1;

      accum = (lsmul_ele) 1;
      for (j=0; j<sa; ++j) {
         ind = j & mask;
         c[ind] = 
            one_red ((c[ind] - twop) + fudge_mul(a[j],accum,p,pinv), threep);
         accum = fudge_mul(accum,theta,p,pinv);
      }

      accum = (lsmul_ele) 1;
      for (j=0; j<sb; ++j) {
         ind = (j & mask) | pow2;
         c[ind] = 
            one_red ((c[ind] - twop) + fudge_mul(b[j],accum,p,pinv), threep);
         accum = fudge_mul(accum,theta,p,pinv);
      }

      theta = fudge_mul(theta,theta,p,pinv);
      dft_forward (c, i, p, threep, pinv, theta);
      dft_forward (c+pow2, i, p, threep, pinv, theta);

      for (j=0; j<pow2; ++j)
         c[j|pow2] = fudge_mul (c[j], c[j|pow2], p, pinv) + p;
   }

   sum = (unsigned long long) a[0];
   for (j=1; j<sa; ++j) sum += a[j];
   c[0] = (lsmul_ele) (sum % ((unsigned long long)p));
   sum = (unsigned long long) b[0];
   for (j=1; j<sb; ++j) sum += b[j];
   c[0] = fudge_mul (c[0], (lsmul_ele) (sum % ((unsigned long long)p)), p, pinv)
      + p;

   dft_reverse (c, k, p, threep, pinv, omega);

   for (i=1; i<k; ++i) ninv = fudge_mul(ninv,tinv,p,pinv);

   for (j=0; j<(1L << k); ++j) {
      c[j] = one_red (fudge_mul(c[j],ninv,p,pinv), p);
      c[j] -= p;
      c[j] = one_red(c[j], p);
   }
}
