#include "global.h"

#include <fstream>
#include <iostream>

/*
// for detecting memory leak
#include <afx.h>
#define new DEBUG_NEW
#ifdef _DEBUG
#undef THIS_FILE
static char THIS_FILE[]=__FILE__;
#define new DEBUG_NEW
#endif
//*/


/*
input files:
szmerged_dataset.rules

szoutput_name.attribute2item.txt
szoutput_name.tgtvalues.txt

szrule_dataset.rulenum.txt szrule_dataset.rules.txt

output files:
szrule_dataset.prob.sum.txt
szrule_dataset.sets.sum.txt

*/

void holdout_match_rules(char *szmerged_dataset, char *szrule_dataset, int num_of_clmns, int num_of_rules, char *szoutput_name,
						 double dmatch_thres)
{
	ASSOCRULE *pmined_rules, *ptrue_rules;
	int *ptrue_itemset_buf, *ptrue_tidlist_buf;
	int num_of_true_rules, num_of_mined_rules, nmax_true_rule_len, nmax_mined_rule_len;
	int *ptgt_values, i, ntgt_class;
	SIGN_RULE_NUM thesign_rule_nums;
	MATCH_STAT the_match_stat;
	map<string, int> tgtvalue_map;
	map<string, int>::iterator map_it;
	
	// read true rules
	ConvertRules(szmerged_dataset, num_of_clmns, num_of_rules, szoutput_name);
	//load embedded rules
	num_of_true_rules = LoadTrueRules(szmerged_dataset, ptrue_rules, ptrue_itemset_buf, ptrue_tidlist_buf, nmax_true_rule_len);
	std::cout << "true rules: " << num_of_true_rules << std::endl;

	//set the class label of embedded rules
	if(gnum_of_classes>2)
	{
		LoadTgtValueMap(szoutput_name, &tgtvalue_map);
		map_it = tgtvalue_map.find(gszembed_target_value);
		if(map_it==tgtvalue_map.end())
		{
			ntgt_class = -1;
			printf("Error: the class of the embedded rule is not find in target value map\n");
		}
		else
			ntgt_class = map_it->second;
	}
	else
		ntgt_class = 1;
	for(i=0;i<num_of_true_rules;i++)
		ptrue_rules[i].nclass_no = ntgt_class;


	// read mined rules
	num_of_mined_rules = LoadMinedRules(szrule_dataset, pmined_rules, nmax_mined_rule_len);
	LoadMinedTidList(szrule_dataset, pmined_rules, num_of_mined_rules, "holdout");

	ReadTreeStatis(szoutput_name);

	std::cout << "read tgt values... " << std::endl;
	// read positive target values
	LoadTgtValues(szoutput_name, ptgt_values, "holdout");

	std::cout << "load rulesum... " << std::endl;
	//load p-value thresholds and the number of signficant rules generated by different methods
	LoadSignRuleNums(szrule_dataset, &thesign_rule_nums);

	std::cout << "matching..." << std::endl;
	// match
	char szsum_name[200];
	
	if(num_of_true_rules>0)
	{
		sprintf(szsum_name,"%s.pvalue.sum.txt", szrule_dataset);
		MatchRules_pvalue(pmined_rules, num_of_mined_rules, ptrue_rules, MATCH_PVALUE, &thesign_rule_nums, ptgt_values, &the_match_stat);
		OutputMatchStat(szmerged_dataset, szrule_dataset, dmatch_thres, &thesign_rule_nums, &the_match_stat, szsum_name);

		sprintf(szsum_name,"%s.pvalue-FP.sum.txt", szrule_dataset);
		MatchRules_pvalue(pmined_rules, num_of_mined_rules, ptrue_rules, MATCH_PVALUE_FP, &thesign_rule_nums, ptgt_values, &the_match_stat);
		OutputMatchStat(szmerged_dataset, szrule_dataset, dmatch_thres, &thesign_rule_nums, &the_match_stat, szsum_name);

		sprintf(szsum_name,"%s.sets.sum.txt", szrule_dataset);
		MatchRules(pmined_rules, num_of_mined_rules, ptrue_rules, num_of_true_rules, MATCH_SET, dmatch_thres, &thesign_rule_nums, ptgt_values, &the_match_stat);
		OutputMatchStat(szmerged_dataset, szrule_dataset, dmatch_thres, &thesign_rule_nums, &the_match_stat, szsum_name);

		//sprintf(szsum_name,"%s.prob.sum.txt", szrule_dataset);
		//MatchRules(pmined_rules, num_of_mined_rules, ptrue_rules, num_of_true_rules, MATCH_PROBABILITY, dmatch_thres, &thesign_rule_nums, ptgt_values, &the_match_stat);
		//OutputMatchStat(szmerged_dataset, szrule_dataset, dmatch_thres, &thesign_rule_nums, &the_match_stat, szsum_name);

		//sprintf(szsum_name,"%s.binary.sum.txt", szrule_dataset);
		//MatchRules(pmined_rules, num_of_mined_rules, ptrue_rules, num_of_true_rules, MATCH_BINARY, dmatch_thres, &thesign_rule_nums, ptgt_values, &the_match_stat);
		//OutputMatchStat(szmerged_dataset, szrule_dataset, dmatch_thres, &thesign_rule_nums, &the_match_stat, szsum_name);

		//sprintf(szsum_name,"%s.exact.sum.txt", szrule_dataset);
		//MatchRules(pmined_rules, num_of_mined_rules, ptrue_rules, num_of_true_rules, MATCH_EXACT_BINARY, dmatch_thres, &thesign_rule_nums, ptgt_values, &the_match_stat);
		//OutputMatchStat(szmerged_dataset, szrule_dataset, dmatch_thres, &thesign_rule_nums, &the_match_stat, szsum_name);
	}
	else
	{
		sprintf(szsum_name,"%s.rand.sum.txt", szrule_dataset);
		OutputMatchStat(szmerged_dataset, szrule_dataset, &thesign_rule_nums, szsum_name);
	}

	// free space
	delete [] ptgt_values;
	for(int i=0; i<num_of_mined_rules; i++) {
		delete [] pmined_rules[i].pattern;
		delete [] pmined_rules[i].ptid_list;
	}
	delete [] pmined_rules;

	delete [] ptrue_rules;
	delete [] ptrue_itemset_buf;
	delete [] ptrue_tidlist_buf;
}

void LoadTgtValues(char *szoutputname, int *&ptgt_values, char *holdout)
{
	gndb_size=0;
	gntgt_sup=0;

	std::cout << holdout << std::endl;
	char tgt_name[200];
	sprintf(tgt_name, "%s.tgtvalues.txt", szoutputname);
	std::ifstream tgt_in(tgt_name, std::ios::in);
	int num;
	tgt_in >> num;
	ptgt_values = new int[num];
	for(int i=0; i<num; i++) {
		tgt_in >> ptgt_values[i];
		if(ptgt_values[i]==1)
			gntgt_sup++;
	}
	std:: cout << "tgtvalues: " << num << std::endl;

	gndb_size=num;
}

int LoadMinedRules(char *szdataset_name, ASSOCRULE *&pmined_rules, int &nmax_mined_rule_len)
{
	int num_of_rules = 0;
	int max_len=0;
	std::vector<std::string> rules;

	char szrule_name[200];
	sprintf(szrule_name, "%s.rules.txt", szdataset_name);
	std::cout << szrule_name << std::endl;
	std::ifstream in(szrule_name, std::ios::in);
	while(!in.eof()) {
		std::string s;
		std::getline(in,s);
		if(s.length()==0) continue;

		rules.push_back(s);
		num_of_rules++;
	}
	in.close();

	pmined_rules = new ASSOCRULE[num_of_rules];
	for(int i=0; i<num_of_rules; i++) {
		std::vector<std::string> tokens;
		split(rules[i],tokens,' ');
		pmined_rules[i].npat_len = atoi(tokens[0].c_str());
		pmined_rules[i].pattern = new int[pmined_rules[i].npat_len];
		int k=0;
		for(; k<pmined_rules[i].npat_len; k++)
			pmined_rules[i].pattern[k] = atoi(tokens[1+k].c_str());
		if(gnum_of_classes>2)
		{
			pmined_rules[i].nclass_no = atof(tokens[k+1].c_str());
			pmined_rules[i].dpvalue = atof(tokens[k+2].c_str());
		}
		else
		{
			pmined_rules[i].nclass_no = 1;
			pmined_rules[i].dpvalue = atof(tokens[k+1].c_str());
		}

		if(pmined_rules[i].npat_len>max_len) max_len = pmined_rules[i].npat_len;
	}

	nmax_mined_rule_len = max_len;
	return num_of_rules;

}

void LoadMinedTidList(char *szdataset_name, ASSOCRULE *pmined_rules, int num_of_mined_rules, char *holdout)
{
	std:: cout << holdout << std::endl;
	std::cout << num_of_mined_rules << std::endl;

	char sztidlist_name[200];
	sprintf(sztidlist_name, "%s.tidlist.txt", szdataset_name);
	std::cout << sztidlist_name << std::endl;
	std::ifstream in(sztidlist_name, std::ios::in);
	ASSOCRULE *cur = pmined_rules;
	int count = 0;
	while(!in.eof()) {
		std::string s;
		std::getline(in,s);
		if(s.length()==0) continue;

		std::vector<std::string> tokens;
		split(s,tokens, ' ');
		cur->nsup = atoi(tokens[0].c_str());
		cur->ntgt_sup = atoi(tokens[1].c_str());
		cur->ptid_list = new int[cur->nsup];

		for(int k=0; k<cur->nsup; k++)
			cur->ptid_list[k] = atoi(tokens[2+k].c_str());

		cur++;
		count++;
	}
	in.close();

	if(count!=num_of_mined_rules) std::cout << "ERROR: rules and tidlist doesn't match!" << std::endl;
}