package sg.edu.nus;

/*
 * author: Zhengkui Wang
 * 
 * National university of singapore
 */
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;

import org.apache.hadoop.io.DoubleWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapred.*;

/****************************************
 * The reducer for three-locus analysis
 *******************************************/

/***************************************************************************
 * output value format
 *|first_snp|second_snp|third_snp|x^2value|PT|first_GT|Second_GT|Third_GT|sampleid_bit_list|
 *|5|5|5|5|1|1|1|1|334| 
 **************************************************************************/
public class ThreeSnpsReducer extends MapReduceBase implements
		Reducer<DoubleWritable, Text, NullWritable, Text> {

	byte[][] locusValue = { { 1 << 0, 1 << 0, 1 << 0 },
			{ 1 << 0, 1 << 0, 1 << 1 }, { 1 << 0, 1 << 0, 1 << 2 },
			{ 1 << 0, 1 << 1, 1 << 0 }, { 1 << 0, 1 << 1, 1 << 1 },
			{ 1 << 0, 1 << 1, 1 << 2 }, { 1 << 0, 1 << 2, 1 << 0 },
			{ 1 << 0, 1 << 2, 1 << 1 }, { 1 << 0, 1 << 2, 1 << 2 },
			{ 1 << 1, 1 << 0, 1 << 0 }, { 1 << 1, 1 << 0, 1 << 1 },
			{ 1 << 1, 1 << 0, 1 << 2 }, { 1 << 1, 1 << 1, 1 << 0 },
			{ 1 << 1, 1 << 1, 1 << 1 }, { 1 << 1, 1 << 1, 1 << 2 },
			{ 1 << 1, 1 << 2, 1 << 0 }, { 1 << 1, 1 << 2, 1 << 1 },
			{ 1 << 1, 1 << 2, 1 << 2 }, { 1 << 2, 1 << 0, 1 << 0 },
			{ 1 << 2, 1 << 0, 1 << 1 }, { 1 << 2, 1 << 0, 1 << 2 },
			{ 1 << 2, 1 << 1, 1 << 0 }, { 1 << 2, 1 << 1, 1 << 1 },
			{ 1 << 2, 1 << 1, 1 << 2 }, { 1 << 2, 1 << 2, 1 << 0 },
			{ 1 << 2, 1 << 2, 1 << 1 }, { 1 << 2, 1 << 2, 1 << 2 } };
	public Text newOutputKey = new Text("");
	public Text newOutputValue = new Text("");
	String currentSnp = null;
	int snpTotalNum = 0;
	int totalReducer = 0;
	final static int firstSnp = 0;
	final static int hashMapLength = 0;
	int num = 0;
	ArrayList<TwoSnpsInfor> snpList = new ArrayList<TwoSnpsInfor>();
	TwoSnpsInfor recievedSnp;
	byte[] valueBytes;
	int tmpDiea = 0;
	int firstGT = 0;
	int secondGT = 0;
	int idsBitLength = 334; // change this when you have different
	int[][] pValueTable;
	byte[][][] samplesInTable;
	byte[] outputArray;
	byte[] snp1, snp2, snp3;
	int ptNum = 2;
	int gtNum = 3;
	double x2Value=0;
	int statisticMarker=0;
	String threeLocus;
	private byte[] x2ValueByte;

	@Override
	public void configure(JobConf job) {
		this.snpTotalNum = job.getInt("snp.number", 0);
		this.totalReducer = job.getInt("reducer.num", 0);
		this.statisticMarker=job.getInt("statistic.method", 1);

	}

	public void reduce(DoubleWritable key, Iterator<Text> values,
			OutputCollector<NullWritable, Text> output, Reporter reporter) {
		String[] twoLocus = Double.toString(key.get()).split("\\.");

		recievedSnp = new TwoSnpsInfor(twoLocus[1]);
		int total = 0;
		while (values.hasNext()) {
			this.valueBytes = values.next().getBytes();
			if (this.valueBytes.length != 352) {
				return;
			}
			tmpDiea = 0;
			firstGT = 0;
			secondGT = 0;

			for (int i = 0; i < 2; i++) {
				if (valueBytes[336] == (byte) 1 << i) {
					tmpDiea = i;
				}
			}
			for (int j = 0; j < 3; j++) {
				if (valueBytes[335] == (byte) 1 << j) {
					firstGT = j;
				}
				if (valueBytes[334] == (byte) 1 << j) {
					secondGT = j;
				}
			}
			for (int byteIndex = 0; byteIndex < this.idsBitLength; byteIndex++) {
				recievedSnp.sampleIdBits[tmpDiea][firstGT][secondGT][byteIndex] = valueBytes[byteIndex];

			}
		}
		/*************************************************************************
		 * Combine any two two-locus data from the same row to get three-locus
		 *************************************************************************/
		if (this.currentSnp != null) {
			if (this.currentSnp.equals(twoLocus[0])) {
				this.snpList.add(recievedSnp);
				for (int i = 0; i < this.snpList.size() - 1; i++) {
					this.outputArray = new byte[500];
					pValueTable = new int[2][27];
					samplesInTable = new byte[2][27][334];
					TwoSnpsInfor processingSnp = snpList.get(i);
					pValueTable = getTableValue(processingSnp, samplesInTable,
							recievedSnp);
					switch (this.statisticMarker) {
					case 1: {
						x2Value = StatisticCollection.caculateCS(pValueTable,
								ptNum, gtNum, 3);
						this.x2ValueByte = Converter.intToBytes2((int) (x2Value*(double)100));
						break;
					}
					case 2: {
						x2Value = StatisticCollection.caculateLHR(pValueTable,
								ptNum, gtNum, 3);
						this.x2ValueByte = Converter.intToBytes2((int) (x2Value*(double)100));
						break;
					}
					case 3: {
						x2Value = StatisticCollection.caculateNMI(pValueTable,
								ptNum, gtNum, 3);
						this.x2ValueByte = Converter
								.intToBytes2((int) (x2Value * (double) 1000));
						break;
					}
					case 4: {
						x2Value = StatisticCollection.caculateUC(pValueTable,
								ptNum, gtNum, 3);
						this.x2ValueByte = Converter
								.intToBytes2((int) (x2Value * (double) 1000));
						break;
					}
					default: {
						x2Value = StatisticCollection.caculateCS(pValueTable,
								ptNum, gtNum, 2);
						this.x2ValueByte = Converter.intToBytes2((int) (x2Value*(double)100));
						break;
					}
					}
					
					/***************************************************************
					 * output the three-locus snps data whose x^2 value is
					 * bigger than 15.5 if output all, comment the if condition
					 ***************************************************************/
					if (x2Value > 15.5) {
						this.snp1 = Converter.intToBytes2(Integer
								.valueOf(this.currentSnp));
						this.snp2 = Converter.intToBytes2(Integer
								.valueOf(processingSnp.snp));
						this.snp3 = Converter.intToBytes2(Integer
								.valueOf(recievedSnp.snp));
						this.x2ValueByte = Converter.intToBytes2((int) x2Value);
						for (int x = 0; x < 2; x++)
							for (int y = 0; y < 9; y++) {
								int a = (int) x2Value;
								for (int indexByte = 0; indexByte < this.idsBitLength; indexByte++) {
									this.outputArray[indexByte] = samplesInTable[x][y][indexByte];
								}
								for (int indexGT = 0; indexGT < 3; indexGT++) {
									this.outputArray[this.idsBitLength
											+ indexGT] = this.locusValue[y][indexGT];
								}
								this.outputArray[this.idsBitLength + 3] = (byte) (1 << x);
								for (int indexInt = 0; indexInt < 5; indexInt++) {
									this.outputArray[this.idsBitLength + 4
											+ indexInt] = this.x2ValueByte[indexInt];
									this.outputArray[this.idsBitLength + 9
											+ indexInt] = this.snp3[indexInt];
									this.outputArray[this.idsBitLength + 14
											+ indexInt] = this.snp2[indexInt];
									this.outputArray[this.idsBitLength + 19
											+ indexInt] = this.snp1[indexInt];
								}
								newOutputValue.set(this.outputArray, 0, 358);
								try {
									output.collect(NullWritable.get(),
											newOutputValue);
								} catch (IOException e) {
									// TODO Auto-generated catch block
									e.printStackTrace();
								}
							}

					} // if(x2Value > ?)
				}
			} else {
				this.snpList.clear();

				this.currentSnp = twoLocus[0];
				this.snpList.add(recievedSnp);

			}
		} else {

			this.currentSnp = twoLocus[0];
			this.snpList.add(recievedSnp);

		}

	}

	/**********************************************
	 * Get the contingency table information
	 ************************************************/
	public int[][] getTableValue(TwoSnpsInfor oldsnp,
			byte[][][] samplesintable, TwoSnpsInfor newsnp) {
		int[][] valueTable = new int[2][27];
		for (int diease = 0; diease < 2; diease++)
			for (int firstNum = 0; firstNum < 3; firstNum++)
				for (int secondNum = 0; secondNum < 3; secondNum++) {
					for (int newSecondSnp = 0; newSecondSnp < 3; newSecondSnp++) {
						int column = firstNum * 9 + secondNum * 3
								+ newSecondSnp;
						samplesintable[diease][column] = getIntersection(
								oldsnp.sampleIdBits[diease][firstNum][secondNum],
								newsnp.sampleIdBits[diease][firstNum][newSecondSnp]);
						valueTable[diease][column] = countIntersection(samplesintable[diease][column]);
					}
				}

		return valueTable;
	}

	// find the intersection from two lists
	private byte[] getIntersection(byte[] bs, byte[] bs2) {
		byte[] tmpBytes = new byte[bs.length];
		for (int i = 0; i < bs.length; i++) {
			tmpBytes[i] = (byte) (bs[i] & bs2[i]);
		}
		return tmpBytes;
	}

	// count the number of '1's in the intersection
	private int countIntersection(byte[] intersection) {
		int num = 0;
		for (int i = 0; i < intersection.length; i++) {
			int n = intersection[i];
			while (n > 0) {
				n = n & (n - 1);
				num++;
			}
		}
		return num;
	}

}
