package sg.edu.nus;

import java.io.IOException;
import java.util.BitSet;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import org.apache.hadoop.io.DoubleWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapred.*;

/*
 * author: wangzhengkui
 * National university of singapore
 */


public class TwoSnpsReducerSquareChopping extends MapReduceBase implements
		Reducer<Text, Text, NullWritable, Text> {
	public Text newOutputValue = new Text("");
	byte[][] locusValue = { { 1 << 0, 1 << 0 }, { 1 << 0, 1 << 1 },
			{ 1 << 0, 1 << 2 }, { 1 << 1, 1 << 0 }, { 1 << 1, 1 << 1 },
			{ 1 << 1, 1 << 2 }, { 1 << 2, 1 << 0 }, { 1 << 2, 1 << 1 },
			{ 1 << 2, 1 << 2 } };
	ArrayList<SingleSnpInfor> arraylist = new ArrayList<SingleSnpInfor>();
	SingleSnpInfor oldSnp;// fetching the old snp from the arraly list;
	SingleSnpInfor newSingleSnp;
	int[][] idListNum; // used for count the p-value
	String[] keyArray;
	int newSnpValue = 0;
	int[] diagonalAndEndsnp = new int[3]; 
	/* mark whether it is diagonal reducer or not*/
	int checkOrNot = 0; // whether has already checked is diagonal or not
	int partitionNum = 0;
	byte[] tmpValue = null;
	byte[] dieaArray = { 1 << 0, 1 << 1 };
	byte[] gtArray = { 1 << 0, 1 << 1, 1 << 2 };
	int snpBitLength = 334;
	byte[] outputArray;
	byte[] firstSnp;
	byte[] secondSnp;
	byte[] calValue;
	byte[][][] newSampleList; // used for storing the sampleIDlist after pairing
	int ptNum = 2;
	int gtNum = 3;
	int statisticMarker;
	double x2Value=0;

	@Override
	public void configure(JobConf job) {

		this.partitionNum = job.getInt("partition.num", 0);
		this.statisticMarker=job.getInt("statistic.method", 1);
	}

	/***************************************************************************
	 * output value format
	 *|first_snp|second_snp|x^2value|first_GT|Second_GT|PT|sampleid_bit_list|
	 *|5|5|5|1|1|1|334| 
	 *6 bits in one byte is used to store the sample ids
	 **************************************************************************/
	public void reduce(Text key, Iterator<Text> values,
			OutputCollector<NullWritable, Text> output, Reporter reporter) {
		try {
			// System.out.println("run here and key is:"+ key.toString());
			keyArray = key.toString().split("\\|");
			newSnpValue = Integer.valueOf(keyArray[1]);
			if (checkOrNot == 0) {
				this.diagonalAndEndsnp = isDiagonalOrNot(this.partitionNum,
						Integer.valueOf(keyArray[2]));

				checkOrNot = 1;
			}// if it is not Diagonal

			newSingleSnp = new SingleSnpInfor(newSnpValue);
			while (values.hasNext()) {
				tmpValue = values.next().getBytes();
				int tmpDiea = 0;
				int tmpGT = 0;
				for (int i = 0; i < 2; i++) {
					if (dieaArray[i] == tmpValue[snpBitLength + 1]) {
						tmpDiea = i;
					}
				}
				for (int k = 0; k < 3; k++) {
					if (gtArray[k] == tmpValue[snpBitLength]) {
						tmpGT = k;
					}
				}
				for (int byteIndex = 0; byteIndex < snpBitLength; byteIndex++) {
					newSingleSnp.sampleIdBits[tmpDiea][tmpGT][byteIndex] = tmpValue[byteIndex];
				}
			}

			if (keyArray[0].equals("0")) {
				this.arraylist.add(newSingleSnp);
			} else {
				if (this.diagonalAndEndsnp[0] == 0) {

					for (int i = 0; i < this.arraylist.size(); i++) {

						idListNum = new int[2][9]; // used for count the p-value
						newSampleList = new byte[2][9][snpBitLength];
						/* used for storing the sampleIDlist after pairing */
						oldSnp = this.arraylist.get(i);

						compareTwoSingleSnp(idListNum, newSampleList, oldSnp,
								newSingleSnp);
						x2Value=0;
						switch (this.statisticMarker) {
						case 1: {
							x2Value = StatisticCollection.caculateCS(idListNum,
									ptNum, gtNum, 2);
							this.calValue = Converter.intToBytes2((int) (x2Value*(double)100));
							break;
						}
						case 2: {
							x2Value = StatisticCollection.caculateLHR(idListNum,
									ptNum, gtNum, 2);
							this.calValue = Converter.intToBytes2((int) (x2Value*(double)100));
							break;
						}
						case 3: {
							x2Value = StatisticCollection.caculateNMI(idListNum,
									ptNum, gtNum, 2);
							this.calValue = Converter
									.intToBytes2((int) (x2Value * (double) 1000));
							break;
						}
						case 4: {
							x2Value = StatisticCollection.caculateUC(idListNum,
									ptNum, gtNum, 2);
							this.calValue = Converter
									.intToBytes2((int) (x2Value * (double) 1000));
							break;
						}
						default: {
							x2Value = StatisticCollection.caculateCS(idListNum,
									ptNum, gtNum, 2);
							this.calValue = Converter.intToBytes2((int) (x2Value*(double)100));
							break;
						}
						}
						this.outputArray = new byte[512];
						this.firstSnp = Converter.intToBytes2(oldSnp.snp);
						this.secondSnp = Converter
								.intToBytes2(newSingleSnp.snp);

					//	if (x2Value > 15.5) {
							for (int x = 0; x < 2; x++)
								for (int y = 0; y < 9; y++) {
									int a = (int) x2Value;
									for (int indexByte = 0; indexByte < this.snpBitLength; indexByte++) {
										this.outputArray[indexByte] = newSampleList[x][y][indexByte];
									}
									for (int indexGT = 0; indexGT < 2; indexGT++) {
										this.outputArray[this.snpBitLength
												+ indexGT] = this.locusValue[y][indexGT];
									}
									this.outputArray[this.snpBitLength + 2] = (byte) (1 << x);
									for (int indexInt = 0; indexInt < 5; indexInt++) {
										this.outputArray[this.snpBitLength + 3
												+ indexInt] = this.calValue[indexInt];
										this.outputArray[this.snpBitLength + 8
												+ indexInt] = this.secondSnp[indexInt];
										this.outputArray[this.snpBitLength + 13
												+ indexInt] = this.firstSnp[indexInt];
									}

									newOutputValue
											.set(this.outputArray, 0, 352);
									output.collect(NullWritable.get(),
											newOutputValue);
								}
						//}// if(pValue>15.5)
					}
				}// end if (this.isDiagonal ==0)

				else {
					for (int i = 0; i < this.arraylist.size(); i++) {
						idListNum = new int[2][9]; // used for count the p-value
						newSampleList = new byte[2][9][snpBitLength];
						/* used for storing the sampleIDlist after pairing */
						oldSnp = this.arraylist.get(i);
						if (newSingleSnp.snp < oldSnp.snp) {
							compareTwoSingleSnp(idListNum, newSampleList,
									oldSnp, newSingleSnp);
							
							x2Value=0;
							switch (this.statisticMarker) {
							case 1: {
								x2Value = StatisticCollection.caculateCS(idListNum,
										ptNum, gtNum, 2);
								this.calValue = Converter.intToBytes2((int) (x2Value*(double)100));
								break;
							}
							case 2: {
								x2Value = StatisticCollection.caculateLHR(idListNum,
										ptNum, gtNum, 2);
								this.calValue = Converter.intToBytes2((int) (x2Value*(double)100));
								break;
							}
							case 3: {
								x2Value = StatisticCollection.caculateNMI(idListNum,
										ptNum, gtNum, 2);
								this.calValue = Converter
										.intToBytes2((int) (x2Value * (double) 1000));
								break;
							}
							case 4: {
								x2Value = StatisticCollection.caculateUC(idListNum,
										ptNum, gtNum, 2);
								this.calValue = Converter
										.intToBytes2((int) (x2Value * (double) 1000));
								break;
							}
							default: {
								x2Value = StatisticCollection.caculateCS(idListNum,
										ptNum, gtNum, 2);
								this.calValue = Converter.intToBytes2((int) (x2Value*(double)100));
								break;
							}
							}
							this.outputArray = new byte[512];
							this.firstSnp = Converter.intToBytes2(oldSnp.snp);
							this.secondSnp = Converter
									.intToBytes2(newSingleSnp.snp);
							//set your threshold to output data
						//	if (x2Value > 15.5) {
								for (int x = 0; x < 2; x++)
									for (int y = 0; y < 9; y++) {
										int a = (int) x2Value;
										for (int indexByte = 0; indexByte < this.snpBitLength; indexByte++) {
											this.outputArray[indexByte] = newSampleList[x][y][indexByte];
										}
										for (int indexGT = 0; indexGT < 2; indexGT++) {
											this.outputArray[this.snpBitLength
													+ indexGT] = this.locusValue[y][indexGT];
										}
										this.outputArray[this.snpBitLength + 2] = (byte) (1 << x);
										for (int indexInt = 0; indexInt < 5; indexInt++) {
											this.outputArray[this.snpBitLength
													+ 3 + indexInt] = this.calValue[indexInt];
											this.outputArray[this.snpBitLength
													+ 8 + indexInt] = this.secondSnp[indexInt];
											this.outputArray[this.snpBitLength
													+ 13 + indexInt] = this.firstSnp[indexInt];
										}

										newOutputValue.set(this.outputArray, 0,
												352);
										output.collect(NullWritable.get(),
												newOutputValue);
									}
						//	}// if(pValue>15.5)
						}
					}
				}
			} // else for if(keyArray[0].equals("0"))

		} catch (IOException exc) {
			System.err.println("ERROR: " + exc.getMessage());
			exc.printStackTrace();
		}

	}

	/*
	 * diagonalAndEndsnp[0] : whether it is diagonal diagonalAndEndsnp[1] : The
	 * biggest Snp in this parition diagonalAndEndsnp[2] : which partition this
	 * is
	 */
	private int[] isDiagonalOrNot(int partitionnum, int reducernum) {

		// int isOrNot=0;
		int[] diagonalAndEndsnp = new int[3];
		for (int i = 0, gap = 0; i < partitionnum; gap = gap + partitionnum - i, i++) {
			if (reducernum == gap) {
				diagonalAndEndsnp[0] = 1; // is diagonal or not
				diagonalAndEndsnp[1] = (i + 1) * partitionnum - 1; // 
				diagonalAndEndsnp[2] = i; // 
				return diagonalAndEndsnp;
			}
		}
		return diagonalAndEndsnp;
	}

	// BitSet tmpBitSet;
	private void compareTwoSingleSnp(int[][] idlistnum,
			byte[][][] newsamplelist, SingleSnpInfor oldsnp,
			SingleSnpInfor newsnp) {
		int num = 0;

		for (int diea = 0; diea < 2; diea++)
			for (int snpValue = 0; snpValue < 3; snpValue++) {
				for (int k = 0; k < 3; k++) {
					newsamplelist[diea][snpValue * 3 + k] = getIntersection(
							oldsnp.sampleIdBits[diea][snpValue],
							newsnp.sampleIdBits[diea][k]);
					idlistnum[diea][snpValue * 3 + k] = countIntersection(newsamplelist[diea][snpValue
							* 3 + k]);
				}
			}
	}

	// find the intersection from two lists
	private byte[] getIntersection(byte[] bs, byte[] bs2) {
		byte[] tmpBytes = new byte[bs.length];
		for (int i = 0; i < bs.length; i++) {
			tmpBytes[i] = (byte) (bs[i] & bs2[i]);
		}
		return tmpBytes;
	}

	// count the number of '1's in the intersection
	private int countIntersection(byte[] intersection) {
		int num = 0;
		for (int i = 0; i < intersection.length; i++) {
			int n = intersection[i];
			while (n > 0) {
				n = n & (n - 1);
				num++;
			}
		}
		return num - 334;
	}
}
