#include <vlcutils/error.h>
#include "freqdist.h"

/* Increment the frequency of char. */
void increment(char c, int q[])
{
     switch (c) {
     case 'a':
     case 'A':
	  q[0]++;
	  break;
     case 'c':
     case 'C':
	  q[1]++;
	  break;
     case 'g':
     case 'G':
	  q[2]++;
	  break;
     case 't':
     case 'T':
	  q[3]++;
	  break;
     case 'n':
     case 'N':
	  q[4]++;
	  break;
     default:
	  abort_error("Unexpected letter: `%c'", c);
     }
}

/* Decrement the frequency of char. */
void decrement(char c, int q[])
{
     switch (c) {
     case 'a':
     case 'A':
	  q[0]--;
	  break;
     case 'c':
     case 'C':
	  q[1]--;
	  break;
     case 'g':
     case 'G':
	  q[2]--;
	  break;
     case 't':
     case 'T':
	  q[3]--;
	  break;
     case 'n':
     case 'N':
	  q[4]--;
	  break;
     default:
	  abort_error("Unexpected letter: `%c'", c);
     }
}

int infdist(int q[], MBR box)
{
     int j, dist, tmp, posDist, negDist, qw, sum;

     dist = qw = sum = 0;
     posDist = negDist = 0;
     for (j = 0; j < ALPH_LEN - 1; j++) {
	  qw += q[j];
	  tmp = 0;
	  if (q[j] < box.lower[j]) {
	       posDist += box.lower[j] - q[j];
	       sum += box.lower[j];
	  }
	  else if (q[j] > box.higher[j]) {
	       negDist -= box.higher[j] - q[j];
	       sum += box.higher[j];
	  }
	  else
	       sum += q[j];
     }

     if (sum > qw) {
	  tmp = sum - qw;
	  posDist = 4 * posDist;
	  negDist = 4 * negDist + 4 + tmp;
     }
     else {
	  tmp = qw - sum;
	  posDist = 4 * posDist + 4 + tmp;
	  negDist = 4 * negDist;
     }

     if (posDist > negDist)
	  dist = posDist;
     else
	  dist = negDist;

     return dist;
}

/* Returns true if two MBRs intersect.  Can this be done faster? */
int intersect(MBR r1, MBR r2)
{
     int i;

     for (i = 0; i < ALPH_LEN; i++) {
	  if (r1.higher[i] < r2.lower[i])
	       return 0;
	  if (r2.higher[i] < r1.lower[i])
	       return 0;
     }
     return 1;
}


/* Returns the distance between the portions of b1 & b2 that intersect
 * with the data plane for resolution w.  Parameters: two boxes b1 &
 * b2, and a resolution w. */
int box_box_distance(MBR b1, MBR b2, int w)
{
     int posDist = 0, negDist = 0, sum1 = 0, sum2 = 0, min1 = 0, min2 = 0,
	  max1 = 0, max2 = 0,i;

     for (i = 0; i < ALPH_LEN - 1; i++) {
	  /* Case 1: b1 > b2 along dimension i */
	  if (b2.higher[i] < b1.lower[i]) {
	       posDist += (b1.lower[i] - b2.higher[i]);
	       sum1 += b1.lower[i];
	       sum2 += b2.higher[i];
	  }
	  /* Case 2: b1 < b2 along dimension i */
	  else if (b1.higher[i] < b2.lower[i]) {
	       negDist += (b2.lower[i] - b1.higher[i]);
	       sum1 += b1.higher[i];
	       sum2 += b2.lower[i];
	  }
	  /* Case 3: b1 & b2 overlap along dimension i */
	  else {
	       min1 += b1.lower[i];
	       max1 += b1.higher[i];
	       min2 += b2.lower[i];
	       max2 += b2.higher[i];
	  }
     }

     /* Check whether the constructed vectors lie on the data plane.
      * If not, then shift them to data plane by increasing the
      * distance. */
     if (w < sum1 + min1)
	  negDist += sum1 + min1 - w;
     else if (sum1+max1 < w)  
	  posDist += w - sum1 - min1;

     if (w < sum2 + min2)
	  posDist += sum2 + min2 - w;
     else if (sum2 + max2 < w)  
	  negDist += w - sum2 - min2;

     /* Multiply by 4 because of BLAST's scoring scheme for
      * nucleotides (+1 for match, -3 for mismatch).  There is no gap
      * penalty here since only vectors on the same data plane are
      * considered. */
     return (posDist>negDist) ? (4 * posDist) : (4 * negDist); 
}
