//-----------------------------------------------
// Copyright 2016 Guangxi University
// Written by Liang Zhao(S080011@e.ntu.edu.sg)
// Released under the GPL
//-----------------------------------------------
//

#ifndef EXTRADISJOINTSET_H
#define EXTRADISJOINTSET_H
#include <string>
#include <cstdio>
#include <iostream>
#include <fstream>
#include <sstream>
#include <queue>
#include <unordered_map>
#include <unordered_set>
#include <sys/stat.h>
#include <cstdint>

class ExtraDisjointSet
{
public:
    ExtraDisjointSet(std::string inputFile, 
                     std::string outputFile = "",
                     uint64_t chunkSize = 4294967296,
                     size_t comThreshold = 2,
                     std::string tmpPath = "./")
      : _inputFile(inputFile),
        _outputFile(outputFile),
        _chunkSize(chunkSize),
        _runCounter(0),
        _tmpPath(tmpPath),
        _comThreshold(comThreshold) 
    {
        std::cout << "ExtraDisjointSet: " << _inputFile << " " << _chunkSize << "\n";
    }

    void run();

private:
    std::string _inputFile;
    std::string _outputFile;
    size_t _chunkSize;
    size_t _runCounter;
    std::string _tmpPath;
    size_t _comThreshold;
    std::queue<std::string> _tmpFileNames;
    // merge the position together
    void merge(std::vector<std::unordered_set<size_t> >* pCandidate, std::ofstream* fo); 
    // merge the loaded position to existing file inputFile
    void merge(std::vector<std::unordered_set<size_t> >* pCandidate, std::string inputFile, std::string outputFile);
    // merge inputFile1 to inputFile2
    std::string merge(std::string inputFile1, std::string inputFile2); 
    size_t getSetIntersectSize(std::unordered_set<size_t>* ps1, std::unordered_set<size_t>* ps2);
    void takeUnion(std::unordered_set<size_t>* ps1, std::unordered_set<size_t>* ps2);
    inline bool fileExists(std::string name) {  struct stat buffer; return (stat (name.c_str(), &buffer) == 0); }
};

//
void ExtraDisjointSet::run()
{
    std::ifstream fi;
    fi.open(_inputFile.c_str());
    if (!fi.good()) 
    {
        std::cerr << "[ERROR]: can't open " << _inputFile << ".\n";
        exit(1);
    }
    uint64_t bufSize = 0;
    std::string line;
    std::stringstream ss;
    std::queue<std::string> pathes;
    std::string path;
    uint64_t v;
    // merge each chunk
    std::vector<std::unordered_set<size_t> > candidate;
    while (std::getline(fi, line))
    {
        ss.str(""); ss.clear(); ss.str(line);
        std::unordered_set<size_t> pos;
        while( ss >> v) pos.insert(v);
        candidate.push_back(pos);

        bufSize += line.size();
        if (bufSize >= _chunkSize)
        {
            path = _tmpPath + std::to_string(_runCounter);
            std::ofstream fo;
            fo.open(path.c_str());
            merge(&candidate, &fo);
            pathes.push(path);
            candidate.clear();
            bufSize = 0;
            ++ _runCounter;
        }
    }

    if (bufSize > 0)
    {
        path = _tmpPath + std::to_string(_runCounter);
        std::ofstream fo;
        fo.open(path.c_str());
        merge(&candidate, &fo);
        pathes.push(path);
        candidate.clear();
        bufSize = 0;
        ++ _runCounter;
    }
    // merge multiple chunks
    std::string inputFile1, inputFile2, outputFile;
    while (pathes.size() > 1)
    {
        inputFile1 = pathes.front(); pathes.pop();
        inputFile2 = pathes.front(); pathes.pop();
        outputFile = merge(inputFile1, inputFile2);
        pathes.push(outputFile);
        //std::cout << "inputFile1: " << inputFile1 << ", inputFile2: " << inputFile2 << ", outputFile: " << outputFile << "\n";
        std::cout <<inputFile1 << ", " << inputFile2 << ": " << outputFile << "\n";
    }
    
    if (_outputFile == "")
    {
        if (fileExists(_inputFile)) remove(_inputFile.c_str());
        _outputFile = pathes.front();
        rename(_outputFile.c_str(), _inputFile.c_str());
    } else 
    {    
        rename(pathes.front().c_str(), _outputFile.c_str());
    }
}

//
size_t ExtraDisjointSet::getSetIntersectSize(std::unordered_set<size_t>* ps1, std::unordered_set<size_t>* ps2)
{
    size_t com = 0;
    std::unordered_set<size_t>::iterator itr1, itr2;
    for (itr1 = ps1->begin(); itr1 != ps1->end(); ++itr1)
    {
        itr2 = ps2->find(*itr1);
        if (itr2 != ps2->end())
            ++com;
    }
    return com;
}

//
void ExtraDisjointSet::takeUnion(std::unordered_set<size_t>* ps1, std::unordered_set<size_t>* ps2)
{
    std::unordered_set<size_t>::iterator itr;
    for (itr = ps2->begin(); itr != ps2->end(); ++itr) ps1->insert(*itr);
}

//
void ExtraDisjointSet::merge(std::vector<std::unordered_set<size_t> >* pCandidate, std::ofstream* fo)
{
    // create mask
    size_t n = pCandidate->size();
    int* candMask = new int[n];
    size_t i;
    for (i = 0; i < n; ++i) candMask[i] = 1;

    // map position to index
    typedef std::unordered_map<size_t, std::unordered_set<size_t> > uMuS;
    typedef std::unordered_set<size_t> uS;
    uMuS pos2idx;
    uMuS::iterator itrms;
    uS::iterator itrs;

    std::vector<std::unordered_set<size_t> >::iterator itr;
    i = 0;
    for (itr = pCandidate->begin(); itr != pCandidate->end(); ++itr)
    {
        for (itrs = itr->begin(); itrs != itr->end(); ++itrs)
        {
            itrms = pos2idx.find(*itrs);
            if (itrms != pos2idx.end())
            {
                itrms->second.insert(i);
            }
            else 
            {
                uS us;
                us.insert(i);
                pos2idx.emplace(*itrs, us);
            }
        }
        ++i;
    }

    // merge
    size_t com, m;
    i = 0;
    uS *pusm, *pusi;
    uS::iterator itrs2;

    size_t numCand = pCandidate->size();
    for (i = 0; i < numCand; ++i) {
        if (!candMask[i]) continue;
        candMask[i] = 0;
        pusi = &(pCandidate->at(i));

        uS us;
        us.insert(pusi->begin(), pusi->end());
        for (itrs = pusi->begin(); itrs != pusi->end(); ++itrs)
        {
            itrms = pos2idx.find(*itrs);
            n = itrms->second.size();
            if (n > 1) {
                for (itrs2 = itrms->second.begin(); itrs2 != itrms->second.end(); ++itrs2) {
                    m = (*itrs2);
                    if (i == m) continue;
                    pusm = &(pCandidate->at(m));
                    com = getSetIntersectSize(pusi, pusm);
                    if (com >= _comThreshold) {
                        takeUnion(&us, pusm);
                        candMask[m] = 0;
                    }
                }
            }
        }
        for (itrs = us.begin(); itrs != us.end(); ++itrs)
            (*fo) << (*itrs) << " ";
        (*fo) << "\n";
    }

    delete [] candMask;
}

//
void ExtraDisjointSet::merge(std::vector<std::unordered_set<size_t> >* pCandidate, std::string inputFile, std::string outputFile)
{
    std::ifstream fi;
    fi.open(inputFile.c_str());
    if(fi.good() == false) 
    {
        std::cerr << "[ERROR]: can't open " << inputFile << ".\n";
        exit(1);
    }
    std::ofstream fo;
    fo.open(outputFile.c_str());
    if(fo.good() == false) 
    {
        std::cerr << "[ERROR]: can't open " << outputFile << ".\n";
        exit(1);
    }

    // create mask
    int n = pCandidate->size();
    size_t* candMask = new size_t[n];
    int i;
    for (i = 0; i < n; ++i) candMask[i] = 1;


    // transform
    typedef std::unordered_map<size_t, std::unordered_set<size_t> > uMuS;
    typedef std::unordered_set<size_t> uS;
    uMuS pos2idx;
    uMuS::iterator itrms;
    uS::iterator itrs;

    std::vector<std::unordered_set<size_t> >::iterator itr;
    i = 0;
    for (itr = pCandidate->begin(); itr != pCandidate->end(); ++itr)
    {
        for (itrs = itr->begin(); itrs != itr->end(); ++itrs)
        {
            itrms = pos2idx.find(*itrs);
            if (itrms != pos2idx.end())
            {
                itrms->second.insert(i);
            }
            else 
            {
                uS us;
                us.insert(i);
                pos2idx.emplace(*itrs, us);
            }
        }
        ++i;
    }

    // merge
    size_t com, m;
    uS* pus;
    std::string line; 
    std::stringstream ss;
    uS::iterator itrs2;
    while(std::getline(fi,line))
    {
        uS us, us1;
        ss.str(""); ss.clear();
        ss.str(line);
        while(ss >> m) 
        { 
            us.insert(m);
            us1.insert(m);
        }

        for (itrs = us.begin(); itrs != us.end(); ++itrs)
        {
            itrms = pos2idx.find(*itrs);
            if (itrms == pos2idx.end()) continue;
            
            for (itrs2 = itrms->second.begin(); itrs2 != itrms->second.end(); ++itrs2)
            {
                m = (*itrs2);
                pus = &(pCandidate->at(m));
                com = getSetIntersectSize(&us1, pus);
                if (com >= _comThreshold)
                {
                    takeUnion(&us, pus);
                    candMask[m] = 0;
                }
            }
        }
        for (itrs = us.begin(); itrs != us.end(); ++itrs)
            (fo) << *itrs << " ";
        (fo) << "\n";
    }

    n = pCandidate->size();
    for (i = 0; i < n; ++i)
    {
        if (!candMask[i]) continue;
        pus = &(pCandidate->at(i));
        for (itrs = pus->begin(); itrs != pus->end(); ++itrs)
            (fo) << *itrs << " ";
        (fo) << "\n";
    }

    delete [] candMask;
    fi.close();
    fo.close();
}

//
std::string ExtraDisjointSet::merge(std::string inputFile1, std::string inputFile2)
{
    std::ifstream fi;
    std::ofstream fo;
    fi.open(inputFile1.c_str());
    if (fi.good() == false)
    {
        std::cerr << "[ERROR]: can't open " << inputFile1 << ".\n";
        exit(1);
    }
    std::string line;
    std::string outFile;
    size_t bufsize = 0;
    std::vector<std::unordered_set<size_t> > candidate;
    std::stringstream ss;
    size_t v;

    std::string outputFile = _tmpPath + std::to_string(_runCounter);
    std::string inputFile = inputFile2;
    while(std::getline(fi, line))
    {
        ss.str(""); ss.clear(); ss.str(line);
        std::unordered_set<size_t> pos;
        while(ss >> v) pos.insert(v);
        candidate.push_back(pos);
        bufsize += line.size();
        if (bufsize >= _chunkSize) 
        {

            merge(&candidate, inputFile, outputFile);

            ++ _runCounter;
            candidate.clear();
            bufsize = 0;
            if(fileExists(inputFile)) remove(inputFile.c_str());
            inputFile = outputFile;
            outFile = outputFile;
            outputFile = _tmpPath + std::to_string(_runCounter);
        }
    }
    if (bufsize > 0) // handle the rest small part
    {
        merge(&candidate, inputFile, outputFile);
        ++ _runCounter;
        candidate.clear();
        bufsize = 0;
        outFile = outputFile;
    }
    if (fileExists(inputFile1)) remove(inputFile1.c_str());
    if (fileExists(inputFile2)) remove(inputFile2.c_str());
    fi.close();
    return outFile;
}

#endif
