/*
    Structure from Motion with Deferred Feature Matching and Subset Bundle Adjustment
    Copyright (C) 2015 Andreas Ley <andy-ley@arcor.de>

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
*/

#include "BifocalRANSACFilter.h"


#include "../tools/PCVToolbox.hpp"
#include "../tools/TaskScheduler.h"
#include <random>
#include <chrono>
#include "../tools/SSEMath.h"

namespace SFM {

BifocalRANSACFilter::BifocalRANSACFilter()
{
    //ctor
}

BifocalRANSACFilter::~BifocalRANSACFilter()
{
    //dtor
}


void BifocalRANSACFilter::buildProbabilities(float lambda, const Match *matches, unsigned count)
{
    m_cumulativeProbabilities.resize(count);
    m_probabilities.resize(count);

    double sum = 0.0;
    for (unsigned i = 0; i < count; i++) {
        float prob = std::exp(-matches[i].normalizedDifference*matches[i].normalizedDifference * lambda);
        assert(std::isfinite(prob));

        m_probabilities[i] = prob;
        sum += prob;
        m_cumulativeProbabilities[i] = sum;
    }
    for (unsigned i = 0; i < count; i++) {
        m_cumulativeProbabilities[i] /= sum;
    }
}

unsigned BifocalRANSACFilter::pickMatch(float frac) const
{
    unsigned min = 0;
    unsigned max = m_cumulativeProbabilities.size()-1;
    while (min < max) {
        unsigned center = (min+max)/2;
        if (m_cumulativeProbabilities[center] < frac) {
            min = center+1;
        } else {
            max = center;
        }
    }
    return min;
}



void BifocalRANSACFilter::filter(const Match *matches, float *probs, unsigned count, unsigned iterations, LinAlg::Matrix3x3f &bestFundamental, float scaleFactor)
{
    buildProbabilities(0.15f/(0.5f*0.5f), matches, count);

    std::vector<LinAlg::Matrix3x3f> fundamentals;
    std::vector<float> scores;

    fundamentals.resize(32);
    scores.resize(32);

#if 1
    TaskGroup group;
    for (unsigned i = 0; i < 32; i++) {
        group.add(boost::bind(&BifocalRANSACFilter::doRansac, this, matches, count,
                              1234*i, iterations/32, &fundamentals[i], &scores[i]), TaskScheduler::get());
    }
    TaskScheduler::get().waitFor(&group);
#else
    for (unsigned i = 0; i < 32; i++) {
        doRansac(1234*i, iterations/32, &fundamentals[i], &projections[i], &scores[i]);
    }
#endif


    float bestScore;
    bestFundamental = fundamentals[0];
    bestScore = scores[0];

    for (unsigned i = 1; i < scores.size(); i++) {
        if (scores[i] > bestScore) {
            bestScore = scores[i];
            bestFundamental = fundamentals[i];
        }
    }

    computeProbabilitiesOfFundamental(matches, count, bestFundamental, probs);

}


void BifocalRANSACFilter::doRansac(const Match *matches, unsigned count, unsigned seed, unsigned iters, LinAlg::Matrix3x3f *dstFundamental, float *dstScore)
{
    LinAlg::Matrix3x3f bestFundamental;
    float bestProbability = 0.0f;

    const unsigned numMatchesInSet = 8;

    std::vector<LinAlg::Vector3f> firstSet;
    std::vector<LinAlg::Vector3f> secondSet;
    firstSet.resize(numMatchesInSet);
    secondSet.resize(numMatchesInSet);

    std::mt19937 rne(std::chrono::system_clock::now().time_since_epoch().count() ^ seed);
    std::uniform_real_distribution<float> distribution(0.0, 1.0);

    for (unsigned iter = 0; iter < iters/4; iter++) {
        LinAlg::Matrix3x3f fundamentals[4];
        for (int k = 0; k < 4; k++) {
        //    std::cout << "Iter " << iter << std::endl;
            for (unsigned i = 0; i < numMatchesInSet; i++) {
                unsigned matchIndex = pickMatch(distribution(rne));
                //unsigned matchIndex = std::min((int)(distribution(rne) * m_matches.size()), (int)(m_matches.size()-1));
      //          std::cout << "Match "<<i<<": " << matchIndex << std::endl;

                firstSet[i] = matches[matchIndex].pos[0].AddHom(1.0f);
                secondSet[i] = matches[matchIndex].pos[1].AddHom(1.0f);
            }

            try {
                fundamentals[k] = PCV::getFundamentalMatrix(firstSet, secondSet);
            } catch (...) {
                k--;
            }
        }

        float probs[4];
        computeProbabilityOfFourFundamentals(matches, count, fundamentals, probs);

        for (unsigned k = 0; k < 4; k++) {
            if (probs[k] > bestProbability) {
                bestFundamental = fundamentals[k];
                bestProbability = probs[k];
            }
        }
    }

    *dstFundamental = bestFundamental;
    *dstScore = bestProbability;
}


void BifocalRANSACFilter::computeMatchProbabilityGivenFourFundamentals(const float *fourFundamentals, const Match *match, __m128 &result)
{
    __m128 pos1X = _mm_set1_ps(match->pos[0][0]);
    __m128 pos1Y = _mm_set1_ps(match->pos[0][1]);

    __m128 pos2X = _mm_set1_ps(match->pos[1][0]);
    __m128 pos2Y = _mm_set1_ps(match->pos[1][1]);


    __m128 epipolarLine1X;
    __m128 epipolarLine1Y;
    __m128 epipolarLine1Z;

    __m128 epipolarLine2X;
    __m128 epipolarLine2Y;
    __m128 epipolarLine2Z;

    {
        __m128 fElem = _mm_load_ps(fourFundamentals + 4*(0*3+0));
        epipolarLine2X = _mm_mul_ps(pos1X, fElem);
        epipolarLine1X = _mm_mul_ps(pos2X, fElem);
    }
    {
        __m128 fElem = _mm_load_ps(fourFundamentals + 4*(1*3+1));
        epipolarLine2Y = _mm_mul_ps(pos1Y, fElem);
        epipolarLine1Y = _mm_mul_ps(pos2Y, fElem);
    }
    {
        __m128 fElem = _mm_load_ps(fourFundamentals + 4*(2*3+2));
        epipolarLine2Z = fElem;
        epipolarLine1Z = fElem;
    }



    {
        __m128 fElem = _mm_load_ps(fourFundamentals + 4*(0*3+1));
        epipolarLine2X = _mm_add_ps(epipolarLine2X, _mm_mul_ps(pos1Y, fElem));
        epipolarLine1Y = _mm_add_ps(epipolarLine1Y, _mm_mul_ps(pos2X, fElem));
    }

    {
        __m128 fElem = _mm_load_ps(fourFundamentals + 4*(0*3+2));
        epipolarLine2X = _mm_add_ps(epipolarLine2X, fElem);
        epipolarLine1Z = _mm_add_ps(epipolarLine1Z, _mm_mul_ps(pos2X, fElem));
    }

    {
        __m128 fElem = _mm_load_ps(fourFundamentals + 4*(1*3+0));
        epipolarLine2Y = _mm_add_ps(epipolarLine2Y, _mm_mul_ps(pos1X, fElem));
        epipolarLine1X = _mm_add_ps(epipolarLine1X, _mm_mul_ps(pos2Y, fElem));
    }
    {
        __m128 fElem = _mm_load_ps(fourFundamentals + 4*(1*3+2));
        epipolarLine2Y = _mm_add_ps(epipolarLine2Y, fElem);
        epipolarLine1Z = _mm_add_ps(epipolarLine1Z, _mm_mul_ps(pos2Y, fElem));
    }
    {
        __m128 fElem = _mm_load_ps(fourFundamentals + 4*(2*3+0));
        epipolarLine2Z = _mm_add_ps(epipolarLine2Z, _mm_mul_ps(pos1X, fElem));
        epipolarLine1X = _mm_add_ps(epipolarLine1X, fElem);
    }
    {
        __m128 fElem = _mm_load_ps(fourFundamentals + 4*(2*3+1));
        epipolarLine2Z = _mm_add_ps(epipolarLine2Z, _mm_mul_ps(pos1Y, fElem));
        epipolarLine1Y = _mm_add_ps(epipolarLine1Y, fElem);
    }



    __m128 denom1, denom2;
    {
        __m128 sqrLen = _mm_add_ps(_mm_mul_ps(epipolarLine1X, epipolarLine1X), _mm_mul_ps(epipolarLine1Y, epipolarLine1Y));
        denom1 = _mm_rsqrt_ps(_mm_max_ps(sqrLen, _mm_set1_ps(1e-20f)));
    }
    {
        __m128 sqrLen = _mm_add_ps(_mm_mul_ps(epipolarLine2X, epipolarLine2X), _mm_mul_ps(epipolarLine2Y, epipolarLine2Y));
        denom2 = _mm_rsqrt_ps(_mm_max_ps(sqrLen, _mm_set1_ps(1e-20f)));
    }


    __m128 distance1 = _mm_mul_ps(_mm_add_ps(_mm_add_ps(_mm_mul_ps(pos1X, epipolarLine1X), _mm_mul_ps(pos1Y, epipolarLine1Y)), epipolarLine1Z), denom1);
    __m128 distance2 = _mm_mul_ps(_mm_add_ps(_mm_add_ps(_mm_mul_ps(pos2X, epipolarLine2X), _mm_mul_ps(pos2Y, epipolarLine2Y)), epipolarLine2Z), denom2);


    float scale1 = 1.0f / match->size[0];
    float scale2 = 1.0f / match->size[1];

    __m128 patchScaleDistance1 = _mm_mul_ps(distance1, _mm_set1_ps(scale1));
    __m128 patchScaleDistance2 = _mm_mul_ps(distance2, _mm_set1_ps(scale2));

    const float rcpSigma = 4.0f;

    __m128 d = _mm_mul_ps(_mm_add_ps(_mm_mul_ps(patchScaleDistance1, patchScaleDistance1), _mm_mul_ps(patchScaleDistance2, patchScaleDistance2)), _mm_set1_ps(-rcpSigma));

#if 0
    result = exp_ps(d);
#else
    result = _mm_rcp_ps(_mm_add_ps(_mm_set1_ps(1.0f), _mm_mul_ps(_mm_mul_ps(d, d), _mm_set1_ps(2.0f))));
#endif
}

void BifocalRANSACFilter::computeProbabilityOfFourFundamentals(const Match *matches, unsigned count, const LinAlg::Matrix3x3f *fundamentals, float *probs)
{
    float fourFundamentals[3*3*4] __attribute__((aligned(16)));

    for (unsigned i = 0; i < 3; i++)
        for (unsigned j = 0; j < 3; j++)
            for (unsigned k = 0; k < 4; k++)
                fourFundamentals[i*3*4+j*4+k] = fundamentals[k][i][j];


    __m128 sum = _mm_setzero_ps();
    for (unsigned i = 0; i < count; i++) {
        __m128 prob;
        computeMatchProbabilityGivenFourFundamentals(fourFundamentals, matches+i, prob);
        sum = _mm_add_ps(sum, _mm_mul_ps(prob, _mm_set1_ps(m_probabilities[i])));
    }
    _mm_storeu_ps(probs, sum);
}


float BifocalRANSACFilter::computeMatchProbabilityGivenFundamental(const LinAlg::Matrix3x3f &fundamental, const Match *match)
{
    LinAlg::Vector3f epipolarLine2(fundamental * match->pos[0].AddHom(1.0f));
    LinAlg::Vector3f epipolarLine1(fundamental.T() * match->pos[1].AddHom(1.0f));

    float denom1 = std::sqrt(epipolarLine1.StripHom().SQRLen());
    float denom2 = std::sqrt(epipolarLine2.StripHom().SQRLen());

    if ((denom1 < 1e-20f) || (denom2 < 1e-20f)) {
            /*
        std::cout << (std::string) epipolarLine1 << std::endl;
        std::cout << (std::string) epipolarLine2 << std::endl;
        */
        return 0.0f;
    }

    float distance1 = match->pos[0].AddHom(1.0f) * epipolarLine1 / denom1;
    float distance2 = match->pos[1].AddHom(1.0f) * epipolarLine2 / denom2;

    float patchScaleDistance1 = distance1 / match->size[0];
    float patchScaleDistance2 = distance2 / match->size[1];

    const float rcpSigma = 4.0f;

    float prob = std::exp(-(patchScaleDistance1*patchScaleDistance1 + patchScaleDistance2*patchScaleDistance2) * rcpSigma);
    return prob;
}

void BifocalRANSACFilter::computeProbabilitiesOfFundamental(const Match *matches, unsigned count, const LinAlg::Matrix3x3f &fundamental, float *probs)
{
    for (unsigned i = 0; i < count; i++) {
        float prob = computeMatchProbabilityGivenFundamental(fundamental, matches+i);
        probs[i] = prob * m_probabilities[i];
    }
}





}
