package sg.edu.nus;

/*
 * @author: Zhengkui Wang
 * 
 * National university of singapore
 */

import java.io.IOException;
import java.util.BitSet;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import org.apache.hadoop.io.DoubleWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapred.*;

/*************************************************************
 * Reducer for two-locus analysis in Greedy distribution model
 *************************************************************/

public class TwoSnpsReducerGreedy extends MapReduceBase implements
		Reducer<DoubleWritable, Text, NullWritable, Text> {
	public Text newOutputValue = new Text("");
	int snpTotalNum = 0;
	int totalReducer = 0;
	int hashMapLength = 0;
	int processNum = 0;
	byte[][] locusValue = { { 1 << 0, 1 << 0 }, { 1 << 0, 1 << 1 },
			{ 1 << 0, 1 << 2 }, { 1 << 1, 1 << 0 }, { 1 << 1, 1 << 1 },
			{ 1 << 1, 1 << 2 }, { 1 << 2, 1 << 0 }, { 1 << 2, 1 << 1 },
			{ 1 << 2, 1 << 2 } };
	String splitString = null;
	int[] splittor = null;
	ArrayList<SingleSnpInfor> arraylist = new ArrayList<SingleSnpInfor>();
	int endNum = 0;
	int mark = 0;
	int iternum = 0;
	int putnum = 0;
	SingleSnpInfor oldSnp;// fetching the old snp from the arraly list;
	SingleSnpInfor newSingleSnp;
	int[][] idListNum; // used for count the p-value
	byte[][][] newSampleList; // used for storing the sampleIDlist after pairing
	String tuple;// store the tuple get from values
	String[] strTupleSplit; // store the split strings from tuple
	byte[] dieaArray = { 1 << 0, 1 << 1 };
	byte[] gtArray = { 1 << 0, 1 << 1, 1 << 2 };
	int tmpDiea = 0; // record the diea
	int tmpGT = 0; // record the gt
	int snpBitLength = 334;
	byte[] outputArray;
	byte[] firstSnp;
	byte[] secondSnp;
	byte[] calValue;
	byte[] tmp = null;
	double x2Value;
	int ptNum = 2;
	int gtNum = 3;
	int statisticMarker = 0;

	/********************************************************************
	 * Get all necessnary parameters Get the splitor for load balancing
	 ********************************************************************/
	@Override
	public void configure(JobConf job) {
		this.statisticMarker = job.getInt("statistic.method", 0);
		this.snpTotalNum = job.getInt("snp.num", 0);
		this.totalReducer = job.getInt("reducer.num", 0);
		this.splitString = job.get("splittor");
		splittor = new int[this.totalReducer];
		String[] split = this.splitString.split(",");
		splittor = new int[split.length];
		for (int i = 0; i < split.length; i++) {
			splittor[i] = Integer.parseInt(split[i]);
		}

	}

	/***************************************************************************
	 * output value format
	 * |first_snp|second_snp|x^2value|first_GT|Second_GT|PT|sampleid_bit_list|
	 * |5|5|5|1|1|1|334|
	 * 6 bits in one byte is used to store the sample id
	 **************************************************************************/
	public void reduce(DoubleWritable key, Iterator<Text> values,
			OutputCollector<NullWritable, Text> output, Reporter reporter) {
		try {
			int newSnpValue = 0;
			if (mark == 0) {
				newSnpValue = (int) key.get();
				processNum = getProcessNum(newSnpValue, this.splittor);

				mark = 1;
			} else {
				newSnpValue = (int) key.get();
			}

			newSingleSnp = new SingleSnpInfor(newSnpValue);
			while (values.hasNext()) {
				tmp = values.next().getBytes();

				int tmpDiea = 0;
				int tmpGT = 0;
				for (int i = 0; i < 2; i++) {

					if (dieaArray[i] == tmp[snpBitLength + 1]) {
						tmpDiea = i;
					}
				}
				for (int k = 0; k < 3; k++) {
					if (gtArray[k] == tmp[snpBitLength]) {
						tmpGT = k;
					}
				}
				for (int byteIndex = 0; byteIndex < snpBitLength; byteIndex++) {
					newSingleSnp.sampleIdBits[tmpDiea][tmpGT][byteIndex] = tmp[byteIndex];
				}
			}

			if (hashMapLength < processNum) {
				this.arraylist.add(newSingleSnp);
				hashMapLength++;
				for (int i = 0; i < hashMapLength - 1; i++) {
					idListNum = new int[2][9]; // used for count the p-value
					newSampleList = new byte[2][9][snpBitLength];
					// used for storing the sampleIDlist after pairing
					oldSnp = this.arraylist.get(i);
					compareTwoSingleSnp(idListNum, newSampleList, oldSnp,
							newSingleSnp);
					switch (this.statisticMarker) {
					case 1: {
						x2Value = StatisticCollection.caculateCS(idListNum,
								ptNum, gtNum, 2);
						this.calValue = Converter.intToBytes2((int) (x2Value*(double)100));
						break;
					}
					case 2: {
						x2Value = StatisticCollection.caculateLHR(idListNum,
								ptNum, gtNum, 2);
						this.calValue = Converter.intToBytes2((int) (x2Value*(double)100));
						break;
					}
					case 3: {
						x2Value = StatisticCollection.caculateNMI(idListNum,
								ptNum, gtNum, 2);
						this.calValue = Converter
								.intToBytes2((int) (x2Value * (double) 1000));
						break;
					}
					case 4: {
						x2Value = StatisticCollection.caculateUC(idListNum,
								ptNum, gtNum, 2);
						this.calValue = Converter
								.intToBytes2((int) (x2Value * (double) 1000));
						break;
					}
					default: {
						x2Value = StatisticCollection.caculateCS(idListNum,
								ptNum, gtNum, 2);
						this.calValue = Converter.intToBytes2((int) (x2Value*(double)100));
						break;
					}
					}
					this.outputArray = new byte[512];
					this.firstSnp = Converter.intToBytes2(oldSnp.snp);
					this.secondSnp = Converter.intToBytes2(newSingleSnp.snp);

					// if (x2Value > 15) {
					for (int x = 0; x < 2; x++)
						for (int y = 0; y < 9; y++) {
							for (int indexByte = 0; indexByte < this.snpBitLength; indexByte++) {
								this.outputArray[indexByte] = newSampleList[x][y][indexByte];
							}
							for (int indexGT = 0; indexGT < 2; indexGT++) {
								this.outputArray[this.snpBitLength + indexGT] = this.locusValue[y][indexGT];
							}
							this.outputArray[this.snpBitLength + 2] = (byte) (1 << x);
							for (int indexInt = 0; indexInt < 5; indexInt++) {
								this.outputArray[this.snpBitLength + 3
										+ indexInt] = this.calValue[indexInt];
								this.outputArray[this.snpBitLength + 8
										+ indexInt] = this.secondSnp[indexInt];
								this.outputArray[this.snpBitLength + 13
										+ indexInt] = this.firstSnp[indexInt];
							}

							newOutputValue.set(this.outputArray, 0, 352);
							output.collect(NullWritable.get(), newOutputValue);
						}
					// }//if(pValue >15.5)

				}
			}// FOR(int i=0; i< hashMapLength; i++)

			else {
				for (int i = 0; i < hashMapLength; i++) {
					idListNum = new int[2][9]; // used for count the p-value
					newSampleList = new byte[2][9][snpBitLength]; // used for
					/*
					 * storingthe sample id list after pairing
					 */
					oldSnp = this.arraylist.get(i);
					compareTwoSingleSnp(idListNum, newSampleList, oldSnp,
							newSingleSnp);
					switch (this.statisticMarker) {
					case 1: {
						x2Value = StatisticCollection.caculateCS(idListNum,
								ptNum, gtNum, 2);
						this.calValue = Converter.intToBytes2((int) (x2Value*(double)100));
						break;
					}
					case 2: {
						x2Value = StatisticCollection.caculateLHR(idListNum,
								ptNum, gtNum, 2);
						this.calValue = Converter.intToBytes2((int) (x2Value*(double)100));
						break;
					}
					case 3: {
						x2Value = StatisticCollection.caculateNMI(idListNum,
								ptNum, gtNum, 2);
						this.calValue = Converter
								.intToBytes2((int) (x2Value * (double) 1000));
						break;
					}
					case 4: {
						x2Value = StatisticCollection.caculateUC(idListNum,
								ptNum, gtNum, 2);
						this.calValue = Converter
								.intToBytes2((int) (x2Value * (double) 1000));
						break;
					}
					default: {
						x2Value = StatisticCollection.caculateCS(idListNum,
								ptNum, gtNum, 2);
						this.calValue = Converter.intToBytes2((int) (x2Value*(double)100));
						break;
					}
					}
					this.outputArray = new byte[512];
					this.firstSnp = Converter.intToBytes2(oldSnp.snp);
					this.secondSnp = Converter.intToBytes2(newSingleSnp.snp);

					// if (x2Value > 15) {
					for (int x = 0; x < 2; x++)
						for (int y = 0; y < 9; y++) {

							int a = (int) (x2Value * (double) 10);
							for (int indexByte = 0; indexByte < this.snpBitLength; indexByte++) {
								this.outputArray[indexByte] = newSampleList[x][y][indexByte];

							}
							for (int indexGT = 0; indexGT < 2; indexGT++) {
								this.outputArray[this.snpBitLength + indexGT] = this.locusValue[y][indexGT];
							}
							this.outputArray[this.snpBitLength + 2] = (byte) (1 << x);
							for (int indexInt = 0; indexInt < 5; indexInt++) {
								this.outputArray[this.snpBitLength + 3
										+ indexInt] = this.calValue[indexInt];
								this.outputArray[this.snpBitLength + 8
										+ indexInt] = this.secondSnp[indexInt];
								this.outputArray[this.snpBitLength + 13
										+ indexInt] = this.firstSnp[indexInt];
							}
							newOutputValue.set(this.outputArray, 0, 352);
							output.collect(NullWritable.get(), newOutputValue);
						}
					// }//if(pValue > 15.5)

				}
			}// FOR(int i=0; i< hashMapLength; i++)

		} catch (IOException exc) {
			System.err.println("ERROR: " + exc.getMessage());
			exc.printStackTrace();
		}

	}

	/**********************************************************************************
	 * Compare two two-locus data and find out their sample list and contingency
	 * table
	 ***********************************************************************************/
	private void compareTwoSingleSnp(int[][] idlistnum,
			byte[][][] newsamplelist, SingleSnpInfor oldsnp,
			SingleSnpInfor newsnp) {
		int num = 0;
		for (int diea = 0; diea < 2; diea++)
			for (int snpValue = 0; snpValue < 3; snpValue++) {
				for (int k = 0; k < 3; k++) {
					newsamplelist[diea][snpValue * 3 + k] = getIntersection(
							oldsnp.sampleIdBits[diea][snpValue],
							newsnp.sampleIdBits[diea][k]);
					idlistnum[diea][snpValue * 3 + k] = countIntersection(newsamplelist[diea][snpValue
							* 3 + k]);

				}
			}

	}

	/***************************************************
	 * find the intersection from two lists
	 *****************************************************/
	private byte[] getIntersection(byte[] bs, byte[] bs2) {
		byte[] tmpBytes = new byte[bs.length];
		for (int i = 0; i < bs.length; i++) {
			tmpBytes[i] = (byte) (bs[i] & bs2[i]);
		}
		return tmpBytes;
	}

	/*******************************************************
	 * count the number of '1's in the intersection
	 *********************************************************/
	private int countIntersection(byte[] intersection) {
		int num = 0;
		for (int i = 0; i < intersection.length; i++) {
			int n = intersection[i];
			while (n > 0) {
				n = n & (n - 1);
				num++;
			}
		}
		return num - 334;
	}

	/********************************************************
	 * get how many rows this reducer needs to process
	 *********************************************************/
	private int getProcessNum(int newSnpValue, int[] splittor) {
		int processNum = 0;
		for (int i = 0; i < splittor.length; i++) {
			if (i == 0) {
				if (newSnpValue <= splittor[i]) {
					processNum = splittor[0];
					i = splittor.length;
				}
			} else {
				if (newSnpValue <= splittor[i] && newSnpValue > splittor[i - 1]) {
					processNum = splittor[i] - splittor[i - 1];
					i = splittor.length;
				}
			}
		}
		return processNum;
	}

}
