/*
    Structure from Motion with Deferred Feature Matching and Subset Bundle Adjustment
    Copyright (C) 2015 Andreas Ley <andy-ley@arcor.de>

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
*/

#ifndef PFJACOBI_H
#define PFJACOBI_H

#include <vector>

#include <ostream>

#include <immintrin.h>

#include "../../tools/AlignedAllocator.h"

/**
 * @brief AVX emulation of a NVidia Warp
 * @details Emulates 32-wide SIMD operations by performing 4 AVX operations apiece. Rudimentary scatter/gather
 * and reduction is also implemented.
 * @ingroup BundleAdjustment_Group
 */
class CpuWarpReg
{
    private:
        __m256 a;
        __m256 b;
        __m256 c;
        __m256 d;
    public:
        /// Sets all values to zero
        inline void setZero() {
            a = _mm256_setzero_ps();
            b = _mm256_setzero_ps();
            c = _mm256_setzero_ps();
            d = _mm256_setzero_ps();
        }

        /// Performs a gather load of all 32 values from base + offset[index]*stride*sizeof(float)
        template<unsigned stride = 1>
        inline void gather(const float *base, const unsigned *offset) {
            a = _mm256_setr_ps(base[offset[0*8+0]*stride],
                               base[offset[0*8+1]*stride],
                               base[offset[0*8+2]*stride],
                               base[offset[0*8+3]*stride],
                               base[offset[0*8+4]*stride],
                               base[offset[0*8+5]*stride],
                               base[offset[0*8+6]*stride],
                               base[offset[0*8+7]*stride]);
            b = _mm256_setr_ps(base[offset[1*8+0]*stride],
                               base[offset[1*8+1]*stride],
                               base[offset[1*8+2]*stride],
                               base[offset[1*8+3]*stride],
                               base[offset[1*8+4]*stride],
                               base[offset[1*8+5]*stride],
                               base[offset[1*8+6]*stride],
                               base[offset[1*8+7]*stride]);
            c = _mm256_setr_ps(base[offset[2*8+0]*stride],
                               base[offset[2*8+1]*stride],
                               base[offset[2*8+2]*stride],
                               base[offset[2*8+3]*stride],
                               base[offset[2*8+4]*stride],
                               base[offset[2*8+5]*stride],
                               base[offset[2*8+6]*stride],
                               base[offset[2*8+7]*stride]);
            d = _mm256_setr_ps(base[offset[3*8+0]*stride],
                               base[offset[3*8+1]*stride],
                               base[offset[3*8+2]*stride],
                               base[offset[3*8+3]*stride],
                               base[offset[3*8+4]*stride],
                               base[offset[3*8+5]*stride],
                               base[offset[3*8+6]*stride],
                               base[offset[3*8+7]*stride]);
        }

        /// @brief Gather load and broadcast of the 8 values from base + offset[index] * sizeof(float).
        /// @details Every value is duplicated four times so the final 32 values are aaaabbbbccccddddeeeeffffgggghhhh.
        inline void gatherBroadcast4(const float *base, const unsigned *offset) {
            a =  _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_broadcast_ss(base + offset[0])),
                                                             _mm_broadcast_ss(base + offset[1]), 1);

            b =  _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_broadcast_ss(base + offset[2])),
                                                             _mm_broadcast_ss(base + offset[3]), 1);

            c =  _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_broadcast_ss(base + offset[4])),
                                                             _mm_broadcast_ss(base + offset[5]), 1);

            d =  _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_broadcast_ss(base + offset[6])),
                                                             _mm_broadcast_ss(base + offset[7]), 1);
        }
/*
        inline void loadAndBroadcast16Aligned(const float *src) {
            __m256 low = _mm256_load_ps(src + 0);
            __m256 high = _mm256_load_ps(src + 8);

            __m256 helper1 = _mm256_unpacklo_ps(low, low);
            a = _mm256_unpacklo_ps(helper1, helper1);
            b = _mm256_unpackhi_ps(helper1, helper1);
            __m256 helper2 = _mm256_unpackhi_ps(high, high);
            c = _mm256_unpacklo_ps(helper2, helper2);
            d = _mm256_unpackhi_ps(helper2, helper2);
        }
*/
        /// Loads a block of 32 values from a 32-byte aligned address.
        inline void loadAligned(const float *src) {
            a = _mm256_load_ps(src + 0*8);
            b = _mm256_load_ps(src + 1*8);
            c = _mm256_load_ps(src + 2*8);
            d = _mm256_load_ps(src + 3*8);
        }
        /// Stores a block of 32 values to a 32-byte aligned address.
        inline void storeAligned(float *dst) const {
            _mm256_store_ps(dst + 0*8, a);
            _mm256_store_ps(dst + 1*8, b);
            _mm256_store_ps(dst + 2*8, c);
            _mm256_store_ps(dst + 3*8, d);
        }

        /// Pairwise multiplication
        CpuWarpReg operator*(const CpuWarpReg &rhs) const {
            CpuWarpReg result;

            result.a = _mm256_mul_ps(a, rhs.a);
            result.b = _mm256_mul_ps(b, rhs.b);
            result.c = _mm256_mul_ps(c, rhs.c);
            result.d = _mm256_mul_ps(d, rhs.d);

            return result;
        }

        /// Multiplication of all values with a single float
        CpuWarpReg operator*(const float *scalar) const {
            CpuWarpReg result;

            __m256 f = _mm256_broadcast_ss(scalar);

            result.a = _mm256_mul_ps(a, f);
            result.b = _mm256_mul_ps(b, f);
            result.c = _mm256_mul_ps(c, f);
            result.d = _mm256_mul_ps(d, f);

            return result;
        }

        /// Pairwise addition
        CpuWarpReg operator+(const CpuWarpReg &rhs) const {
            CpuWarpReg result;

            result.a = _mm256_add_ps(a, rhs.a);
            result.b = _mm256_add_ps(b, rhs.b);
            result.c = _mm256_add_ps(c, rhs.c);
            result.d = _mm256_add_ps(d, rhs.d);

            return result;
        }

        /// Pairwise addition with 8 values repeated four times.
        CpuWarpReg operator+(const __m256 &rhs) const {
            CpuWarpReg result;

            result.a = _mm256_add_ps(a, rhs);
            result.b = _mm256_add_ps(b, rhs);
            result.c = _mm256_add_ps(c, rhs);
            result.d = _mm256_add_ps(d, rhs);

            return result;
        }

        /// Pairwise subtraction
        CpuWarpReg operator-(const CpuWarpReg &rhs) const {
            CpuWarpReg result;

            result.a = _mm256_sub_ps(a, rhs.a);
            result.b = _mm256_sub_ps(b, rhs.b);
            result.c = _mm256_sub_ps(c, rhs.c);
            result.d = _mm256_sub_ps(d, rhs.d);

            return result;
        }

        /// Pairwise addition
        CpuWarpReg& operator+=(const CpuWarpReg &rhs) {
            a = _mm256_add_ps(a, rhs.a);
            b = _mm256_add_ps(b, rhs.b);
            c = _mm256_add_ps(c, rhs.c);
            d = _mm256_add_ps(d, rhs.d);
            return *this;
        }

        /// Stores the sum of all 32 values at dst
        void reduce(float *dst) const {
            __m256 tmp = _mm256_add_ps(_mm256_add_ps(a, b), _mm256_add_ps(c, d));
            __m128 tmp2 = _mm_add_ps(_mm256_extractf128_ps(tmp, 0), _mm256_extractf128_ps(tmp, 1));
            tmp2 = _mm_hadd_ps(tmp2, tmp2);
            tmp2 = _mm_hadd_ps(tmp2, tmp2);
            _mm_store_ss(dst, tmp2);
        }

        /**
         * @brief For all eight groups of quadruples, sums the four consecutive values, computes the
         * approximate reciprocal squareroots and broadcasts.
         * @details For all eight groups of quadruples, the four consecutive values all store the value afterwards.
         * This operation is usefull for vector normalization, when the warp holds on to 8 vectors of 4 components each.
         */
        void reduce4ApproxRsqrtAndBroadcast() {
            __m256 tmpA = _mm256_hadd_ps(a, b);
            __m256 tmpB = _mm256_hadd_ps(c, d);

            __m256 tmp = _mm256_hadd_ps(tmpA, tmpB);

            tmp = _mm256_rsqrt_ps(tmp);

            tmpA = _mm256_unpacklo_ps(tmp, tmp);
            tmpB = _mm256_unpackhi_ps(tmp, tmp);

            a = _mm256_unpacklo_ps(tmpA, tmpA);
            b = _mm256_unpackhi_ps(tmpA, tmpA);
            c = _mm256_unpacklo_ps(tmpB, tmpB);
            d = _mm256_unpackhi_ps(tmpB, tmpB);
        }

        /**
         * @brief For all eight groups of quadruples, sums the four consecutive values, computes the
         * precise reciprocal squareroots and broadcasts.
         * @details For all eight groups of quadruples, the four consecutive values all store the value afterwards.
         * This operation is usefull for vector normalization, when the warp holds on to 8 vectors of 4 components each.
         */

        void reduce4RsqrtAndBroadcast() {
            __m256 tmpA = _mm256_hadd_ps(a, b);
            __m256 tmpB = _mm256_hadd_ps(c, d);

            __m256 tmp = _mm256_hadd_ps(tmpA, tmpB);

            tmp = _mm256_sqrt_ps(tmp);
            tmp = _mm256_div_ps(_mm256_set1_ps(1.0f), tmp);

            tmpA = _mm256_unpacklo_ps(tmp, tmp);
            tmpB = _mm256_unpackhi_ps(tmp, tmp);

            a = _mm256_unpacklo_ps(tmpA, tmpA);
            b = _mm256_unpackhi_ps(tmpA, tmpA);
            c = _mm256_unpacklo_ps(tmpB, tmpB);
            d = _mm256_unpackhi_ps(tmpB, tmpB);
        }

        /// Returns the absolute of each value.
        CpuWarpReg abs() const {
            CpuWarpReg result;

            result.a = _mm256_andnot_ps(_mm256_set1_ps(-0.0f), a);
            result.b = _mm256_andnot_ps(_mm256_set1_ps(-0.0f), b);
            result.c = _mm256_andnot_ps(_mm256_set1_ps(-0.0f), c);
            result.d = _mm256_andnot_ps(_mm256_set1_ps(-0.0f), d);

            return result;
        }

        /// Performs a pairwise compare operation which is specified by the operation parameter.
        template<unsigned operation>
        CpuWarpReg compare(const CpuWarpReg &rhs) {
            CpuWarpReg result;

            result.a = _mm256_cmp_ps(a, rhs.a, operation);
            result.b = _mm256_cmp_ps(b, rhs.b, operation);
            result.c = _mm256_cmp_ps(c, rhs.c, operation);
            result.d = _mm256_cmp_ps(d, rhs.d, operation);

            return result;
        }

        /// Performs a pairwise compare operation with a single float value.
        template<unsigned operation>
        CpuWarpReg compare(const float rhs) {
            CpuWarpReg result;

            result.a = _mm256_cmp_ps(a, _mm256_set1_ps(rhs), operation);
            result.b = _mm256_cmp_ps(b, _mm256_set1_ps(rhs), operation);
            result.c = _mm256_cmp_ps(c, _mm256_set1_ps(rhs), operation);
            result.d = _mm256_cmp_ps(d, _mm256_set1_ps(rhs), operation);

            return result;
        }

        /// Performs pairwise binary and operations.
        CpuWarpReg operator&(const CpuWarpReg &rhs) const {
            CpuWarpReg result;

            result.a = _mm256_and_ps(a, rhs.a);
            result.b = _mm256_and_ps(b, rhs.b);
            result.c = _mm256_and_ps(c, rhs.c);
            result.d = _mm256_and_ps(d, rhs.d);

            return result;
        }

        /// Performs pairwise binary and operations.
        CpuWarpReg& operator&=(const CpuWarpReg &rhs) {
            a = _mm256_and_ps(a, rhs.a);
            b = _mm256_and_ps(b, rhs.b);
            c = _mm256_and_ps(c, rhs.c);
            d = _mm256_and_ps(d, rhs.d);

            return *this;
        }

        /// Computes the approximate reciprocal of each value.
        CpuWarpReg approxRcp() const {
            CpuWarpReg result;

            result.a = _mm256_rcp_ps(a);
            result.b = _mm256_rcp_ps(b);
            result.c = _mm256_rcp_ps(c);
            result.d = _mm256_rcp_ps(d);

            return result;
        }

        /// Computes the approximate squareroot of each value.
        CpuWarpReg approxSqrt() const {
            CpuWarpReg result;

            result.a = _mm256_rcp_ps(_mm256_rsqrt_ps(a));
            result.b = _mm256_rcp_ps(_mm256_rsqrt_ps(b));
            result.c = _mm256_rcp_ps(_mm256_rsqrt_ps(c));
            result.d = _mm256_rcp_ps(_mm256_rsqrt_ps(d));

            return result;
        }

        /// Computes the precise reciprocal of each value.
        CpuWarpReg rcp() const {
            CpuWarpReg result;

            result.a = _mm256_div_ps(_mm256_set1_ps(1.0f), a);
            result.b = _mm256_div_ps(_mm256_set1_ps(1.0f), b);
            result.c = _mm256_div_ps(_mm256_set1_ps(1.0f), c);
            result.d = _mm256_div_ps(_mm256_set1_ps(1.0f), d);

            return result;
        }
};

/// @}

namespace SFM {

namespace PFBundleAdjustment {

/** @addtogroup BundleAdjustment_Group
 *  @{
 */

enum { NUM_TRACK_PARAMS = 4 };

struct PFJacobiTransposedTrackBlock
{
    alignas(32) unsigned trackParameterIndices[32];
    alignas(32) float xDerivatives[32*NUM_TRACK_PARAMS];
    alignas(32) float yDerivatives[32*NUM_TRACK_PARAMS];
};

struct PFJacobiTransposedCameraBlockHeader
{
    unsigned cameraParameterIndex;
//    unsigned count;
};

template<unsigned NumCameraUpdateParameters>
struct PFJacobiTransposedCameraBlockData
{
    alignas(32) float xDerivatives[32*NumCameraUpdateParameters];
    alignas(32) float yDerivatives[32*NumCameraUpdateParameters];
};

struct PFJacobiTransposedCalibBlockHeader
{
    unsigned calibParameterIndex;
//    unsigned count;
};

template<unsigned NumCalibrationUpdateParameters>
struct PFJacobiTransposedCalibBlockData
{
    alignas(32) float xDerivatives[32*(NumCalibrationUpdateParameters>0?NumCalibrationUpdateParameters:1)];
    alignas(32) float yDerivatives[32*(NumCalibrationUpdateParameters>0?NumCalibrationUpdateParameters:1)];
};

/**
 * @brief Interface for the transposed Jacobi matrix.
 * @details By exposing the multiplication operation through
 * an interface, the algorithms using the Jacobian matrix can
 * be ignorant of the actual template instantiation of the Jacobian.
 */
class PFJacobiTransposedInterface
{
    public:
        virtual ~PFJacobiTransposedInterface() { }
        virtual void resize(unsigned numObservationBlocks) = 0;
        virtual void multiplyWithVector(const std::vector<float, AlignedAllocator<float>> &src,
                                        std::vector<float, AlignedAllocator<float>> &dst,
                                        unsigned batchSize = -1) const = 0;
};


/**
 * @brief Templated implementation of the transposed Jacobi matrix.
 * @details The template parameters allow
 * the number of update parameters for each camera and calibration
 * to change depending on the employed camera or radial distortion model.
 */
template<unsigned NumCameraUpdateParameters, unsigned NumCalibrationUpdateParameters>
class PFJacobiTransposed : public PFJacobiTransposedInterface
{
    public:
        void resize(unsigned numObservationBlocks) override;

        void setupObservationBlock(unsigned observationBlockIndex, unsigned trackParameterIndices[32], unsigned cameraParameterIndex, unsigned calibParameterIndex);

        void setSingleTrackData(unsigned observationIdx, float *data);
        void setSingleCameraData(unsigned observationIdx, float *data);
        void setSingleInternalCalibData(unsigned observationIdx, float *data);

        void multiplyWithVector(const std::vector<float, AlignedAllocator<float>> &src,
                                std::vector<float, AlignedAllocator<float>> &dst,
                                unsigned batchSize = -1) const override;

        void writeLayoutToImage(const char *filename);

        typedef PFJacobiTransposedCameraBlockData<NumCameraUpdateParameters> SizedPFJacobiTransposedCameraBlockData;
        typedef PFJacobiTransposedCalibBlockData<NumCalibrationUpdateParameters> SizedPFJacobiTransposedCalibBlockData;
    private:
        void multiplyWithVectorSubrange(const float *src, float *dst, unsigned start, unsigned count) const;


        std::vector<PFJacobiTransposedTrackBlock, AlignedAllocator<PFJacobiTransposedTrackBlock>> m_tracks;
        std::vector<PFJacobiTransposedCameraBlockHeader> m_cameraHeaders;
        std::vector<SizedPFJacobiTransposedCameraBlockData, AlignedAllocator<SizedPFJacobiTransposedCameraBlockData>> m_cameraData;
        std::vector<PFJacobiTransposedCalibBlockHeader> m_calibHeaders;
        std::vector<SizedPFJacobiTransposedCalibBlockData, AlignedAllocator<SizedPFJacobiTransposedCalibBlockData>> m_calibData;

//        friend std::ostream &operator<<(std::ostream &stream, const PFJacobiTransposed<NumCameraUpdateParameters, NumCalibrationUpdateParameters> &jacobi);
};




struct PFJacobiTrackBlockHeader
{
    unsigned blockStart;
    unsigned blockCount;
};


struct PFJacobiTrackBlockData
{
    alignas(32) unsigned residualIndices[8];
    alignas(32) float xDerivatives[8*NUM_TRACK_PARAMS];
    alignas(32) float yDerivatives[8*NUM_TRACK_PARAMS];
};

struct PFJacobiCameraBlockHeader
{
    unsigned blockStart;
    unsigned blockCount;
};

template<unsigned NumCameraUpdateParameters>
struct PFJacobiCameraBlockData
{
    alignas(32) float xDerivatives[32*NumCameraUpdateParameters];
    alignas(32) float yDerivatives[32*NumCameraUpdateParameters];
};

struct PFJacobiCalibBlockHeader
{
    unsigned blockStart;
    unsigned blockCount;
};

template<unsigned NumCalibrationUpdateParameters>
struct PFJacobiCalibBlockData
{
    alignas(32) float xDerivatives[32*(NumCalibrationUpdateParameters>0?NumCalibrationUpdateParameters:1)];
    alignas(32) float yDerivatives[32*(NumCalibrationUpdateParameters>0?NumCalibrationUpdateParameters:1)];
};

/**
 * @brief Interface for the Jacobi matrix.
 * @details By exposing the multiplication operation through
 * an interface, the algorithms using the Jacobian matrix can
 * be ignorant of the actual template instantiation of the Jacobian.
 */
class PFJacobiInterface
{
    public:
        virtual ~PFJacobiInterface() { }
        virtual void resize(unsigned numTracksBlockRows, unsigned numCameras, unsigned numCalibs,
                    unsigned *numTrackBlocks, unsigned *numCameraBlocks, unsigned *numCalibBlocks) = 0;
        virtual void multiplyWithVector(const std::vector<float, AlignedAllocator<float>> &src,
                                        std::vector<float, AlignedAllocator<float>> &dst,
                                        unsigned batchSize = -1) const = 0;
        virtual void computeDiagonalOfJTJ(std::vector<float, AlignedAllocator<float>> &dst) const = 0;
};


/**
 * @brief Templated implementation of the Jacobi matrix.
 * @details The template parameters allow
 * the number of update parameters for each camera and calibration
 * to change depending on the employed camera or radial distortion model.
 */
template<unsigned NumCameraUpdateParameters, unsigned NumCalibrationUpdateParameters>
class PFJacobi : public PFJacobiInterface
{
    public:
        void resize(unsigned numTracksBlockRows, unsigned numCameras, unsigned numCalibs,
                    unsigned *numTrackBlocks, unsigned *numCameraBlocks, unsigned *numCalibBlocks) override;

        void setupSingleTrack(unsigned trackIndex, unsigned intraTrackObsIndex, unsigned obsResidualIndex);

        void setSingleTrackData(unsigned trackIndex, unsigned intraTrackObsIndex, float *data);
        void setSingleCameraData(unsigned cameraIdx, unsigned intraCameraObsIndex, float *data);
        void setSingleCalibData(unsigned calibIdx, unsigned intraCalibObsIndex, float *data);
        void multiplyWithVector(const std::vector<float, AlignedAllocator<float>> &src,
                                std::vector<float, AlignedAllocator<float>> &dst,
                                unsigned batchSize = -1) const override;

        void computeDiagonalOfJTJ(std::vector<float, AlignedAllocator<float>> &dst) const override;

        void writeLayoutToImage(const char *filename);

        typedef PFJacobiCameraBlockData<NumCameraUpdateParameters> SizedPFJacobiCameraBlockData;
        typedef PFJacobiCalibBlockData<NumCalibrationUpdateParameters> SizedPFJacobiCalibBlockData;
    protected:
        std::vector<PFJacobiTrackBlockHeader> m_trackHeaders;
        std::vector<PFJacobiTrackBlockData, AlignedAllocator<PFJacobiTrackBlockData> > m_trackData;
        std::vector<PFJacobiCameraBlockHeader> m_cameraHeaders;
        std::vector<SizedPFJacobiCameraBlockData, AlignedAllocator<SizedPFJacobiCameraBlockData>> m_cameraData;
        std::vector<PFJacobiCalibBlockHeader> m_calibHeaders;
        std::vector<SizedPFJacobiCalibBlockData, AlignedAllocator<SizedPFJacobiCalibBlockData>> m_calibData;

//        friend std::ostream &operator<<(std::ostream &stream, const PFJacobi<NumCameraUpdateParameters, NumCalibrationUpdateParameters> &jacobi);
};


extern template class PFJacobiTransposed<6u,0u>;
extern template class PFJacobi<6u,0u>;

extern template class PFJacobiTransposed<6u,7u>;
extern template class PFJacobi<6u,7u>;
extern template class PFJacobiTransposed<6u,4u>;
extern template class PFJacobi<6u,4u>;

extern template class PFJacobiTransposed<6u,8u>;
extern template class PFJacobi<6u,8u>;
extern template class PFJacobiTransposed<6u,5u>;
extern template class PFJacobi<6u,5u>;

/// @}

}

}


#endif // PFJACOBI_H
