/* pp.cc
 * Dan Roche, 21 Jan 2007
 * http://www.cs.uwaterloo.ca/~droche/
 * See the header file pp.h for more documentation.
 */

#include <NTL/ZZXFactoring.h>
#include <NTL/ZZ_pX.h>
#include "pp.h"

long FIRST_PRIMES[];

// Deterministic, slow method
// Computes a square-free decomposition, and then returns
// the gcd of the exponents.
long pp_sqd( const NTL::ZZX& f ) {
    NTL::vec_pair_ZZX_long sfd;
    long r;

    // Use NTL's built-in procedure to compute the s.q.d.
    NTL::SquareFreeDecomp(sfd,f);
    
    // Compute r = gcd of all exponents (all terms are nontrivial)
    r = sfd[0].b;
    for( long i = 1; i < sfd.length(); ++i ) {
    	gcd(r,sfd[i].b);
	if( r == 1 ) return 1;
    }

    return r;
}

long pp_lacunary( const std::pair<NTL::vec_ZZ,NTL::vec_long> &f,
		  long n,
                  long pbound,
		  long bits_inf_norm )
{
  //Determine whether or not we need multi-precision primes.
  if( pbound <= NTL_SP_NBITS ) {
    long primeind=0L, r=2L, // The index and value of the least prime
         diff=0L, maxpind=0L, niter, // These will be computed below,
	 // maxr is an upper bound on r from Corollary 2.10
	 maxr = 2*(NTL::NumBits(f.first.length()-1L) + bits_inf_norm), 
	 i,j,k,l;
    long p, expon; // The prime and exponent used to test if a number
                   // is an r'th power modulo p
 
    // Generate a random prime for the first iteration
    p = NTL::GenPrime_long(pbound);
    NTL::zz_p::init(p);
    NTL::zz_p fofa; // The computed value of f(a) mod p, for random a
    NTL::vec_zz_p modcoeffs, // The coefficients of f modulo p
                  powsofa;   // Computed powers of a used for evaluation

    if( FIRST_PRIMES[NUM_PRIMES-1] <= maxr )
    	NTL::Error
	    ("\nNot enough precomputed primes. Increase NUM_PRIMES.");
    
    // Find the index of the largest prime that could divide r
    while( FIRST_PRIMES[++maxpind] <= maxr );

    // The number of iterations required for each candidate r
    niter = NTL::NumBits(maxpind-2L) + 1;

    // Find the maximal exponent difference.
    for( i = 0L; i < f.second.length(); ++i )
    	diff = NTL::max(diff,f.second[i]);
    // We need to compute this many repeated squares of a for evaluation
    powsofa.SetLength(NTL::NumBits(diff));

    while(true) {
	// Find next largest prime which is not larger than the bound
	// and which divides n.
    	while( (n % r) && (++primeind < maxpind) )
	    r = FIRST_PRIMES[primeind];

	// If we've checked every possible value of r, then we're done,
	// and we conclude w.h.p. that f is not a perfect power.
	if( primeind >= maxpind ) return 1;

	// Check if f is an r'th power using niter iterations
	i=0L;
	while(true) {
	    // Find a prime p such that r divides p-1
	    while( (p-1)%r )
	        p = NTL::GenPrime_long(pbound);
	    NTL::zz_p::init(p);

	    // Raising an element of Z_p to this power will tell us if
	    // if is an r'th power modulo p.
	    expon = (p-1)/r;

	    // Reduce the coefficients of f modulo p
	    NTL::conv(modcoeffs,f.first);

	    // We need 5 iterations here because (3/4)^5 < 1/4
	    for( j = 0; j < 5; ++j ) {
		// Choose a random evaluation point in Z_p
		NTL::random(powsofa[0]);

		// Compute f(a) mod p
		for( k = 1; k < powsofa.length(); ++k )
		    NTL::sqr(powsofa[k],powsofa[k-1]);
		NTL::clear(fofa);
		for( k = 0; k < modcoeffs.length(); ++k ) {
		    fofa += modcoeffs[k];
		    diff = f.second[k];
		    l = 0;
		    while( diff ) {
		    	if( diff % 2 )
			    fofa *= powsofa[l];
			++l; diff /= 2;
		    }
		}

		// Check if f(a) is NOT an r'th power
	    	if( !NTL::IsOne(NTL::power(fofa,expon)) ) break;
	    }
	    // Give up if we've found something that was not an r'th power
	    if( j < 5 ) break;

	    if( ++i < niter ) p = NTL::GenPrime_long(pbound);
	    // If every evaluation was an r'th power, return r with confidence.
	    else return r;
	}

	r = FIRST_PRIMES[++primeind];
    }
   } else {
    long primeind=0L, r=2L, // The index and value of the least prime
         diff=0L, maxpind=0L, niter, // These will be computed below,
	 // maxr is an upper bound on r from Corollary 2.10
	 maxr = 2*(NTL::NumBits(f.first.length()-1L) + bits_inf_norm), 
	 i,j,k,l;
    NTL::ZZ p, expon; // The prime and exponent used to test if a number
                 // is an r'th power modulo p

    // Generate a random prime for the first iteration
    NTL::GenPrime(p,pbound);
    NTL::ZZ_p::init(p);
    NTL::ZZ_p fofa; // The computed value of f(a) mod p, for random a
    NTL::vec_ZZ_p modcoeffs, // The coefficients of f modulo p
                  powsofa;   // Computed powers of a used for evaluation

    if( FIRST_PRIMES[NUM_PRIMES-1] <= maxr )
    	NTL::Error
	    ("\nNot enough precomputed primes. Increase NUM_PRIMES.");
    
    // Find the index of the largest prime that could divide r
    while( FIRST_PRIMES[++maxpind] <= maxr );

    // The number of iterations required for each candidate r
    niter = NTL::NumBits(maxpind-2L) + 1;

    // Find the maximal exponent difference.
    for( i = 0L; i < f.second.length(); ++i )
    	diff = NTL::max(diff,f.second[i]);
    // We need to compute this many repeated squares of a for evaluation
    powsofa.SetLength(NTL::NumBits(diff));

    while(true) {
	// Find next largest prime which is not larger than the bound
	// and which divides n.
    	while( (n % r) && (++primeind < maxpind) )
	    r = FIRST_PRIMES[primeind];

	// If we've checked every possible value of r, then we're done,
	// and we conclude w.h.p. that f is not a perfect power.
	if( primeind >= maxpind ) return 1;

	// Check if f is an r'th power using niter iterations
	i=0L;
	while(true) {
	    // Find a prime p such that r divides p-1
	    while( (p-1)%r )
	        NTL::GenPrime(p,pbound);
	    NTL::ZZ_p::init(p);

	    // Raising an element of Z_p to this power will tell us if
	    // if is an r'th power modulo p.
	    expon = (p-1)/r;

	    // Reduce the coefficients of f modulo p
	    NTL::conv(modcoeffs,f.first);

	    // We need 5 iterations here because (3/4)^5 < 1/4
	    for( j = 0; j < 5; ++j ) {
		// Choose a random evaluation point in Z_p
		NTL::random(powsofa[0]);

		// Compute f(a) mod p
		for( k = 1; k < powsofa.length(); ++k )
		    NTL::sqr(powsofa[k],powsofa[k-1]);
		NTL::clear(fofa);
		for( k = 0; k < modcoeffs.length(); ++k ) {
		    fofa += modcoeffs[k];
		    diff = f.second[k];
		    l = 0;
		    while( diff ) {
		    	if( diff % 2 )
			    fofa *= powsofa[l];
			++l; diff /= 2;
		    }
		}

		// Check if f(a) is NOT an r'th power
	    	if( !NTL::IsOne(NTL::power(fofa,expon)) ) break;
	    }
	    // Give up if we've found something that was not an r'th power
	    if( j < 5 ) break;

	    if( ++i < niter ) NTL::GenPrime(p,pbound);
	    // If every evaluation was an r'th power, return r with confidence.
	    else return r;
	}

	r = FIRST_PRIMES[++primeind];
    }
  }
}

// Monte Carlo "fast" method for dense polynomials.
// A Newton iteration computes the r'th root for all possible values
// of r, and for each r'th root computed, a Monte Carlo test is performed
// to certify that the polynomial is actually an r'th power.
// Since the amound of work required will be proporional to (deg f)/r,
// we start with the largest possible value of r and decrease.
long pp_dense( const NTL::ZZX& f, long pbound, long tms ) {
  //Determine whether or not we need multi-precision primes.
  if( pbound <= NTL_SP_NBITS ) {
    long primeind=0L, r=2L,
         // These will be computed below
         maxpind=0L, diff=0L, niter, trail_nonzero=0L,
	 // maxr is an upper bound on r from Schinzel's Theorem
	 maxr = tms-2L, 
	 i,j,k,n,s,rmul;
    long p;

    // Generate a prime for the first iteration
    p = NTL::GenPrime_long(pbound);
    NTL::zz_p::init(p);
    NTL::zz_pX fp, h, hrm1, ph;
    NTL::zz_p a;

    // Find the lowest-order nonzero term
    while( NTL::IsZero(coeff(f,trail_nonzero)) ) ++trail_nonzero;

    // r must be a proper divisor of this number
    rmul = NTL::deg(f)-trail_nonzero;

    if( maxr < 2 ) return 1;
    if( maxr*maxr > rmul )
    	maxr = (long)(floor(sqrt((double)rmul)));

    if( FIRST_PRIMES[NUM_PRIMES-1] <= maxr )
    	NTL::Error
	    ("\nNot enough precomputed primes. Increase NUM_PRIMES.");
    
    // Find the index of the largest prime that could divide r
    while( FIRST_PRIMES[++maxpind] <= maxr );

    // The number of iterations required for each candidate r
    niter = NTL::NumBits(primeind-1L) + 1;

    while(true) {
	// Find the next largest prime which divides n.
    	while( (rmul % r) && (++primeind < maxpind) )
	    r = FIRST_PRIMES[primeind];

	// If we've checked every possible value of r, then we're done,
	// and we conclude w.h.p. that f is not a perfect power.
	if( primeind < 0L ) return 1;

	// Check if f is an r'th power using niter iterations
	i=0L;
	while(true) {
            // Set fp to the image of f modulo p
	    NTL::conv(fp,f);

	    // Find the lowest-order nonzero term modulo p
	    for( j=trail_nonzero; NTL::IsZero(coeff(fp,j)); ++j );
	    // Normalize fp so the constant coefficient is 1
	    fp >>= j;
	    fp /= NTL::ConstTerm(fp);

            // The degrees of fp and h
	    n = NTL::deg(fp);
	    s = n/r;

	    k = 1;
	    NTL::set(h);
	    // Loop invariant: h^r equiv f mod x^k (over Z_p)
	    while(k <= s) {
		k = NTL::min(2*k,s+1);
		j = r-1;

		// Compute hrm1 = h^(r-1) mod x^k
		NTL::set(hrm1);
		ph = h;
		while(j) {
		    if(j%2) NTL::MulTrunc(hrm1,hrm1,ph,k);
		    NTL::SqrTrunc(ph,ph,k);
		    j /= 2;
		}

		// Update the image of h to satisfy the loop invariant
		h = h - NTL::MulTrunc(NTL::MulTrunc(hrm1,h,k) 
		                       - NTL::trunc(fp,k),
				      NTL::InvTrunc(hrm1*r,k),
				      k);
	    }

	    // Choose a random evaluation point modulo p
	    NTL::random(a);
	    // Check if h(a)^r != f(a), and if so give up.
	    if( NTL::eval(fp,a) != NTL::power(NTL::eval(h,a),r) )
		break;
	    
	    if( ++i < niter ) {
		// Generate a prime for the next time around
	        p = NTL::GenPrime_long(pbound);
		NTL::zz_p::init(p);
	    }
	    // If we computed an r'th power each time, return r with confidence.
	    else return r;
	}

	r = FIRST_PRIMES[++primeind];
    }
  } else {
    long primeind=0L, r=2L,
         // These will be computed below
         maxpind=0L, diff=0L, niter, trail_nonzero=0L,
	 // maxr is an upper bound on r from Schinzel's Theorem
	 maxr = tms-2L, 
	 i,j,k,n,s,rmul;
    NTL::ZZ p;

    // Generate a prime for the first iteration
    NTL::GenPrime(p,pbound);
    NTL::ZZ_p::init(p);
    NTL::ZZ_pX fp, h, hrm1, ph;
    NTL::ZZ_p a;

    // Find the lowest-order nonzero term
    while( NTL::IsZero(coeff(f,trail_nonzero)) ) ++trail_nonzero;

    // r must be a proper divisor of this number
    rmul = NTL::deg(f)-trail_nonzero;

    if( maxr < 2 ) return 1;
    if( maxr*maxr > rmul )
    	maxr = (long)(floor(sqrt((double)rmul)));

    if( FIRST_PRIMES[NUM_PRIMES-1] <= maxr )
    	NTL::Error
	    ("\nNot enough precomputed primes. Increase NUM_PRIMES.");
    
    // Find the index of the largest prime that could divide r
    while( FIRST_PRIMES[++maxpind] <= maxr );

    // The number of iterations required for each candidate r
    niter = NTL::NumBits(primeind-1L) + 1;

    while(true) {
	// Find the next largest prime which divides n.
    	while( (rmul % r) && (++primeind < maxpind) )
	    r = FIRST_PRIMES[primeind];

	// If we've checked every possible value of r, then we're done,
	// and we conclude w.h.p. that f is not a perfect power.
	if( primeind < 0L ) return 1;

	// Check if f is an r'th power using niter iterations
	i=0L;
	while(true) {
            // Set fp to the image of f modulo p
	    NTL::conv(fp,f);

	    // Find the lowest-order nonzero term modulo p
	    for( j=trail_nonzero; NTL::IsZero(coeff(fp,j)); ++j );
	    // Normalize fp so the constant coefficient is 1
	    fp >>= j;
	    fp /= NTL::ConstTerm(fp);

            // The degrees of fp and h
	    n = NTL::deg(fp);
	    s = n/r;

	    k = 1;
	    NTL::set(h);
	    // Loop invariant: h^r equiv f mod x^k (over Z_p)
	    while(k <= s) {
		k = NTL::min(2*k,s+1);
		j = r-1;

		// Compute hrm1 = h^(r-1) mod x^k
		NTL::set(hrm1);
		ph = h;
		while(j) {
		    if(j%2) NTL::MulTrunc(hrm1,hrm1,ph,k);
		    NTL::SqrTrunc(ph,ph,k);
		    j /= 2;
		}

		// Update the image of h to satisfy the loop invariant
		h = h - NTL::MulTrunc(NTL::MulTrunc(hrm1,h,k) 
		                       - NTL::trunc(fp,k),
				      NTL::InvTrunc(hrm1*r,k),
				      k);
	    }

	    // Choose a random evaluation point modulo p
	    NTL::random(a);
	    // Check if h(a)^r != f(a), and if so give up.
	    if( NTL::eval(fp,a) != NTL::power(NTL::eval(h,a),r) )
		break;
	    
	    if( ++i < niter ) {
		// Generate a prime for the next time around
	        NTL::GenPrime(p,pbound);
		NTL::ZZ_p::init(p);
	    }
	    // If we computed an r'th power each time, return r with confidence.
	    else return r;
	}

	r = FIRST_PRIMES[++primeind];
    }
  }
}

long bits_in_inf_norm(NTL::ZZX f) {
    long toret = 0;
    for( int i = 0; i <= NTL::deg(f); ++i )
    	toret = NTL::max(toret,NTL::NumBits(NTL::coeff(f,i)));
    return toret;
}

void denseToSparse(std::pair<NTL::vec_ZZ,NTL::vec_long>& out, 
                   const NTL::ZZX& in) {
    long j = NTL::deg(in);
    NTL::ZZ c;
    out.first.SetLength(0);
    out.second.SetLength(0);
    if( NTL::IsZero(in) ) return;
    NTL::append(out.first,LeadCoeff(in));
    NTL::append(out.second,deg(in));
    while( --j >= 0 ) {
        c = NTL::coeff(in,j);
	if( !NTL::IsZero(c) ) {
	    NTL::append(out.first,c);
	    out.second[out.second.length()-1] -= j;
	    NTL::append(out.second,j);
	}
    }
    return;
}

long pbound( long deg_h, long tms_h, long bits_inf_norm_h, long r,
             long min_p_bits ) {
    // The primes chosen must ultimately be greater than 100 for the
    // bounds to hold.
    if( min_p_bits < 7L ) min_p_bits = 7L;

    // Calculate upper bound on log_2( 2-norm of f )
    NTL::ZZ log2nf;
    log2nf = r;
    log2nf *= NTL::NumBits(tms_h-1L) + bits_inf_norm_h;

    // Calculate an upper bound on the number of "bad primes"
    NTL::ZZ mu;
    mu = deg_h*r;
    mu *= NTL::NumBits(deg_h*r-1L);
    mu += log2nf * 2L * deg_h * r - 2L;
    mu /= min_p_bits - 1L;

    // Calculate gamma = 14 mu ln(14 mu)
    NTL::ZZ gamma = mu * 14L * 
               static_cast<long>(ceil(log(14.0) + 
	                              log(2.0)*NTL::NumBits(mu-1L)));

    return NTL::max(NTL::NumBits(gamma-1L)+1L, min_p_bits);
}

void init_primes() {
    FIRST_PRIMES[0] = 2;
    for( int i = 1; i < NUM_PRIMES; ++i ) {
    	FIRST_PRIMES[i] = NTL::NextPrime(FIRST_PRIMES[i-1]+1,50);
    }
}
