/*
    Structure from Motion with Deferred Feature Matching and Subset Bundle Adjustment
    Copyright (C) 2015 Andreas Ley <andy-ley@arcor.de>

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
*/

#ifndef _WARPLINALG_HPP_
#define _WARPLINALG_HPP_

#include "cuUtilHelpers.hpp"
#include "../cub/util_ptx_reduced.cuh"

namespace cuUtils {


template<unsigned dimension>
class WarpVector
{
    public:
        enum {
            NUM_ELEMENTS_PER_LANE = (dimension + WARP_SIZE-1)/WARP_SIZE
        };

        __device__ WarpVector() {
            setZero();
        }

        __device__ inline void load(const float *src) {
            #pragma unroll
            for (unsigned i = 0; i < NUM_ELEMENTS_PER_LANE; i++) {
                unsigned offset = cub::LaneId() + i*WARP_SIZE;
                if (offset < dimension) {
                    m_elements[i] = src[offset];
                }
            }
            const unsigned numRepititions = WARP_SIZE / NextPowerOfTwoRoundUp<dimension>::value;
            if (numRepititions > 0) {
                m_elements[0] = __shfl(m_elements[0], cub::LaneId() % (WARP_SIZE / numRepititions));
            }
        }

        __device__ inline void store(float *dst) {
            #pragma unroll
            for (unsigned i = 0; i < NUM_ELEMENTS_PER_LANE; i++) {
                unsigned offset = cub::LaneId() + i*WARP_SIZE;
                if (offset < dimension) {
                    dst[offset] = m_elements[i];
                }
            }
        }

        __device__ inline float bcast(unsigned index) const {
            unsigned lane = index % WARP_SIZE;
            unsigned elem = index / WARP_SIZE;
            return __shfl(m_elements[elem], lane);
        }

        __device__ inline void setZero() {
            for (unsigned i = 0; i < NUM_ELEMENTS_PER_LANE; i++) {
                m_elements[i] = 0.0f;
            }
        }

        __device__ inline WarpVector<dimension> operator*(float scalar) const {
            WarpVector<dimension> result;
            for (unsigned i = 0; i < NUM_ELEMENTS_PER_LANE; i++) {
                result.m_elements[i] = m_elements[i] * scalar;
            }
            return result;
        }

        __device__ inline WarpVector<dimension> operator+(const WarpVector<dimension> &other) const {
            WarpVector<dimension> result;
            for (unsigned i = 0; i < NUM_ELEMENTS_PER_LANE; i++) {
                result.m_elements[i] = m_elements[i] + other.m_elements[i];
            }
            return result;
        }

        __device__ inline WarpVector<dimension> operator-(const WarpVector<dimension> &other) const {
            WarpVector<dimension> result;
            for (unsigned i = 0; i < NUM_ELEMENTS_PER_LANE; i++) {
                result.m_elements[i] = m_elements[i] - other.m_elements[i];
            }
            return result;
        }

        float m_elements[NUM_ELEMENTS_PER_LANE];
};


template<unsigned dimension>
__device__ inline float dot(const WarpVector<dimension> &src1, const WarpVector<dimension> &src2)
{
    float sum = 0.0f;
    for (unsigned i = 0; i < WarpVector<dimension>::NUM_ELEMENTS_PER_LANE; i++) {
        sum += src1.m_elements[i] * src2.m_elements[i];
    }

    if (dimension == 1) {
        return __shfl(sum, 0);
    } else
    if (dimension == 2) {
        return __shfl(sum, 0) + __shfl(sum, 1);
    } else
    if (dimension <= 4) {
        #pragma unroll
        for (unsigned j = 2; j >= 1; j/=2)
            sum += __shfl_xor(sum, j, 32);
        return sum;
    } else
    if (dimension <= 8) {
        #pragma unroll
        for (unsigned j = 4; j >= 1; j/=2)
            sum += __shfl_xor(sum, j, 32);
        return sum;
    }else
    if (dimension <= 16) {
        #pragma unroll
        for (unsigned j = 8; j >= 1; j/=2)
            sum += __shfl_xor(sum, j, 32);
        return sum;
    } else {
        #pragma unroll
        for (unsigned j = 16; j >= 1; j/=2)
            sum += __shfl_xor(sum, j, 32);
        return sum;
    }
}





template<unsigned rows, unsigned cols>
class WarpMatrix
{
    public:
        enum {
            WARP_BLOCK_COLS = NextPowerOfTwoRoundUp<cols>::value < WARP_SIZE ? NextPowerOfTwoRoundUp<cols>::value : WARP_SIZE,
            WARP_BLOCK_ROWS = rows < (WARP_SIZE/WARP_BLOCK_COLS)?rows:(WARP_SIZE/WARP_BLOCK_COLS),
            WARP_BLOCK_ELEMS = WARP_BLOCK_COLS*WARP_BLOCK_ROWS,

            NUM_COLS_PER_LANE = (cols + WARP_BLOCK_COLS-1)/WARP_BLOCK_COLS,
            NUM_ROWS_PER_LANE = (rows + WARP_BLOCK_ROWS-1)/WARP_BLOCK_ROWS
        };

        __device__ WarpMatrix() {
            setZero();
        }

        __device__ inline void setZero() {
            for (unsigned i = 0; i < NUM_COLS_PER_LANE * NUM_ROWS_PER_LANE; i++)
                m_elements[i] = 0;
        }

        __device__ inline void load(const float *src) {
            if (cub::LaneId() < WARP_BLOCK_ELEMS) {
                #pragma unroll
                for (unsigned i = 0; i < NUM_ROWS_PER_LANE; i++) {
                    #pragma unroll
                    for (unsigned j = 0; j < NUM_COLS_PER_LANE; j++) {

                        const unsigned srcCol = j*WARP_BLOCK_COLS + (cub::LaneId() % WARP_BLOCK_COLS);
                        const unsigned srcRow = i*WARP_BLOCK_ROWS + (cub::LaneId() / WARP_BLOCK_COLS);

                        if ((srcCol < cols) && (srcRow < rows)) {
                            m_elements[i*NUM_COLS_PER_LANE+j] = src[srcRow * cols + srcCol];
                        }
                    }
                }
            }
        }

        __device__ inline void store(float *dst) {
            if (cub::LaneId() < WARP_BLOCK_ELEMS) {
                #pragma unroll
                for (unsigned i = 0; i < NUM_ROWS_PER_LANE; i++) {
                    #pragma unroll
                    for (unsigned j = 0; j < NUM_COLS_PER_LANE; j++) {

                        const unsigned srcCol = j*WARP_BLOCK_COLS + (cub::LaneId() % WARP_BLOCK_COLS);
                        const unsigned srcRow = i*WARP_BLOCK_ROWS + (cub::LaneId() / WARP_BLOCK_COLS);

                        if ((srcCol < cols) && (srcRow < rows)) {
                            dst[srcRow * cols + srcCol] = m_elements[i*NUM_COLS_PER_LANE+j];
                        }
                    }
                }
            }
        }

        __device__ WarpVector<rows> operator*(const WarpVector<cols> &vec) const {
            WarpVector<rows> result;

            result.setZero();
// todo: support vectors longer than 32
            #pragma unroll
            for (unsigned i = 0; i < NUM_ROWS_PER_LANE; i++) {
                #pragma unroll
                for (unsigned j = 0; j < NUM_COLS_PER_LANE; j++) {
                    float product = m_elements[i*NUM_COLS_PER_LANE+j] * vec.m_elements[j];//__shfl(vec.m_elements[j], cub::LaneId() % WARP_BLOCK_COLS);

                    float sum = product;
                    #pragma unroll
                    for (unsigned j = WARP_BLOCK_COLS/2; j >= 1; j/=2)
                        sum += __shfl_xor(sum, j, 32);

                    unsigned destinationIndex = i * WARP_BLOCK_ROWS;
                    float v = __shfl(sum, (cub::LaneId()-destinationIndex)*WARP_BLOCK_COLS);

                    if ((cub::LaneId() >= destinationIndex) && (cub::LaneId() < destinationIndex + WARP_BLOCK_ROWS)) {
                        result.m_elements[0] += v;
                    }
                }
            }

            const unsigned numRepititions = WARP_SIZE / NextPowerOfTwoRoundUp<rows>::value;
            if (numRepititions > 0) {
                result.m_elements[0] = __shfl(result.m_elements[0], cub::LaneId() % (WARP_SIZE / numRepititions));
            }

            return result;
        }

        float m_elements[NUM_COLS_PER_LANE * NUM_ROWS_PER_LANE];
};



template<unsigned dimension>
__device__ WarpVector<dimension> conjugateGradientSolve(const WarpMatrix<dimension, dimension> &A, const WarpVector<dimension> &b, unsigned iterations)
{
    WarpVector<dimension> x;
    WarpVector<dimension> r = b; // -A*x
    WarpVector<dimension> Ap;

    WarpVector<dimension> p = r;

    float rSqrOld = dot(r, r);
    for (unsigned iter = 0; iter < iterations; iter++) {
        Ap = A * p;
        float alpha = rSqrOld / dot(p, Ap);
        x = x + p * alpha;
        r = r - Ap * alpha;

        float rSqrNew = dot(r, r);

        if (rSqrNew < 1e-20f)
            break;

        float beta = rSqrNew / rSqrOld;
        p = r + p * beta;
        rSqrOld = rSqrNew;
    }
    return x;
}

}

#endif // _WARPLINALG_HPP_
