//-----------------------------------------------
// Copyright 2016 Guangxi University
// Written by Liang Zhao(S080011@e.ntu.edu.sg)
// Released under the GPL
//-----------------------------------------------
//

#include "ErrorCorrect.h"
#include <omp.h>
#include <fstream>
#include <iostream>
#include <sstream>
#include <cmath>
#include "ExtraDisjointSet.h"


Corrector::Corrector(std::string readsFile, 
                     std::string candidateFile, 
                     size_t bufferSize, 
                     size_t numThreads, 
                     std::string outFormat, 
                     std::string readsDim, 
                     double logRatio, 
                     std::string outputFile,
                     double e,
                     size_t maxFreq)
    : m_inputReadsFile(readsFile),
      m_erroneousBasesFile(candidateFile),
      m_bufferSize(bufferSize),
      m_numThreads(numThreads),
      m_outFormat(outFormat),
      m_readsDim(readsDim),
      m_rtThreshold(logRatio),
      m_outputReadsFile(outputFile),
      m_E(e),
      m_maxFreq(maxFreq)
{
    getReadsDimension();
}

void Corrector::combineErrorneousBases()
{
    ExtraDisjointSet eds(m_erroneousBasesFile, "", m_bufferSize);
    eds.run();
}

void Corrector::learnRatio(size_t n)
{
    m_rtThreshold = log10(1-pow((1-m_E), n)) - n*log10(1-m_E) - 0.477;
}

void Corrector::importReads()
{
    std::ifstream fi;
    fi.open(m_inputReadsFile.c_str());
    if (fi.good() == false)
    {
        std::cerr << "[ERROR]: Can't open " << m_inputReadsFile << "\n";
        exit(1);
    }
    std::string line;
    while(std::getline(fi, line))
    {
        if (line[0] == '@')
        {
            std::getline(fi, line);
            uint8_t *bits = new uint8_t[(line.length()+3)>>2];
            string2bits(line, bits);
            m_reads.push_back(bits);
            std::getline(fi, line); std::getline(fi, line);
        } 
        else if (line[0] == '>')
        {
            std::getline(fi, line);
            uint8_t *bits = new uint8_t[(line.length()+3)>>2];
            string2bits(line, bits);
            m_reads.push_back(bits);
        }
    }
    fi.close();
}

void Corrector::string2bits(std::string seq, uint8_t *pbits)
{
    size_t n = seq.length();
    for (size_t i = 0; i < n; ++i)
        pbits[i/4] |= (ErrorCorrect::Base2Bits[seq[i]] << ((3 - (i % 4)) * 2));
}

void Corrector::getReadsDimension()
{
    std::ifstream fi;
    fi.open(m_readsDim.c_str());
    fi >> m_numReads >> m_readLen;
    fi.close();
}

char Corrector::retrieveBase(uint64_t index)
{
    size_t readIdx = index >> (ErrorCorrect::glz_READLENBIT + ErrorCorrect::glz_READFLAG);
    size_t posIdx = (index >> ErrorCorrect::glz_READFLAG) & 0x1ff;
    size_t flag = index & 0x1;
    if(flag) // reverse complement
    {
        posIdx = m_readLen - posIdx - 1;
        return ErrorCorrect::Bits2Base[~((m_reads[readIdx][posIdx>>2]) >> ((3 - posIdx % 4) * 2)) & 0x3];
    }
    else 
    {
        return ErrorCorrect::Bits2Base[((m_reads[readIdx][posIdx>>2]) >> ((3 - posIdx % 4) * 2)) & 0x3];
    }
}

void Corrector::correctError(std::vector<uint64_t> *pcandidateBases)
{
    std::map<char, size_t> freq;
    std::vector<uint64_t>::iterator itr;
    char base;
    std::map<char, size_t>::iterator itrm;
    for (itr = pcandidateBases->begin(); itr != pcandidateBases->end(); ++itr)
    {
        base = retrieveBase(*itr);
        itrm = freq.find(base);
        if (itrm != freq.end())
            itrm->second ++;
        else 
            freq.emplace(base, 1);
    }
    // determine which is the reference, and correct the rest based on the reference
    size_t maxfreq = 0;
    size_t allfreq = 0;
    char maxbase = 'A';
    for (itrm = freq.begin(); itrm != freq.end(); ++itrm)
    {
        if (itrm->second > maxfreq)
        {
            maxfreq = itrm->second;
            maxbase = itrm->first;
        }
        allfreq += itrm->second;
    }
    learnRatio(allfreq);
    float rt[4] = {0,0,0,0};
    if (!maxfreq) return;
    for (itrm = freq.begin(); itrm != freq.end(); ++itrm)
        if (itrm->second <= m_maxFreq)
            rt[ErrorCorrect::Base2Bits[itrm->first]] = log2(float(itrm->second)/maxfreq);
    bool hasError = 0;
    for (size_t i = 0; i < 4; ++i) if (rt[i] <= m_rtThreshold) hasError = 1;
    if (!hasError) return;
    size_t readIdx, posIdx, flag;
    for (itr = pcandidateBases->begin(); itr != pcandidateBases->end(); ++itr)
    {
        readIdx = (*itr) >> (ErrorCorrect::glz_READLENBIT + ErrorCorrect::glz_READFLAG);
        posIdx = ((*itr) >> ErrorCorrect::glz_READFLAG) & 0x1ff;
        flag = (*itr) & 0x1;
        if (flag)
        {
            posIdx = m_readLen - posIdx - 1;
            base = ErrorCorrect::Bits2Base[~((m_reads[readIdx][posIdx>>2]) >> ((3 - posIdx % 4) * 2)) & 0x3];
        }
        else
            base = ErrorCorrect::Bits2Base[((m_reads[readIdx][posIdx>>2]) >> ((3 - posIdx % 4) * 2)) & 0x3];
        if (rt[ErrorCorrect::Base2Bits[base]] <= m_rtThreshold)
        {
            if (flag)
            {
                posIdx = m_readLen - posIdx - 1;
                m_reads[readIdx][posIdx>>2] &= ((~(0x3 << ((3 - posIdx % 4) * 2))) & 0xff);
                m_reads[readIdx][posIdx>>2] |= ((((~ErrorCorrect::Base2Bits[maxbase]) & 0x3) << ((3 - posIdx % 4) * 2)) & 0xff);
            }
            else
            {
                m_reads[readIdx][posIdx>>2] &= ((~(0x3 << ((3 - posIdx % 4) * 2))) & 0xff);
                m_reads[readIdx][posIdx>>2] |= ((ErrorCorrect::Base2Bits[maxbase] << ((3 - posIdx % 4) * 2)) & 0xff);
            }
        }
    }
}

void Corrector::correctError(std::string line)
{
    uint64_t v;
    std::stringstream ss;
    std::vector<uint64_t> errbases;
    ss.str(line);
    while(ss >> v) errbases.push_back(v);
    correctError(&errbases);
}

void Corrector::correct()
{
    std::ifstream fi;
    fi.open(m_erroneousBasesFile.c_str());
    if (fi.good() == false)
    {
        std::cerr << "[ERROR]: can't open " << m_erroneousBasesFile << ".\n";
        exit(1);
    }
    std::string line;
    std::vector<std::string> lines;
    size_t i, n, m;
    uint64_t buffer = 1000;
    if (m_numThreads < 1)
    {
        m = omp_get_num_threads();
        m = m > 1 ? m - 1 : m;
    }
    else 
        m = m_numThreads;
    while(std::getline(fi, line))
    {
        if(lines.size() >= buffer)
        {
            n = lines.size();
            omp_set_num_threads(m);
#pragma omp parrallel for 
            for (i = 0; i < n; ++i)
            {
                correctError(lines[i]);
            }
#pragma omp barrier
            lines.clear();
        }
    }

    if(lines.size() > 0)
    {
        n = lines.size();
        omp_set_num_threads(m);
#pragma omp parrallel for 
        for (i = 0; i < n; ++i)
        {
            correctError(lines[i]);
        }
#pragma omp barrier
        lines.clear();
    }
    fi.close();
}

std::string Corrector::toString(uint8_t *pbits, size_t n)
{
    std::string s = "";
    for (size_t i = 0; i < n; ++i)
        s += ErrorCorrect::Bits2Base[(pbits[i/4] >> ((3 - i % 4) * 2)) & 0x3];
    return s;
}

void Corrector::postprocess()
{
    std::ifstream fi;
    fi.open(m_inputReadsFile.c_str());
    if (fi.good() == false)
    {
        std::cerr << "[ERROR]: can't open " << m_inputReadsFile << ".\n";
        exit(1);
    }
    std::ofstream fo;
    fo.open(m_outputReadsFile.c_str());
    if (fo.good() == false)
    {
        std::cerr << "[ERROR]: can't create " << m_outputReadsFile << ".\n";
        exit(1);
    }
    size_t i = 0;
    std::string line;
    while (std::getline(fi, line))
    {
        fo << line << "\n";
        if (line[0] == '@')
        {
            std::getline(fi, line);
            fo << toString(m_reads[i], line.size()) << "\n";
            if (m_outFormat == "fastq")
            {
                std::getline(fi, line); fo << line << "\n";
                std::getline(fi, line); fo << line << "\n";
            }
            else 
            {
                std::getline(fi, line); std::getline(fi, line);
            }
        }
        else if (line[0] == '>')
        {
            std::getline(fi, line);
            fo << toString(m_reads[i], line.size()) << "\n";
        }
        ++i;
    }
    fi.close();
    fo.close();
}
