/* newinterp.inc
 * Daniel S. Roche, January 2011
 * See COPYING.txt for permissions.
 *
 * Black-box interpolation over a finite field using diversification.
 * See "Diversification improves interpolation", Giesbrecht & Roche, 2011
 *
 * Include file (template implementations)
 */

#include <cmath>
#include <utility>
#include <vector>
#include <set>
#include <map>
#include <algorithm>
#include <NTL/ZZX.h>
#include <NTL/lzz_pX.h>
#include <NTL/ZZXFactoring.h>
#include "misc.h"

/* Using stl map with complex coefficients is
 * troublesome because they are only known to a certain approximation.
 * The following templatized class ApproxMap effectively
 * handles this by taking an epsilon error bound in the constructor.
 *
 * Note that an assertion will fail in the special case that a
 * coefficient search for is at least epsilon away from all other keys,
 * but less than 2*epsilon away from at least one key. This is a problem
 * because this new coefficient is not epsilon-equal to any key, but if
 * we insert it as a new key, the map becomes ill-defined as some
 * _other_ coefficient will be less than epsilon distance from two
 * distinct keys.
 */
template <typename T>
class IndMap : public std::map<T,long> { };

template <typename T>
class IndMap< std::complex<T> > {
  public:
  typedef std::complex<T> key_type;
  typedef long data_type;
  typedef std::pair<key_type, data_type> value_type;

  private:
  std::vector<value_type> rep;
  T epsilon;

  long search (const std::complex<T>& c) const {
    bool farFromAll = true;
    for (size_t i=0; i<rep.size(); ++i) {
      if (std::norm(rep[i].first - c) < epsilon) return i;
      else if (std::norm(rep[i].first - c) < (2*epsilon)) 
        farFromAll = false;
    }
    assert (farFromAll);
    return -1;
  }

  public:
  IndMap (T eps = 0.1) :epsilon(eps) { }

  data_type& operator[] (const key_type& c) {
    long ind = search(c);
    if (ind < 0) {
      rep.push_back (value_type(c,data_type()));
      ind = rep.size()-1;
    }
    return rep[ind].second;
  }

  unsigned long size() const { return rep.size(); }
};

// For STL sorting of pointers to objects. Avoids potentially costly copying.
template <typename T>
struct PointerSort : public std::binary_function<T*, T*, bool> {
  bool operator() (const T* a, const T* b) const {
    return *a < *b;
  }
};

// Specialization of PointerSort for zz_p
template <>
bool PointerSort<NTL::zz_p>::operator() 
(const NTL::zz_p* a, const NTL::zz_p* b) const 
  { return rep(*a) < rep(*b); }

// Compares complex numbers by their absolute values
template <typename T>
static bool complex_comp (const std::complex<T>& a, const std::complex<T>& b)
{ return std::abs(a) < std::abs(b); }

template <typename T>
static T mymax (T a, T b) { return ((a>b) ? a : b); }

/* Computes the least d>=e so that the given polynomial f
 * is (e,d)-diverse. That is, all coefficients of f with
 * absolute value greater than d have pairwise differences at
 * least 2e.
 */
template <typename FloatT>
FloatT diversity_bound 
(const std::vector< std::complex<FloatT> >& f, FloatT e, long T) 
{
  // create a copy so we can sort it
  std::vector< std::complex<FloatT> > v (f); 
  // sort by abs value
  std::sort (v.begin(), v.end(), complex_comp<FloatT>); 

  typename std::vector< std::complex<FloatT> >::const_iterator viter1;
  typename std::vector< std::complex<FloatT> >::const_iterator viter2;

  // Use the upper bound on nonzero terms to get a lower bound on the
  // smallest coefficient
  FloatT d = mymax(2.0*e, abs(v[mymax(v.size()-T,0UL)])-e);

  // Skip past elements already less than d
  for (viter1 = v.begin(); 
       (viter1 != v.end()) && (std::abs(*viter1) < d); 
       ++viter1);

  // Now check whether each element is (2*e)-distinct from all those
  // with larger absolute values.
  // If not, then update d accordingly.
  for (; viter1 != v.end(); ++viter1) {
    viter2 = viter1;
    while (++viter2 != v.end()) {
      if (std::abs(*viter1 - *viter2) < (2*e)) {
        d = mymax(d,std::abs(*viter2)+e);
      }
      else if (std::abs(*viter2) - std::abs(*viter1) >= (2*e))
        // by triangle inequality
        break;
    }
  }
  return d;
}


/* Computes an approximation for the 2-norm of f,
 * provided f has at most T terms.
 * BB should be a subclass of ApproxBB
 */
template <typename BB>
typename BB::FloatT approxnorm (BB& bb, const NTL::ZZ& D, long T) {
  // Compute a good prime w.h.p.
  long p = NTL::GenPrime_long (NTL::NumBits(goodp_bound(D,T)));

  typename BB::FloatT runningsum = 0.0;
  typename BB::BaseT eval;
  NTL::ZZ denom;
  NTL::conv(denom,p);
  NTL::ZZ num;
  while (num < denom) {
    eval = bb.eval (num, denom);
    eval *= eval;
    runningsum += std::abs(eval);
    ++num;
  }
  runningsum /= ((typename BB::FloatT)p);
  return std::sqrt(runningsum);
}


/* Probabilistic interpolation method "A".
 * f: will hold the output
 * BBT: should be a subclass of UniModBB
 * bb: uni-modular black box for unknown polynomial f
 * D: upper bound on degree of f
 * T: upper bound on sparsity of f
 */
template <typename BBT>
bool new_interpA (SparsePoly< typename BBT::FloatT >& f, 
                  BBT& bb, const NTL::ZZ& D, long T)
{
  long i,j;

  // Avoid some trivial errors
  if (T < 2) T = 2;

  // First scale the black box so that the norm is 1
  // Then set up the spun black box and modular black boxes
  typename BBT::FloatT norm = approxnorm(bb,D,T);
  ScaleBB<BBT> scbb (bb, (((typename BBT::FloatT) 1.0)/norm));
  SpinBB< ScaleBB<BBT> > spbb (scbb);
  ApproxModBB< typename SpinBB< ScaleBB<BBT> >::FloatT, SpinBB< ScaleBB<BBT> > >
    ambb (spbb);

  // Compute lambda s.t. a prime in the range lambda..2*lambda is good w.h.p.
  long lambda = goodp_bound (D, T);
  long goodplen = NTL::NumBits(lambda) + 1;

  // For modular black box evaluation
  std::vector<typename BBT::BaseT> fp;
  long t = 0;
  long sfp;
  std::vector< std::pair<long, std::vector<typename BBT::BaseT> > > goodevals(1);
  long p;
#ifdef VERBOSE
  long nbad = 0;
  long nalphatrials = 0;
#endif

  // Now search for a good zero threshold delta and diversifying
  // root of unity alpha = exp(2*Pi*anum/adenom).
  // In the process, we also get the first "good prime" evaluation.
  typename BBT::FloatT inf = 
    ((typename BBT::FloatT)1.0)/((typename BBT::FloatT)0.0);
  typename BBT::FloatT delta = inf;
  { 
    long s = 2;
    typename BBT::FloatT d;
    NTL::ZZ goodanum;
    NTL::ZZ goodadenom;
    while (delta > (((typename BBT::FloatT)2.0)*bb.epsilon) && (s < T*T)) {
#ifdef VERBOSE
      ++nalphatrials;
#endif
      conv (spbb.wdenom, s);
      NTL::RandomBnd (spbb.wnum, spbb.wdenom);
      p = NTL::GenPrime_long (goodplen);
      ambb.eval (fp, p);
      d = diversity_bound 
        (fp, bb.epsilon*std::sqrt((typename BBT::FloatT)p), T);
      sfp = sparsity(fp, d);
      if (sfp > t) {
        delta = inf;
        t = sfp;
      }
      if (sfp == t) {
        if (d < delta) {
          delta = d;
          goodevals[0].first = p;
          goodevals[0].second = fp;
          goodanum = spbb.wnum;
          goodadenom = spbb.wdenom;
        }
      }
      s = NTL::NextPrime(2*s);
    }
    spbb.wnum = goodanum;
    spbb.wdenom = goodadenom;
  }

#ifdef VERBOSE
  std::cout << "It took " << nalphatrials << " trials to find a good spin."
            << std::endl
            << "Finally, we have delta=" << delta 
            << " (compare to epsilon=" << bb.epsilon
            << ")" << std::endl
            << "and sparsity = " << t << std::endl;
#endif

  // Use NTL to choose random primes that are good w.h.p.,
  // as well as random alpha s.t. f(alpha x) is diverse w.h.p.
  // goodevals holds black box evaluations diverse with sparsity t.
  // Once the product of primes is larger than D, we move on.
  NTL::ZZ goodprod;
  conv(goodprod,goodevals[0].first); // product of good primes

  while (goodprod <= D) {
    do { p = NTL::GenPrime_long (goodplen); } while (divide(goodprod,p));
    ambb.eval(fp, p);
    sfp = sparsity(fp, delta);
    if (sfp > t) {
      // This will probably never happen.
      return false;
    }
    if (sfp == t) {
      goodevals.resize (goodevals.size()+1);
      goodevals.back().first = p;
      goodevals.back().second = fp;
      goodprod *= p;
    }
#ifdef VERBOSE
    else ++nbad;
#endif
  }
#ifdef VERBOSE
  std::cout << "Sampled mod x^p-1 for " 
            << goodevals.size() << " good primes and "
            << nbad << " bad primes" << std::endl
            << "in the range " << (1L<<(goodplen-1)) << " <= p <= "
            << ((1L<<goodplen)-1) << std::endl;
#endif

  // For some reason vector CRT in NTL is not defined for a vector type,
  // so we use a build a ZZX object to hold the exponents of f
  // (as coefficients, not as roots).
  NTL::ZZX expons;
  NTL::zz_pX expons_p;
  expons_p.rep.SetLength(t);
  std::vector<long> expons_vec(t);
  NTL::zz_pBak bak;
  NTL::ZZ pprod;
  typename std::vector< std::pair<long, std::vector<typename BBT::BaseT> > >::const_iterator iter = goodevals.begin();

  // Use the first good evaluation to build an IndMap of
  // coefficients to indices. This will be used to match up the terms
  // in different evaluations.
  // We also initialize expons with this first image modulo p.
  IndMap <typename BBT::BaseT> cmap (2*bb.epsilon);
  i=0;
  for (j=0; j<((long)iter->second.size()); ++j) {
    if (std::abs(iter->second[j]) >= delta) {
      NTL::SetCoeff(expons,i,j);
      cmap[iter->second[j]] = i++;
    }
  }
  NTL::conv(pprod,iter->first);

  while (++iter != goodevals.end()) {
    for (j=0; j<((long)iter->second.size()); ++j) {
      if (std::abs(iter->second[j]) >= delta)
        expons_vec[ cmap[iter->second[j]] ] = j;
    }
    // Now we have to change the global NTL prime by "pushing" onto the "stack"
    bak.save();
    NTL::zz_p::init(iter->first);
    // Now copy expons_vec to expons_p
    for (i=0; i<t; ++i)
      NTL::conv (expons_p.rep[i], expons_vec[i]);
    // Update expons with the new modular image at each coefficient
    NTL::CRT (expons, pprod, expons_p);
    bak.restore(); // "pop" our prime off the "stack"
  }

  // We want to get the exponents in sorted order.
  // So make a vector of pointers to each coefficient of expons, then sort it.
  std::vector<NTL::ZZ*> exptrs(t);
  for (i=0; i<t; ++i) exptrs[i] = &(expons.rep[i]);
  std::sort (exptrs.begin(), exptrs.end(), PointerSort<NTL::ZZ>() );

  // Now extract the exponents from factors and store them in f.
  // Simultaneously, use a single good-prime evaluation to
  // get the coefficients.
  // Each coefficient must be divided by alpha^i.

  f.rep.resize(t);
  typename SparsePoly<typename BBT::FloatT>::RepT::iterator 
    fiter = f.rep.begin();
  for (std::vector<NTL::ZZ*>::const_iterator eiter = exptrs.begin();
       eiter != exptrs.end(); ++eiter) {
    fiter->second = **eiter;
    fiter->first = goodevals.front().second
      [NTL::rem(fiter->second,goodevals.front().first)];
    
    // Multiply by the norm
    fiter->first *= norm;
    // Multiply by alpha^-(e_i)
    fiter->first *= 
      ru<typename BBT::FloatT> ((-(fiter->second))*spbb.wnum, spbb.wdenom);
    ++fiter;
  }

  return true;
}

