/*
    Structure from Motion with Deferred Feature Matching and Subset Bundle Adjustment
    Copyright (C) 2015 Andreas Ley <andy-ley@arcor.de>

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
*/

#include "DenseReconstruction.h"
#include "cudaKernelUtils/WarpLinAlg.hpp"
#include "cudaKernelUtils/WarpLoadStore.hpp"

using cuUtils::WarpVector;
using cuUtils::WarpMatrix;

surface<void, cudaSurfaceType2DLayered> outputImages;
texture<uchar1, cudaTextureType2DLayered, cudaReadModeNormalizedFloat> images;


template<unsigned width, unsigned height>
struct WarpPatch
{
    enum {
        Width = width,
        Height = height,
        ValuesPerThread = (width*height + 31) / 32
    };

    __device__ static int getPatchX(unsigned valueIndex) { return (valueIndex * 32 + threadIdx.x) % width; }
    __device__ static int getPatchY(unsigned valueIndex) { return (valueIndex * 32 + threadIdx.x) / width; }

    float values[ValuesPerThread];

    __device__ void computeMeanVar(float &mean, float &var) {
        float X = values[0];
        float XX = values[0]*values[0];
        #pragma unroll
        for (unsigned i = 1; i < ValuesPerThread; i++) {
            X += values[i];
            XX += values[i]*values[i];
        }

        #pragma unroll
        for (int j=16; j>=1; j/=2) {
            X += __shfl_xor(X, j, 32);
            XX += __shfl_xor(XX, j, 32);
        }

        mean = X * (1.0f / (width*height));
        var = XX * (1.0f / (width*height)) - mean*mean;
    }

    __device__ void normalize(float mean, float var) {
        float scale = 1.0f / sqrtf(var);
        float offset = -mean * scale;

        #pragma unroll
        for (unsigned i = 0; i < ValuesPerThread; i++) {
            values[i] = values[i] * scale + offset;
        }
    }

    __device__ float computeNormalizedCrossCorrelation(const WarpPatch<width, height> &other) {
        float XY = values[0]*other.values[0];
        #pragma unroll
        for (unsigned i = 1; i < ValuesPerThread; i++) {
            XY += values[i]*other.values[i];
        }
        #pragma unroll
        for (int j=16; j>=1; j/=2) {
            XY += __shfl_xor(XY, j, 32);
        }
        return XY * (1.0f / (width*height));
    }
};

typedef WarpPatch<8, 8> CorrelationPatch;

extern "C" __global__ void __launch_bounds__(32*4, 2048/128) generateVotes(GenerateVotesKernelParams kernelParams)
{
    int pixelX = kernelParams.offsetX + blockIdx.x * 4 + threadIdx.y;
    int pixelY = kernelParams.offsetY + blockIdx.y;

    if ((pixelX >= kernelParams.imageWidth-(int)CorrelationPatch::Width/2) || (pixelY >= kernelParams.imageHeight-(int)CorrelationPatch::Height/2))
        return;

    if ((pixelX <= (int)CorrelationPatch::Width/2) || (pixelY <= (int)CorrelationPatch::Height/2))
        return;

#if 0

        unsigned index;
        __shared__ ImageVote voteStruct[4];
        if (threadIdx.x == 0) {
            index = atomicAdd(kernelParams.numVotes, 1);

            voteStruct[threadIdx.y].pixelX = pixelX;
            voteStruct[threadIdx.y].pixelY = pixelY;

            float x = (pixelX + 0.5f) * kernelParams.rcpImageWidth;
            float y = (pixelY + 0.5f) * kernelParams.rcpImageHeight;

            voteStruct[threadIdx.y].rcpDepth = 0.5f;
            voteStruct[threadIdx.y].vote = tex2DLayeredLod(images, x, y, 1, 0.0f).x;
        }

        index = __shfl((int)index, 0);
        if (index < kernelParams.maxVotes) {
            cuUtils::warpCopy<ImageVote>(&kernelParams.votes[index], &voteStruct[threadIdx.y]);
        }

#else
    CorrelationPatch refPatch;
    {
        #pragma unroll
        for (unsigned i = 0; i < CorrelationPatch::ValuesPerThread; i++) {
            float x = (pixelX + (CorrelationPatch::getPatchX(i) - (int)CorrelationPatch::Width/2) + 0.5f) * kernelParams.rcpImageWidth;
            float y = (pixelY + (CorrelationPatch::getPatchY(i) - (int)CorrelationPatch::Height/2) + 0.5f) * kernelParams.rcpImageHeight;
            refPatch.values[i] = tex2DLayeredLod(images, x, y, kernelParams.refImage, 0.0f).x;
        }

        float mean, var;
        refPatch.computeMeanVar(mean, var);
        if (var < 0.0001f)
            return;

        refPatch.normalize(mean, var);
    }

    float rcpDepth = 0.0f;
    float lastRcpDepth = rcpDepth;
    float lastVotes = 0.0f;
    while (rcpDepth <= 1.0f) {

        float votes = 0.0f;

        for (unsigned imgIndex = 0; imgIndex < kernelParams.numImages; imgIndex++) {
            if (imgIndex == kernelParams.refImage) continue;

            __shared__ ReprojectionMatrix reprojectionMatrix[4];

            cuUtils::warpCopy<ReprojectionMatrix>(&reprojectionMatrix[threadIdx.y], &kernelParams.reprojectionMatrices[imgIndex]);

            CorrelationPatch patch;
            {
                #pragma unroll
                for (unsigned i = 0; i < CorrelationPatch::ValuesPerThread; i++) {
                    float x = pixelX + (CorrelationPatch::getPatchX(i) - (int)CorrelationPatch::Width/2);
                    float y = pixelY + (CorrelationPatch::getPatchY(i) - (int)CorrelationPatch::Height/2);

                    float clipSpaceX = x * reprojectionMatrix[threadIdx.y].rowMajor[0*4+0] +
                                       y * reprojectionMatrix[threadIdx.y].rowMajor[0*4+1] +
                                       rcpDepth * reprojectionMatrix[threadIdx.y].rowMajor[0*4+2] +
                                       reprojectionMatrix[threadIdx.y].rowMajor[0*4+3];

                    float clipSpaceY = x * reprojectionMatrix[threadIdx.y].rowMajor[1*4+0] +
                                       y * reprojectionMatrix[threadIdx.y].rowMajor[1*4+1] +
                                       rcpDepth * reprojectionMatrix[threadIdx.y].rowMajor[1*4+2] +
                                       reprojectionMatrix[threadIdx.y].rowMajor[1*4+3];

                    float clipSpaceW = x * reprojectionMatrix[threadIdx.y].rowMajor[3*4+0] +
                                       y * reprojectionMatrix[threadIdx.y].rowMajor[3*4+1] +
                                       rcpDepth * reprojectionMatrix[threadIdx.y].rowMajor[3*4+2] +
                                       reprojectionMatrix[threadIdx.y].rowMajor[3*4+3];

                    float rprjX = clipSpaceX * (1.0f / clipSpaceW);
                    float rprjY = clipSpaceY * (1.0f / clipSpaceW);


                    patch.values[i] = tex2DLayeredLod(images, rprjX, rprjY, imgIndex, 0.0f).x;
                }

                float mean, var;
                patch.computeMeanVar(mean, var);

                if (var < 0.0001f)
                    continue;

                patch.normalize(mean, var);
            }
            float score = fmax(0.0f, refPatch.computeNormalizedCrossCorrelation(patch));
            score = score*score;
            score = score*score;
            votes += score;
        }
#if 1
        if ((votes > kernelParams.voteThresh) && (votes < lastVotes)) {
            unsigned index;
            __shared__ ImageVote voteStruct[4];
            if (threadIdx.x == 0) {
                index = atomicAdd(kernelParams.numVotes, 1);

                voteStruct[threadIdx.y].pixelX = pixelX;
                voteStruct[threadIdx.y].pixelY = pixelY;
                voteStruct[threadIdx.y].rcpDepth = lastRcpDepth;
                voteStruct[threadIdx.y].vote = lastVotes;
            }

            index = __shfl((int)index, 0);
            if (index < kernelParams.maxVotes) {
                cuUtils::warpCopy<ImageVote>(&kernelParams.votes[index], &voteStruct[threadIdx.y]);
            }

            lastVotes = 0.0f;
        } else
            lastVotes = votes;

        lastRcpDepth = rcpDepth;
#else
        if (votes > kernelParams.voteThresh) {
            unsigned index;
            __shared__ ImageVote voteStruct[4];
            if (threadIdx.x == 0) {
                index = atomicAdd(kernelParams.numVotes, 1);

                voteStruct[threadIdx.y].pixelX = pixelX;
                voteStruct[threadIdx.y].pixelY = pixelY;
                voteStruct[threadIdx.y].rcpDepth = rcpDepth;
                voteStruct[threadIdx.y].vote = votes;
            }

            index = __shfl((int)index, 0);
            if (index < kernelParams.maxVotes) {
                cuUtils::warpCopy<ImageVote>(&kernelParams.votes[index], &voteStruct[threadIdx.y]);
            }

            lastVotes = 0.0f;
        }
#endif
        rcpDepth += kernelParams.depthStepSize;
    }
#endif
}







extern "C" __global__ void convertRGBImage(ConvertRGBImageKernelParams kernelParams)
{
    unsigned x = blockIdx.x * blockDim.x + threadIdx.x;
    unsigned y = blockIdx.y * blockDim.y + threadIdx.y;

    if ((x >= kernelParams.width) || (y >= kernelParams.height))
        return;

    unsigned packedColor = kernelParams.source[x+y*kernelParams.width];
    int r = (packedColor >> 0) & 0xFF;
    int g = (packedColor >> 8) & 0xFF;
    int b = (packedColor >> 16) & 0xFF;

    float lum = 0.299f * r + 0.587f * g + 0.114f * b;

    unsigned char c = fmin(fmax(lum, 0), 255);

    surf2DLayeredwrite(c, outputImages, x, y, kernelParams.layer);
}
