/*
    Structure from Motion with Deferred Feature Matching and Subset Bundle Adjustment
    Copyright (C) 2015 Andreas Ley <andy-ley@arcor.de>

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
*/

#include "SingleTrackBundleAdjustment.h"

#include <string.h>
#include <assert.h>


// found on the internet:
inline __m128 _mm_dot_ps(__m128 v1, __m128 v2)
{
    __m128 mul0 = _mm_mul_ps(v1, v2);
    __m128 swp0 = _mm_shuffle_ps(mul0, mul0, _MM_SHUFFLE(2, 3, 0, 1));
    __m128 add0 = _mm_add_ps(mul0, swp0);
    __m128 swp1 = _mm_shuffle_ps(add0, add0, _MM_SHUFFLE(0, 1, 2, 3));
    __m128 add1 = _mm_add_ps(add0, swp1);
    return add1;
}


namespace SFM {


void SingleTrackBundleAdjustment::computeResiduals(const std::vector<Observation> &observations, const LinAlg::Vector4f &trackPosition)
{
    m_residuals.resize(observations.size()*2);

    for (unsigned j = 0; j < observations.size(); j++) {
        const LinAlg::Matrix3x4f &projectionViewMatrix = *observations[j].projectionViewMatrix;
        const float &weight = observations[j].weight;
        const LinAlg::Vector2f &screenSpacePos = observations[j].screenSpacePosition;
        const PFBundleAdjustment::PFBundleAdjustment::RadialDistortionParametrization &distortionParameters = *observations[j].distortion;

        LinAlg::Vector3f projectedPos = projectionViewMatrix * trackPosition;
        if (std::abs(projectedPos[2]*projectedPos[2]) <= 1e-30f) {
            m_residuals[j*2+0] = 0.0f;
            m_residuals[j*2+1] = 0.0f;
            continue;
        }

        LinAlg::Vector2f euclideanProjectedPos = projectedPos.StripHom() * (1.0f / projectedPos[2]);


        /// @todo refactor this
        switch (distortionParameters.type) {
            case config::BundleAdjustmentStructureConfig::RadialDistortionType::NoRadialDistortion:
            break;
            case config::BundleAdjustmentStructureConfig::RadialDistortionType::Polynomial_234: {
                float r2 = euclideanProjectedPos.SQRLen();
                float r3 = r2 * std::sqrt(r2);
                float r4 = r2*r2;
                euclideanProjectedPos *= 1.0f + r2 * distortionParameters.polynomial234.kappa[0] +
                                                r3 * distortionParameters.polynomial234.kappa[1] +
                                                r4 * distortionParameters.polynomial234.kappa[2];
            }break;
            default:
                throw std::runtime_error("unsupported distortion type!");
        }


        LinAlg::Vector2f residual = (screenSpacePos - euclideanProjectedPos) * weight;
        m_residuals[j*2+0] = residual[0];
        m_residuals[j*2+1] = residual[1];
    }
}

void SingleTrackBundleAdjustment::computeJacobian(const std::vector<Observation> &observations, const LinAlg::Vector4f &trackPosition)
{
    m_jacobian.resize(observations.size()*2);
    memset(&m_jacobian[0], 0, sizeof(LinAlg::Vector4f)*m_jacobian.size());

    const LinAlg::Vector4f &projWsPos = trackPosition;


    for (unsigned j = 0; j < observations.size(); j++) {
        const LinAlg::Matrix3x4f &projectionViewMatrix = *observations[j].projectionViewMatrix;
        const float &weight = observations[j].weight;
        const PFBundleAdjustment::PFBundleAdjustment::RadialDistortionParametrization &distortionParameters = *observations[j].distortion;

        LinAlg::Vector3f projectedPos = projectionViewMatrix * projWsPos;
        if (std::abs(projectedPos[2]*projectedPos[2]) <= 1e-30f)
            continue;

        LinAlg::Vector2f euclideanProjectedPos = projectedPos.StripHom() * (1.0f / projectedPos[2]);

        const float errorFirstDerivative = 1.0f;


        /// @todo refactor this
        LinAlg::Matrix2x2f radialDistJacobi(LinAlg::Matrix2x2f::NO_INITIALIZATION);
        switch (distortionParameters.type) {
            case config::BundleAdjustmentStructureConfig::RadialDistortionType::NoRadialDistortion:
            break;
            case config::BundleAdjustmentStructureConfig::RadialDistortionType::Polynomial_234: {
                const float r2 = euclideanProjectedPos.SQRLen();
                const float r1 = std::sqrt(r2);
                const float r3 = r2 * r1;
                const float r4 = r2 * r2;

                const float radialA = distortionParameters.polynomial234.kappa[0];
                const float radialB = distortionParameters.polynomial234.kappa[1];
                const float radialC = distortionParameters.polynomial234.kappa[2];

                radialDistJacobi[0][0] =
                    euclideanProjectedPos[0] * euclideanProjectedPos[0] *
                                (2.0f*radialA + 3.0f*radialB*r1 + 4.0f*radialC*r2) +
                    r2*radialA +
                    r3*radialB +
                    r4*radialC +
                    1.0f;

                radialDistJacobi[1][0] =
                radialDistJacobi[0][1] =
                    euclideanProjectedPos[0] * euclideanProjectedPos[1] *
                                (2.0f*radialA + 3.0f*radialB*r1 + 4.0f*radialC*r2);


                radialDistJacobi[1][1] =
                    euclideanProjectedPos[1] * euclideanProjectedPos[1] *
                                (2.0f*radialA + 3.0f*radialB*r1 + 4.0f*radialC*r2) +
                    r2*radialA +
                    r3*radialB +
                    r4*radialC +
                    1.0f;
            }break;
            default:
                throw std::runtime_error("unsupported distortion type!");
        }


        LinAlg::Vector2f pointGrad;

        // d/dX1
        pointGrad[0] = projectionViewMatrix[0][0] / projectedPos[2] -
                       projectionViewMatrix[2][0] * euclideanProjectedPos[0] / projectedPos[2];
        pointGrad[1] = projectionViewMatrix[1][0] / projectedPos[2] -
                       projectionViewMatrix[2][0] * euclideanProjectedPos[1] / projectedPos[2];

        pointGrad *= errorFirstDerivative;
        if (distortionParameters.type != config::BundleAdjustmentStructureConfig::RadialDistortionType::NoRadialDistortion)
            pointGrad = radialDistJacobi * pointGrad;

        assert(std::isfinite(pointGrad[0]));
        assert(std::isfinite(pointGrad[1]));

        m_jacobian[j*2+0][0] = pointGrad[0] * weight;
        m_jacobian[j*2+1][0] = pointGrad[1] * weight;

        // d/dX2
        pointGrad[0] = projectionViewMatrix[0][1] / projectedPos[2] -
                       projectionViewMatrix[2][1] * euclideanProjectedPos[0] / projectedPos[2];
        pointGrad[1] = projectionViewMatrix[1][1] / projectedPos[2] -
                       projectionViewMatrix[2][1] * euclideanProjectedPos[1] / projectedPos[2];

        pointGrad *= errorFirstDerivative;
        if (distortionParameters.type != config::BundleAdjustmentStructureConfig::RadialDistortionType::NoRadialDistortion)
            pointGrad = radialDistJacobi * pointGrad;


        assert(std::isfinite(pointGrad[0]));
        assert(std::isfinite(pointGrad[1]));

        m_jacobian[j*2+0][1] = pointGrad[0] * weight;
        m_jacobian[j*2+1][1] = pointGrad[1] * weight;

        // d/dX3
        pointGrad[0] = projectionViewMatrix[0][2] / projectedPos[2] -
                       projectionViewMatrix[2][2] * euclideanProjectedPos[0] / projectedPos[2];
        pointGrad[1] = projectionViewMatrix[1][2] / projectedPos[2] -
                       projectionViewMatrix[2][2] * euclideanProjectedPos[1] / projectedPos[2];

        pointGrad *= errorFirstDerivative;
        if (distortionParameters.type != config::BundleAdjustmentStructureConfig::RadialDistortionType::NoRadialDistortion)
            pointGrad = radialDistJacobi * pointGrad;


        assert(std::isfinite(pointGrad[0]));
        assert(std::isfinite(pointGrad[1]));

        m_jacobian[j*2+0][2] = pointGrad[0] * weight;
        m_jacobian[j*2+1][2] = pointGrad[1] * weight;

        // d/dX4
        pointGrad[0] = projectionViewMatrix[0][3] / projectedPos[2] -
                       projectionViewMatrix[2][3] * euclideanProjectedPos[0] / projectedPos[2];
        pointGrad[1] = projectionViewMatrix[1][3] / projectedPos[2] -
                       projectionViewMatrix[2][3] * euclideanProjectedPos[1] / projectedPos[2];

        pointGrad *= errorFirstDerivative;
        if (distortionParameters.type != config::BundleAdjustmentStructureConfig::RadialDistortionType::NoRadialDistortion)
            pointGrad = radialDistJacobi * pointGrad;


        assert(std::isfinite(pointGrad[0]));
        assert(std::isfinite(pointGrad[1]));

        m_jacobian[j*2+0][3] = pointGrad[0] * weight;
        m_jacobian[j*2+1][3] = pointGrad[1] * weight;
    }

}


float SingleTrackBundleAdjustment::sumResiduals()
{
    float sum = 0.0f;
    for (unsigned i = 0; i < m_residuals.size(); i++) {
        sum += m_residuals[i] * m_residuals[i];
    }
    return sum;
}


void SingleTrackBundleAdjustment::Hessian::computeHessianFromJacobi(const std::vector<LinAlg::Vector4f, AlignedAllocator<LinAlg::Vector4f> > &jacobian, float lambda)
{
    H0 = H1 = H2 = H3 = _mm_setzero_ps();

    for (unsigned i = 0; i < jacobian.size(); i++) {
        __m128 row = _mm_load_ps(&jacobian[i][0]);

        __m128 first =  _mm_permute_ps(row, 0b00000000);
        __m128 second = _mm_permute_ps(row, 0b01010101);
        __m128 third =  _mm_permute_ps(row, 0b10101010);
        __m128 fourth = _mm_permute_ps(row, 0b11111111);


        H0 = _mm_add_ps(H0, _mm_mul_ps(row, first));
        H1 = _mm_add_ps(H1, _mm_mul_ps(row, second));
        H2 = _mm_add_ps(H2, _mm_mul_ps(row, third));
        H3 = _mm_add_ps(H3, _mm_mul_ps(row, fourth));
    }

    H0 = _mm_mul_ps(H0, _mm_setr_ps(1.0f + lambda, 1.0f, 1.0f, 1.0f));
    H1 = _mm_mul_ps(H1, _mm_setr_ps(1.0f, 1.0f + lambda, 1.0f, 1.0f));
    H2 = _mm_mul_ps(H2, _mm_setr_ps(1.0f, 1.0f, 1.0f + lambda, 1.0f));
    H3 = _mm_mul_ps(H3, _mm_setr_ps(1.0f, 1.0f, 1.0f, 1.0f + lambda));
}

__m128 SingleTrackBundleAdjustment::Hessian::operator*(const __m128 &v) const
{
    __m128 result;

    // hessian is symmetric!

    result = _mm_mul_ps(H0, _mm_permute_ps(v, 0b00000000));
    result = _mm_add_ps(result, _mm_mul_ps(H1, _mm_permute_ps(v, 0b01010101)));
    result = _mm_add_ps(result, _mm_mul_ps(H2, _mm_permute_ps(v, 0b10101010)));
    result = _mm_add_ps(result, _mm_mul_ps(H3, _mm_permute_ps(v, 0b11111111)));

    return result;
}

__m128 SingleTrackBundleAdjustment::Hessian::getDiagonal() const
{
    __m128 ab, cd;

    ab = _mm_blend_ps(H0, H1, 0b10);
    cd = _mm_blend_ps(H2, H3, 0b1000);
    return _mm_blend_ps(ab, cd, 0b1100);
}

inline __m128 _mm_abs_ps(__m128 v)
{
    return _mm_and_ps(v, _mm_castsi128_ps(_mm_set1_epi32(~(1 << 31))));
}


void SingleTrackBundleAdjustment::optimize(const std::vector<Observation> &observations, LinAlg::Vector4f &trackPosition)
{
    float lambda = 1e-2f;

    computeResiduals(observations, trackPosition);
    float lastError = sumResiduals();
    //std::cout << "Error before: " << lastError << std::endl;

    for (unsigned lmIteration = 0; lmIteration < 100; lmIteration++) {

        computeJacobian(observations, trackPosition);
        alignas(16) Hessian hessian;
        hessian.computeHessianFromJacobi(m_jacobian, lambda);

        __m128 y = _mm_setzero_ps();
        for (unsigned j = 0; j < m_jacobian.size(); j++) {
            //y += m_jacobian[j] * m_residuals[j];
            __m128 residual = _mm_broadcast_ss(&m_residuals[j]);
            __m128 jacobianRow = _mm_load_ps(&m_jacobian[j][0]);
            y = _mm_add_ps(y, _mm_mul_ps(jacobianRow, residual));
        }

        __m128 x = _mm_setzero_ps();
        __m128 r = y;

        __m128 preconditioner = hessian.getDiagonal();
        __m128 rcpPreconditioner = _mm_and_ps(_mm_div_ps(_mm_set1_ps(1.0f), preconditioner),
                                              _mm_cmpge_ps(_mm_abs_ps(preconditioner), _mm_set1_ps(1e-10f)));

        __m128 z = _mm_mul_ps(rcpPreconditioner, r);
        __m128 p = z;

        __m128 oldRtimesZ = _mm_dot_ps(r, z);

        __m128 AP;
        for (unsigned pcgdIteration = 0; pcgdIteration < 10; pcgdIteration++) {
            AP = hessian * p;
            //__m128 alpha = _mm_permute_ps(_mm_div_ss(oldRtimesZ, _mm_dot_ps(p * AP)), 0);
            __m128 alpha = _mm_div_ps(oldRtimesZ, _mm_dot_ps(p, AP));
            x = _mm_add_ps(x, _mm_mul_ps(p, alpha));
            r = _mm_sub_ps(r, _mm_mul_ps(AP, alpha));
            z = _mm_mul_ps(rcpPreconditioner, r);
            __m128 newRtimesZ = _mm_dot_ps(r, z);
            if (_mm_comilt_ss(newRtimesZ, _mm_set1_ps(1e-15f))) {
                oldRtimesZ = newRtimesZ;
                break;
            }
            p = _mm_add_ps(z, _mm_mul_ps(p, _mm_div_ps(newRtimesZ, oldRtimesZ)));
            oldRtimesZ = newRtimesZ;
        }


        __m128 pos = _mm_add_ps(_mm_loadu_ps(&trackPosition[0]), x);
        pos = _mm_mul_ps(pos, _mm_rsqrt_ps(_mm_dot_ps(pos, pos)));

        alignas(16) LinAlg::Vector4f newTrackPosition;
        _mm_store_ps(&newTrackPosition[0], pos);

        computeResiduals(observations, newTrackPosition);
        float newError = sumResiduals();
        //std::cout << "iteration " << lmIteration << " lambda " << lambda << " error: " << newError << std::endl;
        if (newError < lastError) {
            trackPosition = newTrackPosition;
            lastError = newError;

            lambda *= 0.5f;
            lambda = std::max(lambda, 1e-10f);

        } else {
            lambda *= 10.0f;

            if (lambda > 1e5f) {
                lmIteration++;
                break;
            }

            computeResiduals(observations, trackPosition);
        }
    }
    //std::cout << "Error after: " << lastError << std::endl;
}


}
