#include <stdio.h>
#include <malloc.h>

#include "Util.h"
#include "Codebook.h"
#include "Lattice.h"
#include "Training.h"

/* ****************************************************************************** */
static double Shell1Point[120][8] = {
#include "Codebooks/Shell1Default"
};

static double Shell2Point[800][8] = {
#include "Codebooks/Shell2Default"
};
/* ****************************************************************************** */

typedef struct {

  double   mse;
  double   r;
  unsigned long hits;
  VECTOR   x;

} CellStruct;

static CellStruct *Shell1SPYCells[2], *Shell1TPYCells[2], *Shell1SPCCells[2], *Shell1TPCCells[2];
static CellStruct *Shell2SPYCells[2], *Shell2TPYCells[2], *Shell2SPCCells[2], *Shell2TPCCells[2];

void InitTraining() {

  Shell1SPYCells[0] = (CellStruct *) malloc(120*sizeof(CellStruct));
  Shell1SPYCells[1] = (CellStruct *) malloc(120*sizeof(CellStruct));

  Shell2SPYCells[0] = (CellStruct *) malloc(800*sizeof(CellStruct));
  Shell2SPYCells[1] = (CellStruct *) malloc(800*sizeof(CellStruct));

  Shell1SPCCells[0] = (CellStruct *) malloc(120*sizeof(CellStruct));
  Shell1SPCCells[1] = (CellStruct *) malloc(120*sizeof(CellStruct));

  Shell1TPYCells[0] = (CellStruct *) malloc(120*sizeof(CellStruct));
  Shell1TPYCells[1] = (CellStruct *) malloc(120*sizeof(CellStruct));

  Shell1TPCCells[0] = (CellStruct *) malloc(120*sizeof(CellStruct));
  Shell1TPCCells[1] = (CellStruct *) malloc(120*sizeof(CellStruct));

  Shell2TPYCells[0] = (CellStruct *) malloc(800*sizeof(CellStruct));
  Shell2TPYCells[1] = (CellStruct *) malloc(800*sizeof(CellStruct));

  Shell2SPCCells[0] = (CellStruct *) malloc(800*sizeof(CellStruct));
  Shell2SPCCells[1] = (CellStruct *) malloc(800*sizeof(CellStruct));

  Shell2TPCCells[0] = (CellStruct *) malloc(800*sizeof(CellStruct));
  Shell2TPCCells[1] = (CellStruct *) malloc(800*sizeof(CellStruct));

}

static void ResetShell1TrainingsSet(CellStruct cell[120], double point[120][8]) {
  int index, i;
  for (index=0; index<120; index++) {
    cell[index].hits = 1;
    cell[index].mse  = 0;
    cell[index].r  = 0;
    for (i=0; i<8; i++) cell[index].x[i] = point[index][i];
  }
}

static void ResetShell2TrainingsSet(CellStruct cell[800], double point[800][8]) {
  int index, i;
  for (index=0; index<800; index++) {
    cell[index].hits = 1;
    cell[index].mse  = 0;
    cell[index].r  = 0;
    for (i=0; i<8; i++) cell[index].x[i] = point[index][i];
  }
}
  
void ResetTrainset() {

  ResetShell1TrainingsSet(Shell1SPYCells[0], Shell1Point);
  ResetShell1TrainingsSet(Shell1SPYCells[1], Shell1Point);

  ResetShell2TrainingsSet(Shell2SPYCells[0], Shell2Point);
  ResetShell2TrainingsSet(Shell2SPYCells[1], Shell2Point);

  ResetShell1TrainingsSet(Shell1TPYCells[0], Shell1Point);
  ResetShell1TrainingsSet(Shell1TPYCells[1], Shell1Point);

  ResetShell1TrainingsSet(Shell1SPCCells[0], Shell1Point);
  ResetShell1TrainingsSet(Shell1SPCCells[1], Shell1Point);

  ResetShell1TrainingsSet(Shell1TPCCells[0], Shell1Point);
  ResetShell1TrainingsSet(Shell1TPCCells[1], Shell1Point);

  ResetShell2TrainingsSet(Shell2TPYCells[0], Shell2Point);
  ResetShell2TrainingsSet(Shell2TPYCells[1], Shell2Point);

  ResetShell2TrainingsSet(Shell2SPCCells[0], Shell2Point );
  ResetShell2TrainingsSet(Shell2SPCCells[1], Shell2Point );

  ResetShell2TrainingsSet(Shell2TPCCells[0], Shell2Point);
  ResetShell2TrainingsSet(Shell2TPCCells[1], Shell2Point);
}

static double vmse (VECTOR x, VECTOR y) {
  int i;
  double diff;
  double mse = 0;

  for (i=0; i<8; i++) {
    diff = x[i] - y[i];
    mse += (diff*diff);
  }
  return(mse/8.0);
}

static double vradius (VECTOR x) {
  int i;
  double r = 0;

  for (i=0; i<8; i++) {
    r += (x[i]*x[i]);
  }
  return(r);
}

static void ComputeAllShell1Centros(char *name, CellStruct cell[120]) {
  FILE *prot;
  int index, i;

  prot = fopen(name,"w");

  for (index=0; index<120; index++) {

    for (i=0; i<8; i++) {

      cell[index].x[i] = (cell[index].x[i]/cell[index].hits);

      if (i<7)
	fprintf(prot, "%f, ", cell[index].x[i]);
    }
    fprintf(prot, "%f,\n", cell[index].x[7]);
  }
  fclose(prot);
}

static void ComputeAllShell2Centros(char *name, CellStruct cell[800]) {
  FILE *prot;
  int index, i;

  prot = fopen(name,"w");

  for (index=0; index<800; index++) {

    for (i=0; i<8; i++) {
      cell[index].x[i] = (cell[index].x[i]/cell[index].hits);

      if (i<7)
	fprintf(prot, "%f, ", cell[index].x[i]);
    }
    fprintf(prot, "%f,\n", cell[index].x[7]);
  }
  fclose(prot);
}
 
/* ****************************************************************************** */

void ComputeCentroids() {

  ComputeAllShell1Centros("Shell1SPY0", Shell1SPYCells[0]);
  ComputeAllShell1Centros("Shell1SPY1", Shell1SPYCells[1]);

  ComputeAllShell2Centros("Shell2SPY0", Shell2SPYCells[0]);
  ComputeAllShell2Centros("Shell2SPY1", Shell2SPYCells[1]);

  ComputeAllShell1Centros("Shell1TPY0", Shell1TPYCells[0]);
  ComputeAllShell1Centros("Shell1TPY1", Shell1TPYCells[1]);

  ComputeAllShell1Centros("Shell1SPC0", Shell1SPCCells[0]);
  ComputeAllShell1Centros("Shell1SPC1", Shell1SPCCells[1]);

  ComputeAllShell1Centros("Shell1TPC0", Shell1TPCCells[0]);
  ComputeAllShell1Centros("Shell1TPC1", Shell1TPCCells[1]);

  ComputeAllShell2Centros("Shell2TPY0", Shell2TPYCells[0]);
  ComputeAllShell2Centros("Shell2TPY1", Shell2TPYCells[1]);

  ComputeAllShell2Centros("Shell2SPC0", Shell2SPCCells[0]);
  ComputeAllShell2Centros("Shell2SPC1", Shell2SPCCells[1]);

  ComputeAllShell2Centros("Shell2TPC0", Shell2TPCCells[0]);
  ComputeAllShell2Centros("Shell2TPC1", Shell2TPCCells[1]);

} 
/* ****************************************************************************** */

void AddShell1SPYCellEntry ( int index, double * x, double * y, int ti) {
  int i;

  for (i=0; i<8; i++) 
    Shell1SPYCells[ti][index].x[i] += x[i];

  Shell1SPYCells[ti][index].r += vradius(x);
  Shell1SPYCells[ti][index].mse += vmse(x,y);
  Shell1SPYCells[ti][index].hits++;
}

void AddShell1TPYCellEntry ( int index, double * x, double * y, int ti) {
  int i;

  for (i=0; i<8; i++) 
    Shell1TPYCells[ti][index].x[i] += x[i];

  Shell1TPYCells[ti][index].r += vradius(x);
  Shell1TPYCells[ti][index].mse += vmse(x,y);
  Shell1TPYCells[ti][index].hits++;
}

void AddShell1SPCCellEntry ( int index, double * x, double * y, int ti) {
  int i;

  for (i=0; i<8; i++) 
    Shell1SPCCells[ti][index].x[i] += x[i];

  Shell1SPCCells[ti][index].r += vradius(x);
  Shell1SPCCells[ti][index].mse += vmse(x,y);
  Shell1SPCCells[ti][index].hits++;
}

void AddShell1TPCCellEntry ( int index, double * x, double * y, int ti) {
  int i;

  for (i=0; i<8; i++) 
    Shell1TPCCells[ti][index].x[i] += x[i];

  Shell1TPCCells[ti][index].r += vradius(x);
  Shell1TPCCells[ti][index].mse += vmse(x,y);
  Shell1TPCCells[ti][index].hits++;
}

void AddShell2SPYCellEntry ( int index, double * x, double * y, int ti) {
  int i;

  for (i=0; i<8; i++) 
    Shell2SPYCells[ti][index].x[i] += x[i];

  Shell2SPYCells[ti][index].r += vradius(x);
  Shell2SPYCells[ti][index].mse += vmse(x,y);
  Shell2SPYCells[ti][index].hits++;
}

void AddShell2TPYCellEntry ( int index, double * x, double * y, int ti) {
  int i;

  for (i=0; i<8; i++) 
    Shell2TPYCells[ti][index].x[i] += x[i];

  Shell2TPYCells[ti][index].r += vradius(x);
  Shell2TPYCells[ti][index].mse += vmse(x,y);
  Shell2TPYCells[ti][index].hits++;
}

void AddShell2SPCCellEntry ( int index, double * x, double * y, int ti) {
  int i;

  for (i=0; i<8; i++) 
    Shell2SPCCells[ti][index].x[i] += x[i];

  Shell2SPCCells[ti][index].r += vradius(x);
  Shell2SPCCells[ti][index].mse += vmse(x,y);
  Shell2SPCCells[ti][index].hits++;
}

void AddShell2TPCCellEntry ( int index, double * x, double * y, int ti) {
  int i;

  for (i=0; i<8; i++) 
    Shell2TPCCells[ti][index].x[i] += x[i];

  Shell2TPCCells[ti][index].r += vradius(x);
  Shell2TPCCells[ti][index].mse += vmse(x,y);
  Shell2TPCCells[ti][index].hits++;
}

