#include <stdio.h>
#include <string.h>
#include <math.h>

#include "Util.h"
#include "Codebook.h"
#include "Lattice.h"

#include "Crop.h"

#include "VLC.h"

#define ROUND1(a)	((a)>=0 ? (int)((a)+0.5) : -(int)(0.5-(a))) 
/* #define TRUNC(a)	((a)>0 ? (int)(a) : -(int)(1-(a))) */
#define TRUNC(a)	((a)==0 ? 0 : ((a)>0) ? (int)(a) : -(int)(1-(a)))
#define NNI(x,f)        (((x)-(f))<0 ? ((f)-1): ((f)+1))

/*====================================================================================
 * A number written in base r, represented by index is converted into decimal.
 *====================================================================================*/
static int index_to_decimal (INDEX index, int r) 
{
  int sum = (int) index[0];

  sum = (sum * r) + (int) index[1];
  sum = (sum * r) + (int) index[2];
  sum = (sum * r) + (int) index[3];
  sum = (sum * r) + (int) index[4];
  sum = (sum * r) + (int) index[5];
  sum = (sum * r) + (int) index[6];
  sum = (sum * r) + (int) index[7];

  return ( sum );
}

/*====================================================================================
 * Returns the L2 norm of a point.
 *====================================================================================*/
int l2_norm (VECTOR x) 
{
  int i;
  double sum = 0.0;

  for (i=0; i<8; i++)
    sum += (x[i]*x[i]);

  return ((int) sum);
}

int l2_norm_int (int *x) 
{
  int i;
  int sum = 0;

  for (i=0; i<8; i++)
    sum += (x[i]*x[i]);

  return (sum);
}

int l1_norm (VECTOR x) 
{
  int i;
  double sum = 0.0;

  for (i=0; i<8; i++)
    sum += ABS(x[i]);

  return ((int) sum);
}

/*====================================================================================
 * Takes every component modulo r
 *====================================================================================*/
static void take_mod (INDEX index, int r) 
{
  int in;

  in = (int) index[0] % r;
  index[0] = (in < 0) ? in+r:in;

  in = (int) index[1] % r;
  index[1] = (in < 0) ? in+r:in;

  in = (int) index[2] % r;
  index[2] = (in < 0) ? in+r:in;

  in = (int) index[3] % r;
  index[3] = (in < 0) ? in+r:in;

  in = (int) index[4] % r;
  index[4] = (in < 0) ? in+r:in;

  in = (int) index[5] % r;
  index[5] = (in < 0) ? in+r:in;

  in = (int) index[6] % r;
  index[6] = (in < 0) ? in+r:in;

  in = (int) index[7] % r;
  index[7] = (in < 0) ? in+r:in;

}

/*====================================================================================
 * Returns the group number between 0 and 255 of E8 lattice point.
 *====================================================================================*/
int LatticePointGN (VECTOR x) 
{
  INDEX z;

  z[0] = 0.5*x[0] - 0.5*x[1] - 0.5*x[2] - 0.5*x[3] 
    - 0.5*x[4] - 0.5*x[5] - 0.5*x[6] + 2.5*x[7];
  z[1] = x[1] - x[7];
  z[2] = x[2] - x[7];
  z[3] = x[3] - x[7];
  z[4] = x[4] - x[7];
  z[5] = x[5] - x[7];
  z[6] = x[6] - x[7];
  z[7] = 2*x[7];

  take_mod (z, 2);

  return ( index_to_decimal (z, 2));

}

double ScalarMul(VECTOR x, VECTOR y) {
  int i;
  double sum;

  sum = 0;

  for (i=0; i<8; i++) {
    sum += (x[i]*y[i]);
  }
  return(sum);
}

static int SymIndices[3][8] = {
  { 4, 5, 6, 7, 0, 1, 2, 3 },
  { 3, 2, 1, 0, 7, 6, 5, 4 },
  { 7, 6, 5, 4, 3, 2, 1, 0 }
};

int FindSymmetry(VECTOR x) {
  int i;
  int sym;

  sym = 0;
  for (i=0; i<8; i++) {
    if (x[i] != x[SymIndices[sym][i]]) {
      sym = -1;
      break;
    }
  }

  if (sym==0) return(0);

  sym = 1;
  for (i=0; i<8; i++) {
    if (x[i] != x[SymIndices[sym][i]]) {
      sym = -1;
      break;
    }
  }

  if (sym==1) return(1);

  sym = 2;
  for (i=0; i<8; i++) {
    if (x[i] != x[SymIndices[sym][i]]) {
      sym = -1;
      break;
    }
  }

  if (sym==2) return(2);
  return(-1);

}

void ApplySymmetry(VECTOR x, VECTOR y, int type, int polarity) {
  int i;

  if (type > 0) 
    for (i=0; i<8; i++) y[i] = polarity*(x[SymIndices[type-1][i]]);
  else
    for (i=0; i<8; i++) y[i] = polarity*(x[i]);
}


/*====================================================================================
 * For input x the nearest E8 lattice point is returned in z.
 *====================================================================================*/


static double NextD8LatticePoint(double x[8], int f[8]) {

  int i, k;
  int fsum;
  double maxdist, se, e[8];
  int w[8];

  maxdist = 0; k = 0;
  fsum = 0; se = 0;

  for (i=0; i<8; i++) {

    /* First, perform the two ways of rounding described in the article.
     * The results are hold in f[i] and w[i], respectively.
     */

    if (x[i]==0) {
      f[i] = 0; w[i] = 1;
    } else if (x[i]>0) {
      if ((x[i] - ((int) x[i])) <= 0.5) {
	f[i] = (int) x[i]; w[i] = f[i] + 1;
      } else {
	w[i] = (int) x[i]; f[i] = w[i] + 1;
      }
    } else {
      if ((x[i] - ((int) x[i])) >= -0.5) {
	f[i] = (int) x[i]; w[i] = f[i] - 1;
      } else {
	w[i] = (int) x[i]; f[i] = w[i] - 1;
      }
    }
    /* Compute component sum */
    fsum += f[i];

    /* Compute rounding error and keep it in memory */
    e[i] = x[i] - (double) f[i];

    /* Compute overall SE */
    se += (e[i]*e[i]);

    /* Remember the largest rounding error for f */
    if ( ABS(e[i]) > maxdist ) {
      maxdist = ABS(x[i] - f[i]);
      k = i;
    }
  }
  
  /* Now, if f has an odd sum, we replace the component f[k] having the 
   * largest rounding error with w[k] for rounding to the second best integer.
   */

  if (fsum & 1) {
    se -= (e[k]*e[k]); /* The error for f[k] is no longer valid */ 
    f[k] = w[k];       /* f[k] is replaced by w[k] */
    e[k] = x[k] - (double) f[k];/* Recompute the rounding ... */
    se += (e[k]*e[k]); /* ... and the overall error  */
  }

  /* Return the overall SE */
  return(se);
}

void NearestE8LatticePoint(double x[8], double y[8]) {

  int i;
  int z1[8], z2[8];
  int used_vector;
  double xt[8];
  double se0, se1, se2;

  /* Compute lattice point on D8 lattice */
  se1 = NextD8LatticePoint(x, z1);

  /* Compute lattice point on translated D8 lattice */
  for (i=0; i<8; i++) xt[i] = x[i] + 0.5;
  NextD8LatticePoint(xt, z2);
  se2 = 0; se0 = 0; se1 = 0;
  for (i=0; i<8; i++) { 
    xt[i] = (double) z2[i] - 0.5;
    se1 += (x[i] - z1[i])*(x[i] - z1[i]);
    se2 += (x[i] - xt[i])*(x[i] - xt[i]);
    se0 += (x[i]*x[i]);
  }
  if (((se0 <= se1) && (se0 <= se2))) used_vector = 0;
  else if (se2==se1) {
    if (l2_norm_int(z1)<l2_norm(xt))
      used_vector = 1; else used_vector = 2;
  } else if (se2 < se1) used_vector = 2;
  else used_vector = 1;

  switch (used_vector) {
  case 0: for (i=0; i<8; i++) y[i] = 0; break;
  case 1: for (i=0; i<8; i++) y[i] = (double) z1[i]; break;
  case 2: for (i=0; i<8; i++) y[i] = xt[i]; break;
  }
}

void NearestE8LatticePoint16(double x[8], double y[8]) {

  int i;
  int z1[8], z2[8];
  double xt[8];
  double se0, se1, se2;

  /* Compute lattice point on D8 lattice */
  se1 = NextD8LatticePoint(x, z1);

  /* Compute lattice point on translated D8 lattice */
  for (i=0; i<8; i++) xt[i] = x[i] + 0.5;
  NextD8LatticePoint(xt, z2);
  se2 = 0; se0 = 0;
  for (i=0; i<8; i++) { 
    xt[i] = (double) z2[i] - 0.5;
    se2 += (x[i] - xt[i])*(x[i] - xt[i]);
    se0 += (x[i]*x[i]);
  }

  if (se2<=16) {  /* If se2 is better, compute result */

    for (i=0; i<8; i++) y[i] = xt[i];

  } else

    for (i=0; i<8; i++) y[i] = (double) z1[i];
}

void NearestE8LatticePoint1(double x[8], double y[8]) {

  int i;
  int z1[8], z2[8];
  double xt[8];
  double se1, se2;

  /* Compute lattice point on D8 lattice */
  se1 = NextD8LatticePoint(x, z1);

  /* Compute lattice point on translated D8 lattice */
  for (i=0; i<8; i++) xt[i] = x[i] + 0.5;
  NextD8LatticePoint(xt, z2);
  se2 = 0; 
  for (i=0; i<8; i++) { 
    xt[i] = (double) z2[i] - 0.5;
    se2 += (x[i] - xt[i])*(x[i] - xt[i]);
  }

  if (se2 < se1) {  /* If se2 is better, compute result */

    for (i=0; i<8; i++) y[i] = xt[i];

  } else

    for (i=0; i<8; i++) y[i] = (double) z1[i];
}

/* ============================================================================== */
