/* pptest.cc
 * Dan Roche, 21 Jan 2007
 * http://www.cs.uwaterloo.ca/~droche/
 *
 * Program to compare accuracy and running time of three
 * algorithms to determine perfect-poweredness of polynomials;
 * see pp.h for a description of the algorithms tested.
 */

#include <cmath>
#include <iostream>
#include <fstream>
#include <iomanip>
#include <string>
#include <vector>
#include <NTL/ZZX.h>
#include "pp.h"

struct HoldTime { double time; };
std::ostream& operator<<( std::ostream& out, const HoldTime &ht );
inline HoldTime prtime(double t) { HoldTime ht = {t}; return ht; }

struct TestCase {
    NTL::ZZX dense_rep;
    std::pair<NTL::vec_ZZ,NTL::vec_long> sparse_rep;
    long terms;
    long bits_inf_norm;
    long realAnswer;
    long lacunaryAnswer;
    long denseAnswer;
};

int main(void) {
    NTL::ZZ seed;
    long i,j;

    // Try to use /dev/urandom for seed
    std::ifstream randin("/dev/urandom");
    if( randin ) {
        char seedbytes[16];
        randin.read(seedbytes,16);
        randin.close();
        NTL::ZZFromBytes(seed,reinterpret_cast<unsigned char*>(seedbytes),16);
    }
    else { // or default to current time
        seed = time(NULL);
    }
    NTL::SetSeed(seed);

    // Initialize the array of primes
    init_primes();

    // Get parameters from the user
    long deg_h, tms_h, ht_h, r, nruns;
    std::string in;

    while(true) {
        deg_h = 1000;
        std::cout << "Degree of h [" << deg_h << "]: " << std::flush;
        std::getline(std::cin,in);
        size_t firstdigit = in.find_first_not_of(" \t");
        if( firstdigit == std::string::npos ) break;
        //else
        deg_h = atol(in.data() + firstdigit);
        if( deg_h > 0 ) break;
        //else
        std::cout << "Please enter a positive integer value." << std::endl;
    }

    while(true) {
        tms_h = 5;
        std::cout << "Sparsity of h [" << tms_h << "]: " << std::flush;
        std::getline(std::cin,in);
        size_t firstdigit = in.find_first_not_of(" \t");
        if( firstdigit == std::string::npos ) break;
        //else
        tms_h = atol(in.data() + firstdigit);
        if( tms_h > 1 ) break;
        //else
        std::cout << "Please enter an integer value at least 2." << std::endl;
    }

    while(true) {
        ht_h = 20;
        std::cout << "Num. of bits in coefficients of h ["
                  << ht_h << "]: " << std::flush;
        std::getline(std::cin,in);
        size_t firstdigit = in.find_first_not_of(" \t");
        if( firstdigit == std::string::npos ) break;
        //else
        ht_h = atol(in.data() + firstdigit);
        if( ht_h > 0 ) break;
        //else
        std::cout << "Please enter a positive integer value." << std::endl;
    }

    while(true) {
        r = 6;
        std::cout << "Raise h to the power [" << r << "]: " << std::flush;
        std::getline(std::cin,in);
        size_t firstdigit = in.find_first_not_of(" \t");
        if( firstdigit == std::string::npos ) break;
        //else
        r = atol(in.data() + firstdigit);
        if( r > 0 ) break;
        //else
        std::cout << "Please enter a positive integer value." << std::endl;
    }

    while(true) {
        nruns = 100;
        std::cout << "Number of trials [" << nruns << "]: " << std::flush;
        std::getline(std::cin,in);
        size_t firstdigit = in.find_first_not_of(" \t");
        if( firstdigit == std::string::npos ) break;
        //else
        nruns = atol(in.data() + firstdigit);
        if( nruns > 0 ) break;
        //else
        std::cout << "Please enter a positive integer value." << std::endl;
    }

    TestCase *tcases = new TestCase[nruns];
    double timeOff, sqdTime, lacunaryTime, denseTime;
    long deg_f = deg_h * r, pb;
    NTL::ZZX h;
    NTL::ZZ coeff;
    // Use a vector of positions (exponents) to randomly choose tms_h
    // coefficients of h to initialize. The best way to do this is probably
    // by using a balanced binary search tree, but this is all preprocessing
    // anyway.
    std::vector<long> posns(deg_h);
    std::vector<long>::iterator iter;

    std::cout << "Initializing Test Cases......... 0%" << std::flush;

    for(i = 0; i < nruns; ++i) {
    	// Reinitialize the posns vector
	posns.resize(deg_h);
	for(j = 0; j < deg_h; ++j) posns[j]=j;

	// Initialize leading coefficient
	NTL::clear(h);
	do { RandomBits(coeff,ht_h); } while( IsZero(coeff) );
	SetCoeff(h,deg_h,coeff);

	// Initialize the rest of the coefficients
	while( deg_h < (tms_h + posns.size()) ) {
	    iter = posns.begin() + NTL::RandomBnd(posns.size());
	    do { RandomBits(coeff,ht_h); } while( IsZero(coeff) );
	    SetCoeff(h, *iter, coeff);
	    posns.erase(iter);
	}

	// Compute h^r
	j = r;
	NTL::set( tcases[i].dense_rep );
	while(true) {
	    if( j % 2 ) tcases[i].dense_rep *= h;
	    j >>= 1;
	    if( j == 0 ) break;
	    NTL::sqr(h,h);
	}
	
	// Fill in the rest of the test case
	denseToSparse( tcases[i].sparse_rep, tcases[i].dense_rep );
	tcases[i].terms = tcases[i].sparse_rep.first.length();
	tcases[i].bits_inf_norm = bits_in_inf_norm(tcases[i].dense_rep);

	// Progress indicator
	if( i*100L/nruns > (i-1)*100L/nruns ) {
	    std::cout << "\b\b\b" << std::setw(2) << i*100L/nruns << '%'
	              << std::flush;
	}
    }
    std::cout << "\b\b\bdone" << std::endl;

    std::cout << "Computing Timing Offset.........   " << std::flush;
    j=0;
    timeOff = -NTL::GetTime();
    for(i=0; i<nruns; ++i ) {
        // To ensure each struct in the array is accessed
	j ^= tcases[i].bits_inf_norm;
	if( i*100/nruns > (i-1)*100L/nruns ) {
	    std::cout << "\b\b\b" << std::setw(2) << j%100 << ' ' << std::flush;
	}
    }
    timeOff += NTL::GetTime();
    std::cout << "\b\b\bdone\r\r";

    std::cout << "Square-Free Decomposition....... 0% \b" << std::flush;
    sqdTime = -NTL::GetTime();
    for(i=0; i<nruns; ++i) {
	tcases[i].realAnswer = pp_sqd(tcases[i].dense_rep);

	// Progress indicator
	if( i*100L/nruns > (i-1)*100L/nruns ) {
	    std::cout << "\b\b\b" << std::setw(2) << i*100L/nruns << '%'
	              << std::flush;
	}
    }
    sqdTime += NTL::GetTime();
    std::cout << "\b\b\bdone" << std::endl;

    std::cout << "Lacunary Algorithm.............. 0%" << std::flush;
    pb = pbound( deg_h, tms_h, ht_h, r, lacunary_min_p_bits(deg_f) );
    lacunaryTime = -NTL::GetTime();
    for(i=0; i<nruns; ++i) {
	tcases[i].lacunaryAnswer =
	    pp_lacunary( tcases[i].sparse_rep, deg_f, pb,
	                 tcases[i].bits_inf_norm );
	
	// Progress indicator
	if( i*100L/nruns > (i-1)*100L/nruns ) {
	    std::cout << "\b\b\b" << std::setw(2) << i*100L/nruns << '%'
	              << std::flush;
	}
    }
    lacunaryTime += NTL::GetTime();
    std::cout << "\b\b\bdone" << std::endl;

    std::cout << "Newton Iteration................ 0%" << std::flush;
    pb = pbound( deg_h, tms_h, ht_h, r, dense_min_p_bits(deg_f) );
    denseTime = -NTL::GetTime();
    for(i=0; i<nruns; ++i) {
        tcases[i].denseAnswer = 
	    pp_dense( tcases[i].dense_rep, pb, tcases[i].terms );

	// Progress indicator
	if( i*100L/nruns > (i-1)*100L/nruns ) {
	    std::cout << "\b\b\b" << std::setw(2) << i*100L/nruns << '%'
	              << std::flush;
	}
    }
    denseTime += NTL::GetTime();
    std::cout << "\b\b\bdone" << std::endl;

    //std::cout << "Tabulating Results.............. 0%" << std::flush;
    NTL::ZZ avgSparsity;
    long lacunaryWrong = 0L, denseWrong = 0L;
    for(i=0; i<nruns; ++i) {
	avgSparsity += tcases[i].terms;
	if( tcases[i].realAnswer == 1 ) {
	    if( tcases[i].lacunaryAnswer > 1 ) ++lacunaryWrong;
	    if( tcases[i].denseAnswer > 1 ) ++denseWrong;
	}
	else {
	    if( (tcases[i].lacunaryAnswer == 1) ||
	        (tcases[i].realAnswer % tcases[i].lacunaryAnswer) )
	        ++lacunaryWrong;
	    if( (tcases[i].denseAnswer == 1) ||
	        (tcases[i].realAnswer % tcases[i].denseAnswer) ) {
	        ++denseWrong;
		std::cerr << "Real: " << tcases[i].realAnswer << ", Dense: " 
		          << tcases[i].denseAnswer << std::endl;
	    }
	}
	
	// Progress indicator
	/*if( i*100L/nruns > (i-1)*100L/nruns ) {
	    std::cout << "\b\b\b" << std::setw(2) << i*100L/nruns << '%'
	              << std::flush;
	}*/
    }
    avgSparsity = (avgSparsity + (nruns/2))/nruns;
    delete [] tcases;
    //std::cout << "\b\b\bdone" << std::endl;

    std::cout.setf(std::ios_base::showpoint | std::ios_base::fixed);
    std::cout.precision(2);

    if( nruns == 1 )
    	std::cout << std::endl << "Tested f=h^" << r
	          << " with degree " << deg_f << " and "
		  << avgSparsity << " nonzero terms ("
		  << (NTL::to_long(avgSparsity)*100.0/deg_f) << "%)." 
		  << std::endl;
    else
	std::cout << std::endl << "Tested " << nruns 
	          << " polynomials f=h^" << r << ", each with"
		  << std::endl << "degree " << deg_f 
		  << " and on average "
		  << avgSparsity << " nonzero terms ("
		  << (NTL::to_long(avgSparsity)*100.0/deg_f) << "%)." 
		  << std::endl;

    std::cout << "Square-Free Decomposition time: "
              << prtime(sqdTime) << std::endl;

    std::cout << "Lacunary Algorithm time:        "
              << prtime(lacunaryTime) << " with "
	      << lacunaryWrong << (lacunaryWrong==1 ? " mistake" : " mistakes")
	      << " (" << (lacunaryWrong*100.0/nruns) << "%)"
	      << std::endl;

    std::cout << "Newton Iteration time:          "
              << prtime(denseTime) << " with "
	      << denseWrong << (denseWrong==1 ? " mistake" : " mistakes")
	      << " (" << (denseWrong*100.0/nruns) << "%)"
	      << std::endl;

    return 0;
}

// Write out the given time (in seconds) so it looks pretty
std::ostream& operator<<( std::ostream& out, const HoldTime &ht ) {
    double t = ht.time;
    std::ios::fmtflags origflags = 
        out.flags( std::ios_base::dec |
	           std::ios_base::fixed |
		   std::ios_base::right );
    char origfill = out.fill();
    if( t >= 3600.0 ) {
        out << std::setw(2) << std::setprecision(0) << floor(t/3600.0) << ':';
	t = fmod(t,3600.0);
	out.fill('0');
    }
    else out << "   ";
    if( t >= 60.0 ) {
        out << std::setw(2) << std::setprecision(0) << floor(t/60.0) << ':';
	t = fmod(t,60.0);
	out.fill('0');
    }
    else out << "   ";
    out.setf(std::ios_base::showpoint);
    out << std::setw(5) << std::setprecision(2) << t;
    out.fill(origfill);
    out.flags(origflags);
    return out;
}
