/*
    Structure from Motion with Deferred Feature Matching and Subset Bundle Adjustment
    Copyright (C) 2015 Andreas Ley <andy-ley@arcor.de>

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
*/

#include "ViewMatrixEstimator.h"

#include "../tools/TaskScheduler.h"
#include "../tools/PCVToolbox.hpp"
#include "../tools/SSEMath.h"
#include "../tools/AlignedAllocator.h"
#include <random>
#include <chrono>
#include <cmath>


namespace SFM {

ViewMatrixEstimator::ViewMatrixEstimator()
{
    //ctor
}

ViewMatrixEstimator::~ViewMatrixEstimator()
{
    //dtor
}

void ViewMatrixEstimator::buildProbabilities(float lambda, const ControlPoint *controlPoints, unsigned count)
{
    m_cumulativeProbabilities.resize(count);
    m_probabilities.resize(count);

    double sum = 0.0;
    for (unsigned i = 0; i < count; i++) {
        float prob = std::exp(-controlPoints[i].normalizedDifference*controlPoints[i].normalizedDifference * lambda);
        m_probabilities[i] = prob;
        sum += prob;
        m_cumulativeProbabilities[i] = sum;
    }
    for (unsigned i = 0; i < count; i++) {
        m_cumulativeProbabilities[i] /= sum;
    }
}

unsigned ViewMatrixEstimator::pickMatch(float frac) const
{
    unsigned min = 0;
    unsigned max = m_cumulativeProbabilities.size()-1;
    while (min < max) {
        unsigned center = (min+max)/2;
        if (m_cumulativeProbabilities[center] < frac) {
            min = center+1;
        } else {
            max = center;
        }
    }
    return min;
}



float computeReprojError(const std::vector<LinAlg::Vector4f> &ransacPoints3D,
                   const std::vector<LinAlg::Vector2f> &euclPoints2D,
                   const LinAlg::Matrix3x4f &projectionViewMatrix)
{
    float sum = 0.0f;
    for (unsigned i = 0; i < ransacPoints3D.size(); i++) {
        LinAlg::Vector3f screenSpaceProj = projectionViewMatrix * ransacPoints3D[i];
        LinAlg::Vector2f screenSpace = screenSpaceProj.StripHom() / screenSpaceProj[2];

        float dx = screenSpace[0] - euclPoints2D[i][0];
        float dy = screenSpace[1] - euclPoints2D[i][1];

        sum += dx*dx + dy*dy;
    }
    return sum;
}

void optimizeViewMatrix(const std::vector<LinAlg::Vector4f> &ransacPoints3D,
                        const std::vector<LinAlg::Vector3f> &ransacPoints2D,
                        const LinAlg::Matrix3x4f &projectionMatrix,
                        LinAlg::Matrix4x4f &viewMatrix)
{
    std::vector<LinAlg::Vector2f> euclPoints2D;
    euclPoints2D.resize(ransacPoints2D.size());
    for (unsigned i = 0; i < ransacPoints2D.size(); i++)
        euclPoints2D[i] = ransacPoints2D[i].StripHom() / ransacPoints2D[i][2];

    for (unsigned iter = 0; iter < 20; iter++) {
/*
        std::cout << (std::string) viewMatrix[0] << std::endl;
        std::cout << (std::string) viewMatrix[1] << std::endl;
        std::cout << (std::string) viewMatrix[2] << std::endl;
        std::cout << (std::string) viewMatrix[3] << std::endl;
        std::cout << "iter " << iter << " error: " << computeReprojError(ransacPoints3D, euclPoints2D,
                    projectionMatrix * viewMatrix) << std::endl;
*/
        LinAlg::Vector3f translationGrad;
        LinAlg::Vector3f rotationGrad;

        translationGrad[0] =
                computeReprojError(ransacPoints3D, euclPoints2D,
                    projectionMatrix * LinAlg::Translation3D(LinAlg::Fill(0.01f, 0.0f, 0.0f)) * viewMatrix) -
                computeReprojError(ransacPoints3D, euclPoints2D,
                    projectionMatrix * LinAlg::Translation3D(LinAlg::Fill(-0.01f, 0.0f, 0.0f)) * viewMatrix);

        translationGrad[1] =
                computeReprojError(ransacPoints3D, euclPoints2D,
                    projectionMatrix * LinAlg::Translation3D(LinAlg::Fill(0.0f, 0.01f, 0.0f)) * viewMatrix) -
                computeReprojError(ransacPoints3D, euclPoints2D,
                    projectionMatrix * LinAlg::Translation3D(LinAlg::Fill(0.0f, -0.01f, 0.0f)) * viewMatrix);

        translationGrad[2] =
                computeReprojError(ransacPoints3D, euclPoints2D,
                    projectionMatrix * LinAlg::Translation3D(LinAlg::Fill(0.0f, 0.0f, 0.01f)) * viewMatrix) -
                computeReprojError(ransacPoints3D, euclPoints2D,
                    projectionMatrix * LinAlg::Translation3D(LinAlg::Fill(0.0f, 0.0f, -0.01f)) * viewMatrix);

        rotationGrad[0] =
                computeReprojError(ransacPoints3D, euclPoints2D,
                    projectionMatrix * LinAlg::RotateX(0.001f) * viewMatrix) -
                computeReprojError(ransacPoints3D, euclPoints2D,
                    projectionMatrix * LinAlg::RotateX(-0.001f) * viewMatrix);

        rotationGrad[1] =
                computeReprojError(ransacPoints3D, euclPoints2D,
                    projectionMatrix * LinAlg::RotateY(0.001f) * viewMatrix) -
                computeReprojError(ransacPoints3D, euclPoints2D,
                    projectionMatrix * LinAlg::RotateY(-0.001f) * viewMatrix);

        rotationGrad[2] =
                computeReprojError(ransacPoints3D, euclPoints2D,
                    projectionMatrix * LinAlg::RotateZ(0.001f) * viewMatrix) -
                computeReprojError(ransacPoints3D, euclPoints2D,
                    projectionMatrix * LinAlg::RotateZ(-0.001f) * viewMatrix);


        viewMatrix = LinAlg::RotateX(-rotationGrad[0] * 1.0f) *
                     LinAlg::RotateY(-rotationGrad[1] * 1.0f) *
                     LinAlg::RotateZ(-rotationGrad[2] * 1.0f) *
                     LinAlg::Translation3D(translationGrad * (-0.1f)) *
                     viewMatrix;

        LinAlg::Vector3f a, b, c;
        a = viewMatrix[0].StripHom();
        b = viewMatrix[1].StripHom();
        c = viewMatrix[2].StripHom();

        LinAlg::GramSchmidt(a, b, c);

        viewMatrix[0] = a.AddHom(viewMatrix[0][3]);
        viewMatrix[1] = b.AddHom(viewMatrix[1][3]);
        viewMatrix[2] = c.AddHom(viewMatrix[2][3]);
        viewMatrix[3] = LinAlg::Fill<float>(0.0f, 0.0f, 0.0f, 1.0f);
    }
}

void ViewMatrixEstimator::estimate(const LinAlg::Matrix3x4f &internalCalibration, const ControlPoint *controlPoints, float *probs, unsigned count, unsigned iterations, LinAlg::Matrix4x4f &bestViewMatrix, float &bestScore)
{
    buildProbabilities(0.15f/(0.8f*0.8f), controlPoints, count);

    m_internalCalibration = internalCalibration;

    std::vector<LinAlg::Matrix4x4f> viewMatrices;
    std::vector<float> scores;

    viewMatrices.resize(32);
    scores.resize(32);

#if 1
    TaskGroup group;
    for (unsigned i = 0; i < 32; i++) {
        group.add(boost::bind(&ViewMatrixEstimator::doRansac, this, controlPoints, count,
                              1234*i, iterations/32, &viewMatrices[i], &scores[i]), TaskScheduler::get());
    }
    TaskScheduler::get().waitFor(&group);
#else
    for (unsigned i = 0; i < 32; i++) {
        doRansac(controlPoints, count,
                              1234*i, iterations/32, &viewMatrices[i], &scores[i]);
    }
#endif


    bestViewMatrix = viewMatrices[0];
    bestScore = scores[0];

    for (unsigned i = 1; i < scores.size(); i++) {
        if (scores[i] > bestScore) {
            bestScore = scores[i];
            bestViewMatrix = viewMatrices[i];
        }
    }

    if (probs != NULL) {
        LinAlg::Matrix3x4f pv = m_internalCalibration * bestViewMatrix;

        for (unsigned i = 0; i < count; i++) {

            LinAlg::Vector3f projPos = pv * controlPoints[i].worldSpacePosition;

            if (std::abs(projPos[2]) < 1e-20f) {
                probs[i] = 0.0f;
                continue;
            }

            LinAlg::Vector2f esPos = projPos.StripHom() / projPos[2];

            float sqrReprojectionError = (controlPoints[i].screenSpacePosition - esPos).SQRLen();

            probs[i] = std::exp(sqrReprojectionError * (-0.0001f * controlPoints[i].rcpSqrScreenSize)) * m_probabilities[i];
        }
    }
}

__m128 _mm_abs_ps(__m128 v)
{
    return _mm_and_ps(v, _mm_castsi128_ps(_mm_set1_epi32(~(1 << 31))));
}


struct StateVector {
    __m128 A, B, C;

    void setZero() {
        A = B = C = _mm_setzero_ps();
    }

    void set1(float f) {
        A = B = C = _mm_set1_ps(f);
    }

    StateVector operator*(const StateVector &rhs) const {
        StateVector result;

        result.A = _mm_mul_ps(A, rhs.A);
        result.B = _mm_mul_ps(B, rhs.B);
        result.C = _mm_mul_ps(C, rhs.C);

        return result;
    }

    StateVector operator+(const StateVector &rhs) const {
        StateVector result;

        result.A = _mm_add_ps(A, rhs.A);
        result.B = _mm_add_ps(B, rhs.B);
        result.C = _mm_add_ps(C, rhs.C);

        return result;
    }

    StateVector operator-(const StateVector &rhs) const {
        StateVector result;

        result.A = _mm_sub_ps(A, rhs.A);
        result.B = _mm_sub_ps(B, rhs.B);
        result.C = _mm_sub_ps(C, rhs.C);

        return result;
    }


    StateVector operator*(const __m128 &broadcasted) const {
        StateVector result;

        result.A = _mm_mul_ps(A, broadcasted);
        result.B = _mm_mul_ps(B, broadcasted);
        result.C = _mm_mul_ps(C, broadcasted);

        return result;
    }

    StateVector getRcp() const {
        StateVector result;

        result.A = _mm_and_ps(_mm_div_ps(_mm_set1_ps(1.0f), A),
                                              _mm_cmpge_ps(_mm_abs_ps(A), _mm_set1_ps(1e-10f)));

        result.B = _mm_and_ps(_mm_div_ps(_mm_set1_ps(1.0f), B),
                                              _mm_cmpge_ps(_mm_abs_ps(B), _mm_set1_ps(1e-10f)));

        result.C = _mm_and_ps(_mm_div_ps(_mm_set1_ps(1.0f), C),
                                              _mm_cmpge_ps(_mm_abs_ps(C), _mm_set1_ps(1e-10f)));


        return result;
    }

};

__m128 dot(const StateVector &lhs, const StateVector &rhs) {

    StateVector products = lhs * rhs;

    __m128 tmp = _mm_add_ps(_mm_add_ps(products.A, products.B), products.C);

    tmp = _mm_add_ps(tmp, _mm_permute_ps(tmp, 0b11100100 ^ 0b01010101));
    tmp = _mm_add_ps(tmp, _mm_permute_ps(tmp, 0b11100100 ^ 0b10101010));

    return tmp;
}

struct ProductMatrix {

    struct SubMatrix {
        __m128 column0;
        __m128 column1;
        __m128 column2;
        __m128 column3;

        void setZero() {
            column0 = column1 = column2 = column3 = _mm_setzero_ps();
        }
    };

    SubMatrix A, B, C, D, BT, DT, E;

    void setZero() {
        A.setZero();
        B.setZero();
        C.setZero();
        D.setZero();
        BT.setZero();
        DT.setZero();
        E.setZero();
    }


    void computeFromDesignMatrix(const std::vector<LinAlg::Vector<12, float>, AlignedAllocator<LinAlg::Vector<12, float> > > &designMatrix) {
        setZero();

        for (unsigned i = 0; i < designMatrix.size(); i+= 2) {
            __m128 row0A = _mm_load_ps(&designMatrix[i+0][0]);
            __m128 row0C = _mm_load_ps(&designMatrix[i+0][8]);
            __m128 row1B = _mm_load_ps(&designMatrix[i+1][4]);
            __m128 row1C = _mm_load_ps(&designMatrix[i+1][8]);

            A.column0 = _mm_add_ps(A.column0, _mm_mul_ps(row0A, _mm_permute_ps(row0A, 0b00000000)));
            A.column1 = _mm_add_ps(A.column1, _mm_mul_ps(row0A, _mm_permute_ps(row0A, 0b01010101)));
            A.column2 = _mm_add_ps(A.column2, _mm_mul_ps(row0A, _mm_permute_ps(row0A, 0b10101010)));
            A.column3 = _mm_add_ps(A.column3, _mm_mul_ps(row0A, _mm_permute_ps(row0A, 0b11111111)));

            B.column0 = _mm_add_ps(B.column0, _mm_mul_ps(row0A, _mm_permute_ps(row0C, 0b00000000)));
            B.column1 = _mm_add_ps(B.column1, _mm_mul_ps(row0A, _mm_permute_ps(row0C, 0b01010101)));
            B.column2 = _mm_add_ps(B.column2, _mm_mul_ps(row0A, _mm_permute_ps(row0C, 0b10101010)));
            B.column3 = _mm_add_ps(B.column3, _mm_mul_ps(row0A, _mm_permute_ps(row0C, 0b11111111)));

            C.column0 = _mm_add_ps(C.column0, _mm_mul_ps(row1B, _mm_permute_ps(row1B, 0b00000000)));
            C.column1 = _mm_add_ps(C.column1, _mm_mul_ps(row1B, _mm_permute_ps(row1B, 0b01010101)));
            C.column2 = _mm_add_ps(C.column2, _mm_mul_ps(row1B, _mm_permute_ps(row1B, 0b10101010)));
            C.column3 = _mm_add_ps(C.column3, _mm_mul_ps(row1B, _mm_permute_ps(row1B, 0b11111111)));

            D.column0 = _mm_add_ps(D.column0, _mm_mul_ps(row1B, _mm_permute_ps(row1C, 0b00000000)));
            D.column1 = _mm_add_ps(D.column1, _mm_mul_ps(row1B, _mm_permute_ps(row1C, 0b01010101)));
            D.column2 = _mm_add_ps(D.column2, _mm_mul_ps(row1B, _mm_permute_ps(row1C, 0b10101010)));
            D.column3 = _mm_add_ps(D.column3, _mm_mul_ps(row1B, _mm_permute_ps(row1C, 0b11111111)));

            E.column0 = _mm_add_ps(E.column0, _mm_mul_ps(row0C, _mm_permute_ps(row0C, 0b00000000)));
            E.column1 = _mm_add_ps(E.column1, _mm_mul_ps(row0C, _mm_permute_ps(row0C, 0b01010101)));
            E.column2 = _mm_add_ps(E.column2, _mm_mul_ps(row0C, _mm_permute_ps(row0C, 0b10101010)));
            E.column3 = _mm_add_ps(E.column3, _mm_mul_ps(row0C, _mm_permute_ps(row0C, 0b11111111)));

            E.column0 = _mm_add_ps(E.column0, _mm_mul_ps(row1C, _mm_permute_ps(row1C, 0b00000000)));
            E.column1 = _mm_add_ps(E.column1, _mm_mul_ps(row1C, _mm_permute_ps(row1C, 0b01010101)));
            E.column2 = _mm_add_ps(E.column2, _mm_mul_ps(row1C, _mm_permute_ps(row1C, 0b10101010)));
            E.column3 = _mm_add_ps(E.column3, _mm_mul_ps(row1C, _mm_permute_ps(row1C, 0b11111111)));
        }

        BT = B;
        _MM_TRANSPOSE4_PS(BT.column0, BT.column1, BT.column2, BT.column3);
        DT = D;
        _MM_TRANSPOSE4_PS(DT.column0, DT.column1, DT.column2, DT.column3);
    }

    StateVector operator*(const StateVector &rhs) const {
        StateVector result;

        result.A =                      _mm_mul_ps(A.column0, _mm_permute_ps(rhs.A, 0b00000000));
        result.A = _mm_add_ps(result.A, _mm_mul_ps(A.column1, _mm_permute_ps(rhs.A, 0b01010101)));
        result.A = _mm_add_ps(result.A, _mm_mul_ps(A.column2, _mm_permute_ps(rhs.A, 0b10101010)));
        result.A = _mm_add_ps(result.A, _mm_mul_ps(A.column3, _mm_permute_ps(rhs.A, 0b11111111)));

        result.A = _mm_add_ps(result.A, _mm_mul_ps(B.column0, _mm_permute_ps(rhs.C, 0b00000000)));
        result.A = _mm_add_ps(result.A, _mm_mul_ps(B.column1, _mm_permute_ps(rhs.C, 0b01010101)));
        result.A = _mm_add_ps(result.A, _mm_mul_ps(B.column2, _mm_permute_ps(rhs.C, 0b10101010)));
        result.A = _mm_add_ps(result.A, _mm_mul_ps(B.column3, _mm_permute_ps(rhs.C, 0b11111111)));

        result.B =                      _mm_mul_ps(C.column0, _mm_permute_ps(rhs.B, 0b00000000));
        result.B = _mm_add_ps(result.B, _mm_mul_ps(C.column1, _mm_permute_ps(rhs.B, 0b01010101)));
        result.B = _mm_add_ps(result.B, _mm_mul_ps(C.column2, _mm_permute_ps(rhs.B, 0b10101010)));
        result.B = _mm_add_ps(result.B, _mm_mul_ps(C.column3, _mm_permute_ps(rhs.B, 0b11111111)));

        result.B = _mm_add_ps(result.B, _mm_mul_ps(D.column0, _mm_permute_ps(rhs.C, 0b00000000)));
        result.B = _mm_add_ps(result.B, _mm_mul_ps(D.column1, _mm_permute_ps(rhs.C, 0b01010101)));
        result.B = _mm_add_ps(result.B, _mm_mul_ps(D.column2, _mm_permute_ps(rhs.C, 0b10101010)));
        result.B = _mm_add_ps(result.B, _mm_mul_ps(D.column3, _mm_permute_ps(rhs.C, 0b11111111)));


        result.C =                      _mm_mul_ps(BT.column0, _mm_permute_ps(rhs.A, 0b00000000));
        result.C = _mm_add_ps(result.C, _mm_mul_ps(BT.column1, _mm_permute_ps(rhs.A, 0b01010101)));
        result.C = _mm_add_ps(result.C, _mm_mul_ps(BT.column2, _mm_permute_ps(rhs.A, 0b10101010)));
        result.C = _mm_add_ps(result.C, _mm_mul_ps(BT.column3, _mm_permute_ps(rhs.A, 0b11111111)));

        result.C = _mm_add_ps(result.C, _mm_mul_ps(DT.column0, _mm_permute_ps(rhs.B, 0b00000000)));
        result.C = _mm_add_ps(result.C, _mm_mul_ps(DT.column1, _mm_permute_ps(rhs.B, 0b01010101)));
        result.C = _mm_add_ps(result.C, _mm_mul_ps(DT.column2, _mm_permute_ps(rhs.B, 0b10101010)));
        result.C = _mm_add_ps(result.C, _mm_mul_ps(DT.column3, _mm_permute_ps(rhs.B, 0b11111111)));

        result.C = _mm_add_ps(result.C, _mm_mul_ps(E.column0, _mm_permute_ps(rhs.C, 0b00000000)));
        result.C = _mm_add_ps(result.C, _mm_mul_ps(E.column1, _mm_permute_ps(rhs.C, 0b01010101)));
        result.C = _mm_add_ps(result.C, _mm_mul_ps(E.column2, _mm_permute_ps(rhs.C, 0b10101010)));
        result.C = _mm_add_ps(result.C, _mm_mul_ps(E.column3, _mm_permute_ps(rhs.C, 0b11111111)));

        return result;
    }

    StateVector getDiagonal() const {
        StateVector result;
        {
            __m128 ab, cd;
            ab = _mm_blend_ps(A.column0, A.column1, 0b10);
            cd = _mm_blend_ps(A.column2, A.column3, 0b1000);
            result.A = _mm_blend_ps(ab, cd, 0b1100);
        }
        {
            __m128 ab, cd;
            ab = _mm_blend_ps(C.column0, C.column1, 0b10);
            cd = _mm_blend_ps(C.column2, C.column3, 0b1000);
            result.B = _mm_blend_ps(ab, cd, 0b1100);
        }
        {
            __m128 ab, cd;
            ab = _mm_blend_ps(E.column0, E.column1, 0b10);
            cd = _mm_blend_ps(E.column2, E.column3, 0b1000);
            result.C = _mm_blend_ps(ab, cd, 0b1100);
        }

        return result;
    }
};


LinAlg::Matrix<3, 4, float> computeProjectionMatrixFromWorldspacePointsPowerIterate(const std::vector<LinAlg::Vector<3, float> > &points2D, const std::vector<LinAlg::Vector<4, float> > &points3D)
{

    LinAlg::Matrix<3,3,float> T_2D = PCV::getCondition<2>(points2D);
    LinAlg::Matrix<4,4,float> T_3D = PCV::getCondition<3>(points3D);

    std::vector<LinAlg::Vector<3,float> > conditionedPoints2D;
    std::vector<LinAlg::Vector<4,float> > conditionedPoints3D;

    PCV::transform<2>(T_2D, points2D, conditionedPoints2D);
    PCV::transform<3>(T_3D, points3D, conditionedPoints3D);

    std::vector<LinAlg::Vector<12, float>, AlignedAllocator<LinAlg::Vector<12, float> > > designMatrix =
                PCV::getDesignMatrix_projectionMatrixFromWSPoints<float, AlignedAllocator<LinAlg::Vector<12, float> > >(conditionedPoints2D, conditionedPoints3D);

#if 0
    LinAlg::Matrix<12, 12, float> product(LinAlg::Matrix<12, 12, float>::NO_INITIALIZATION);
    for (unsigned i = 0; i < 12; i++)
        for (unsigned j = i; j < 12; j++) {
            float sum = 0.0f;
            for (unsigned k = 0; k < designMatrix.size(); k++)
                sum += designMatrix[k][i] * designMatrix[k][j];
            product[i][j] = product[j][i] = sum;
        }
/*
    std::cout << "Product = ";
    output<12, 12, float>(std::cout, product) << std::endl;
*/
    product.GaussJordanInvert();

    LinAlg::Vector<12, float> projectionMatrixAsVector;
    for (unsigned i = 0; i < 12; i++)
        projectionMatrixAsVector[i] = std::sqrt(1.0f / 12.0f);

    for (unsigned iteration = 0; iteration < 100; iteration++) {
        LinAlg::Vector<12, float> old = projectionMatrixAsVector;
        projectionMatrixAsVector = (product * old).normalized();

        if (old*projectionMatrixAsVector > 0.999999999999f) {
/*
            std::cout << "Breaking after " << iteration << std::endl
                      << "    " << old*projectionMatrixAsVector << std::endl
                      << "    " << (std::string) old << std::endl
                      << "    " << (std::string) projectionMatrixAsVector << std::endl;
*/
            break;
        }
    }







    LinAlg::Matrix<3, 4, float> projectionMatrix(LinAlg::Matrix<3, 4, float>::NO_INITIALIZATION);
    memcpy(&projectionMatrix[0][0], &projectionMatrixAsVector[0], 12*sizeof(float));

#else
    ProductMatrix productMatrix;
    productMatrix.computeFromDesignMatrix(designMatrix);

    StateVector rcpPreconditioner = productMatrix.getDiagonal().getRcp();
    rcpPreconditioner.set1(1.0f);

    StateVector projMatrixVector;
    projMatrixVector.set1(std::sqrt(1.0f / 12.0f));

    for (unsigned powerIteration = 0; powerIteration < 100; powerIteration++) {
        StateVector x = projMatrixVector;
        const StateVector &b = projMatrixVector;
        StateVector r = b - productMatrix * x;
        StateVector z = rcpPreconditioner * r;
        StateVector p = z;

        __m128 oldRtimesZ = dot(r, z);
        for (unsigned pcgdIteration = 0; pcgdIteration < 25; pcgdIteration++) {
            StateVector AP = productMatrix * p;
            __m128 alpha = _mm_div_ps(oldRtimesZ, dot(p, AP));

            x = x + p * alpha;
            r = r - AP * alpha;
            z = rcpPreconditioner * r;
            __m128 newRtimesZ = dot(r, z);

            if (_mm_comilt_ss(newRtimesZ, _mm_set1_ps(1e-3f))) {
                oldRtimesZ = newRtimesZ;
              //  std::cout << "Breaking after pcgdIteration " << pcgdIteration << std::endl;
                break;
            }
            p = z + p * _mm_div_ps(newRtimesZ, oldRtimesZ);
            oldRtimesZ = newRtimesZ;
        }

        // normalize length
        //x = x * _mm_rsqrt_ps(dot(x, x));
        x = x * _mm_div_ps(_mm_set1_ps(1.0f), _mm_sqrt_ps(dot(x, x)));

        __m128 change = dot(x, projMatrixVector);
        projMatrixVector = x;

/*
            alignas(16) float tmp[4];
            _mm_store_ps(tmp, change);
            std::cout << "powerIteration " << powerIteration << " change: " << tmp[0] << std::endl;
*/
        if (_mm_comigt_ss(_mm_abs_ps(change), _mm_set1_ps(0.9999f))) {
        //    std::cout << "Breaking after powerIteration " << powerIteration << std::endl;

/*
            std::cout << "Breaking after " << iteration << std::endl
                      << "    " << old*projectionMatrixAsVector << std::endl
                      << "    " << (std::string) old << std::endl
                      << "    " << (std::string) projectionMatrixAsVector << std::endl;
*/
            break;
        }

    }

    alignas(16) LinAlg::Matrix<3, 4, float> projectionMatrix(LinAlg::Matrix<3, 4, float>::NO_INITIALIZATION);
    _mm_store_ps(&projectionMatrix[0][0], projMatrixVector.A);
    _mm_store_ps(&projectionMatrix[1][0], projMatrixVector.B);
    _mm_store_ps(&projectionMatrix[2][0], projMatrixVector.C);

#endif


    PCV::decondition<float>(T_2D, T_3D, projectionMatrix);

    return projectionMatrix;
}


void ViewMatrixEstimator::doRansac(const ControlPoint *controlPoints, unsigned count, unsigned seed, unsigned iters, LinAlg::Matrix4x4f *dstViewMatrix, float *dstScore)
{
    LinAlg::Matrix4x4f bestViewMatrix;
    float bestProbability = 0.0f;

    const unsigned numPointsInSet = 8;

    std::vector<LinAlg::Vector4f> ransacWorldSpacePoints;
    std::vector<LinAlg::Vector3f> ransacScreenSpacePoints;
    ransacWorldSpacePoints.resize(numPointsInSet);
    ransacScreenSpacePoints.resize(numPointsInSet);

    std::mt19937 rne(std::chrono::system_clock::now().time_since_epoch().count() ^ seed);
    std::uniform_real_distribution<float> distribution(0.0, 1.0);

    for (unsigned iter = 0; iter < iters/4; iter++) {
        LinAlg::Matrix4x4f viewMatrices[4];
        LinAlg::Matrix3x4f projectionMatrices[4];
        for (int k = 0; k < 4; k++) {
        //    std::cout << "Iter " << iter << std::endl;
            for (unsigned i = 0; i < numPointsInSet; i++) {
                unsigned pointIndex = pickMatch(distribution(rne));
                //unsigned matchIndex = std::min((int)(distribution(rne) * m_matches.size()), (int)(m_matches.size()-1));
      //          std::cout << "Match "<<i<<": " << matchIndex << std::endl;

                ransacWorldSpacePoints[i] = controlPoints[pointIndex].worldSpacePosition;
                ransacScreenSpacePoints[i] = controlPoints[pointIndex].screenSpacePosition.AddHom(1.0f);
            }

            try {
                //LinAlg::Matrix3x4f Porig = PCV::computeProjectionMatrixFromWorldspacePoints<float>(ransacScreenSpacePoints, ransacWorldSpacePoints);
                LinAlg::Matrix3x4f Porig = computeProjectionMatrixFromWorldspacePointsPowerIterate(ransacScreenSpacePoints, ransacWorldSpacePoints);

                LinAlg::Matrix3x3f dummy3;
                LinAlg::Matrix4x4f dummy4;
                LinAlg::Matrix4x4f viewMatrix;
                LinAlg::Vector3f dummyVec;
                PCV::decomposeProjectionMatrix(Porig,
                                          dummy3,
                                          viewMatrix,
                                          dummy4,
                                          dummyVec);

                //optimizeViewMatrix(ransacWorldSpacePoints, ransacScreenSpacePoints, m_internalCalibration, viewMatrix);

                viewMatrices[k] = viewMatrix;
                projectionMatrices[k] = m_internalCalibration * viewMatrix;//Porig;
            } catch (...) {
               // k--;
            }
        }

        float probs[4];
        computeProbabilityOfFourProjectionMatrices(controlPoints, count, projectionMatrices, probs);

        for (unsigned k = 0; k < 4; k++) {
            if (probs[k] > bestProbability) {
                bestViewMatrix = viewMatrices[k];
                bestProbability = probs[k];
            }
        }
    }

    *dstViewMatrix = bestViewMatrix;
    *dstScore = bestProbability;
}



void ViewMatrixEstimator::computeProbabilityOfFourProjectionMatrices(const ControlPoint *controlPoints, unsigned count, LinAlg::Matrix3x4f *projectionMatrices, float *probs)
{
    float fourProjectionViewMatrices[3*4*4] __attribute__((aligned(16)));

    for (unsigned i = 0; i < 3; i++)
        for (unsigned j = 0; j < 4; j++)
            for (unsigned k = 0; k < 4; k++)
                fourProjectionViewMatrices[i*4*4+j*4+k] = projectionMatrices[k][i][j];


    __m128 sum = _mm_setzero_ps();
    for (unsigned i = 0; i < count; i++) {
        __m128 WS_X = _mm_set1_ps(controlPoints[i].worldSpacePosition[0]);
        __m128 WS_Y = _mm_set1_ps(controlPoints[i].worldSpacePosition[1]);
        __m128 WS_Z = _mm_set1_ps(controlPoints[i].worldSpacePosition[2]);
        __m128 WS_W = _mm_set1_ps(controlPoints[i].worldSpacePosition[3]);

        __m128 PSS_X = _mm_mul_ps(WS_X, _mm_load_ps(fourProjectionViewMatrices+4*(0*4+0)));
        __m128 PSS_Y = _mm_mul_ps(WS_X, _mm_load_ps(fourProjectionViewMatrices+4*(1*4+0)));
        __m128 PSS_W = _mm_mul_ps(WS_X, _mm_load_ps(fourProjectionViewMatrices+4*(2*4+0)));

        PSS_X = _mm_add_ps(PSS_X, _mm_mul_ps(WS_Y, _mm_load_ps(fourProjectionViewMatrices+4*(0*4+1))));
        PSS_Y = _mm_add_ps(PSS_Y, _mm_mul_ps(WS_Y, _mm_load_ps(fourProjectionViewMatrices+4*(1*4+1))));
        PSS_W = _mm_add_ps(PSS_W, _mm_mul_ps(WS_Y, _mm_load_ps(fourProjectionViewMatrices+4*(2*4+1))));

        PSS_X = _mm_add_ps(PSS_X, _mm_mul_ps(WS_Z, _mm_load_ps(fourProjectionViewMatrices+4*(0*4+2))));
        PSS_Y = _mm_add_ps(PSS_Y, _mm_mul_ps(WS_Z, _mm_load_ps(fourProjectionViewMatrices+4*(1*4+2))));
        PSS_W = _mm_add_ps(PSS_W, _mm_mul_ps(WS_Z, _mm_load_ps(fourProjectionViewMatrices+4*(2*4+2))));

        PSS_X = _mm_add_ps(PSS_X, _mm_mul_ps(WS_W, _mm_load_ps(fourProjectionViewMatrices+4*(0*4+3))));
        PSS_Y = _mm_add_ps(PSS_Y, _mm_mul_ps(WS_W, _mm_load_ps(fourProjectionViewMatrices+4*(1*4+3))));
        PSS_W = _mm_add_ps(PSS_W, _mm_mul_ps(WS_W, _mm_load_ps(fourProjectionViewMatrices+4*(2*4+3))));


        __m128 validMask = _mm_or_ps(_mm_cmpgt_ps(PSS_W, _mm_set1_ps(1e-10f)),
                                     _mm_cmplt_ps(PSS_W, _mm_set1_ps(-1e-10f)));


        __m128 rcpW = _mm_rcp_ps(PSS_W);

        __m128 ESS_X = _mm_mul_ps(PSS_X, rcpW);
        __m128 ESS_Y = _mm_mul_ps(PSS_Y, rcpW);

        __m128 DX = _mm_sub_ps(ESS_X, _mm_set1_ps(controlPoints[i].screenSpacePosition[0]));
        __m128 DY = _mm_sub_ps(ESS_Y, _mm_set1_ps(controlPoints[i].screenSpacePosition[1]));

        __m128 sqrRpjError = _mm_add_ps(_mm_mul_ps(DX, DX), _mm_mul_ps(DY, DY));


        __m128 d = _mm_mul_ps(sqrRpjError, _mm_set1_ps(0.0001f * controlPoints[i].rcpSqrScreenSize));
#if 0
        __m128 prob = exp_ps(negate(d));
#else
        __m128 prob = _mm_rcp_ps(_mm_add_ps(_mm_set1_ps(1.0f), _mm_mul_ps(_mm_mul_ps(d, d), _mm_set1_ps(2.0f))));
#endif

        prob = _mm_and_ps(prob, validMask);

        sum = _mm_add_ps(sum, _mm_mul_ps(prob, _mm_set1_ps(m_probabilities[i])));
    }
    _mm_storeu_ps(probs, sum);

}



}
