//-----------------------------------------------
// Copyright 2016 Guangxi University
// Written by Liang Zhao(S080011@e.ntu.edu.sg)
// Released under the GPL
//-----------------------------------------------
//
// KmerOverlaps - Overlap computation functions, seeded by exact kmer matches
// It's developed based on SGA originally writen by Jared Simpson (js18@sanger.ac.uk)
//

#include "KmerOverlaps.h"
#include "HashMap.h"
#include "BWTAlgorithms.h"
#include "Profiler.h"
#include "Timer.h"
//#include "Verbosity.h"

//
MultipleAlignment KmerOverlaps::buildMultipleAlignment(size_t lz_index, int depthFilter, const std::string& query, //@ -- lzhao --
                                                       size_t k,
                                                       int min_overlap,
                                                       double min_identity,
                                                       int bandwidth,
                                                       const BWTIndexSet& indices)
{
    SequenceOverlapPairVector overlap_vector = retrieveMatches(query, k, min_overlap, min_identity, bandwidth, indices);
    MultipleAlignment ma;
    int npairs = overlap_vector.size();
    if ((npairs < depthFilter) && (npairs > 0))
        ma.getCandidate(overlap_vector, lz_index);
    return ma;
}

// Struct to hold a partial match in the FM-index
// The position field is the location in the query sequence of this kmer.
// The index field is an index into the BWT. 
// The is_reverse flag indicates the strand of the partial match
struct KmerMatch
{
    size_t position:16;
    size_t index:47;
    size_t is_reverse:1;

    friend bool operator<(const KmerMatch& a, const KmerMatch& b)
    {
        if(a.index == b.index)
            return a.is_reverse < b.is_reverse;
        else
            return a.index < b.index;
    }

    friend bool operator==(const KmerMatch& a, const KmerMatch& b)
    {
        return a.index == b.index && a.is_reverse == b.is_reverse;
    }
};

// Return a hash key for a KmerMatch
struct KmerMatchKey
{
    size_t operator()(const KmerMatch& a) const { return a.index; }
};

typedef std::set<KmerMatch> KmerMatchSet;
typedef HashMap<KmerMatch, bool, KmerMatchKey> KmerMatchMap;

//
SequenceOverlapPairVector KmerOverlaps::retrieveMatches(const std::string& query, size_t k, 
                                                        int min_overlap, double min_identity,
                                                        int bandwidth, const BWTIndexSet& indices)
{
    PROFILE_FUNC("OverlapHaplotypeBuilder::retrieveMatches")
    assert(indices.pBWT != NULL);
    assert(indices.pSSA != NULL);

    static size_t n_calls = 0;
    static size_t n_candidates = 0;
    static size_t n_output = 0;
    static double t_time = 0;
    Timer timer("test", true);

    n_calls++;

    int64_t max_interval_size = 200;
    SequenceOverlapPairVector overlap_vector;
    if(query.size() < k)
        return overlap_vector;

    // Use the FM-index to look up intervals for each kmer of the read. Each index
    // in the interval is stored individually in the KmerMatchMap. We then
    // backtrack to map these kmer indices to read IDs. As reads can share
    // multiple kmers, we use the map to avoid redundant lookups.
    // There is likely a faster algorithm which performs direct decompression
    // of the read sequences without having to expand the intervals to individual
    // indices. The current algorithm suffices for now.
    KmerMatchMap prematchMap;
    size_t num_kmers = query.size() - k + 1;
    for(size_t i = 0; i < num_kmers; ++i)
    {
        std::string kmer = query.substr(i, k);
        BWTInterval interval = BWTAlgorithms::findInterval(indices, kmer);
        if(interval.isValid() && interval.size() < max_interval_size) 
        {
            for(int64_t j = interval.lower; j <= interval.upper; ++j)
            {
                KmerMatch match = { i, static_cast<size_t>(j), false };
                prematchMap.insert(std::make_pair(match, false));
            }
        }

        kmer = reverseComplement(kmer);
        interval = BWTAlgorithms::findInterval(indices, kmer);
        if(interval.isValid() && interval.size() < max_interval_size) 
        {
            for(int64_t j = interval.lower; j <= interval.upper; ++j)
            {
                KmerMatch match = { i, static_cast<size_t>(j), true };
                prematchMap.insert(std::make_pair(match, false));
            }
        }
    }

    // Backtrack through the kmer indices to turn them into read indices.
    // This mirrors the calcSA function in SampledSuffixArray except we mark each entry
    // as visited once it is processed.
    KmerMatchSet matches;
    for(KmerMatchMap::iterator iter = prematchMap.begin(); iter != prematchMap.end(); ++iter)
    {
        // This index has been visited
        if(iter->second)
            continue;

        // Mark this as visited
        iter->second = true;

        // Backtrack the index until we hit the starting symbol
        KmerMatch out_match = iter->first;
        while(1) 
        {
            char b = indices.pBWT->getChar(out_match.index);
            out_match.index = indices.pBWT->getPC(b) + indices.pBWT->getOcc(b, out_match.index - 1);

            // Check if the hash indicates we have visited this index. If so, stop the backtrack
            KmerMatchMap::iterator find_iter = prematchMap.find(out_match);
            if(find_iter != prematchMap.end())
            {
                // We have processed this index already
                if(find_iter->second)
                    break;
                else
                    find_iter->second = true;
            }

            if(b == '$')
            {
                // We've found the lexicographic index for this read. Turn it into a proper ID
                out_match.index = indices.pSSA->lookupLexoRank(out_match.index);
                matches.insert(out_match);
                break;
            }
        }
    }

    // Refine the matches by computing proper overlaps between the sequences
    // Use the overlaps that meet the thresholds to build a multiple alignment
    for(KmerMatchSet::iterator iter = matches.begin(); iter != matches.end(); ++iter)
    {
        // If a read table is available in the index, use it to get the match sequence
        // Otherwise get it from the BWT, which is slower
        std::string match_sequence;
        if(indices.pReadTable != NULL)
            match_sequence = indices.pReadTable->getRead(iter->index).seq.toString();
        else
            match_sequence = BWTAlgorithms::extractString(indices.pBWT, iter->index);

        if(iter->is_reverse)
            match_sequence = reverseComplement(match_sequence);
        
        // Ignore identical matches
        if(match_sequence == query)
            continue;

        // Compute the overlap. If the kmer match occurs a single time in each sequence we use
        // the banded extension overlap strategy. Otherwise we use the slow O(M*N) overlapper.
        SequenceOverlap overlap;
        std::string match_kmer = query.substr(iter->position, k);
        size_t pos_0 = query.find(match_kmer);
        size_t pos_1 = match_sequence.find(match_kmer);
        assert(pos_0 != std::string::npos && pos_1 != std::string::npos);

        // Check for secondary occurrences
        if(query.find(match_kmer, pos_0 + 1) != std::string::npos || 
           match_sequence.find(match_kmer, pos_1 + 1) != std::string::npos) {
            overlap = ExtendAlignment::BandAlign(query, match_sequence, bandwidth, 0, min_overlap);
        } else {
            overlap = ExtendAlignment::ExtendAlign(query, match_sequence, pos_0, pos_1, bandwidth);
        }

        n_candidates += 1;
        bool bPassedOverlap = overlap.getOverlapLength() >= min_overlap;
        bool bPassedIdentity = overlap.getPercentIdentity() >= min_identity;

        if(bPassedOverlap && bPassedIdentity)
        {
            SequenceOverlapPair op;
            op.sequence[0] = query;
            op.sequence[1] = match_sequence;
            op.match_idx = iter->index;
            op.overlap = overlap;
            op.is_reversed = iter->is_reverse;
            overlap_vector.push_back(op);
            n_output += 1;
        }
    }

    t_time += timer.getElapsedCPUTime();

    //if(Verbosity::Instance().getPrintLevel() > 6 && n_calls % 100 == 0)
    //    printf("[kmer overlaps] n: %zu candidates: %zu valid: %zu (%.2lf) time: %.2lfs\n", 
    //        n_calls, n_candidates, n_output, (double)n_output / n_candidates, t_time);
    return overlap_vector;
}

uint64_t MultipleAlignment::encode(size_t index, size_t position, bool isReverse)
{
    uint64_t result = 0;
    result = (index << 10) + (position << 1) + (isReverse ? 1 : 0);
    return result;
}

//
void MultipleAlignment::getCandidate(SequenceOverlapPairVector sopv, size_t lz_baseIndex)
{
    SequenceOverlapPairVector::iterator itr;
    int queryLen = sopv.begin()->sequence[0].length();
    for (size_t i = 0; i < size_t(queryLen); ++i) 
    {
        V_UI64 candidate;
        uint64_t index = encode(lz_baseIndex, i, 0); // (uint64_t(i) << 48) + (lz_baseIndex << 1);
        candidate.push_back(index);
        for (itr = sopv.begin(); itr != sopv.end(); ++itr)
        {
            int position = itr->overlap.start + i;
            if ((position >= 0) && (position < queryLen))
            {
                index = encode(itr->match_idx, position, itr->is_reversed); //(uint64_t(position) << 48) + (uint64_t(itr->match_idx) << 1) + (itr->is_reversed ? 1 : 0);
                candidate.push_back(index);
            }
        }
        if (candidate.size() > 1)
        {
            lz_candidate.push_back(candidate);
        }
    }
    lz_indices.push_back(uint64_t(lz_baseIndex));
    for (itr = sopv.begin(); itr != sopv.end(); ++itr)
    {
        if (int(lz_baseIndex) != itr->match_idx)
            lz_indices.push_back(uint64_t(itr->match_idx));
    }
}

void MultipleAlignment::write(std::ostream& out) 
{
    VV_UI64::iterator itr1;
    V_UI64::iterator itr2;
    for (itr1 = lz_candidate.begin(); itr1 != lz_candidate.end(); ++itr1)
    {
        for (itr2 = itr1->begin(); itr2 != itr1->end(); ++itr2)
        {
            out << *itr2 << " ";
        }
        out << "\n";
    }
}
