/*
    Structure from Motion with Deferred Feature Matching and Subset Bundle Adjustment
    Copyright (C) 2015 Andreas Ley <andy-ley@arcor.de>

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
*/

#undef _GLIBCXX_USE_INT128

#include "ImagePatchCompression.h"
#include "cudaKernelUtils/WarpLoadStore.hpp"
#include "cudaKernelUtils/cuUtilHelpers.hpp"
#include "cub/util_ptx_reduced.cuh"

texture<uchar4, 2, cudaReadModeNormalizedFloat> sourceImage;

enum ReadChannelType {
    READ_LUMA = 0,
    READ_CO   = 1,
    READ_CG   = 2
};

template<ReadChannelType type>
__device__ void readChannel(const CompressionJob &job, float *dst)
{
    for (unsigned y = threadIdx.y; y < 64; y += 32) {
        for (unsigned x = threadIdx.x; x < 64; x += 32) {
            float sx = (x - 31.5f) * job.stepX + job.sourceX;
            float sy = (y - 31.5f) * job.stepY + job.sourceY;

            float4 colors = tex2DLod(sourceImage, sx, sy, job.sourceBaseMipLevel);

            float v;
            switch (type) {
                case READ_LUMA:
                    v = (0.25f * colors.x  + 0.5f * colors.y + 0.25f * colors.z) * colors.w;
                break;
                case READ_CO:
                    v = (0.5f * colors.x                     - 0.5f * colors.z)  * colors.w;
                break;
                case READ_CG:
                    v = (-0.25f * colors.x + 0.5f * colors.y - 0.25f * colors.z) * colors.w;
                break;
            };
            dst[x + y * 65] = v;
        }
    }
}


__constant__ float DCTCoefficients[64*64];

/*
__device__ void computeDCTAndTranspose(float * __restrict__ src, float *__restrict__ dst)
{
    float sumA = 0.0f;
    float sumB = 0.0f;
    float sumC = 0.0f;
    float sumD = 0.0f;
    for (unsigned j = 0; j < 64; j++) {
        float src1 = src[j + (threadIdx.x+0*32)*65];
        float src2 = src[j + (threadIdx.x+1*32)*65];
        sumA += src1 * DCTCoefficients[j + (threadIdx.y+0*32) * 64];
        sumB += src2 * DCTCoefficients[j + (threadIdx.y+0*32) * 64];
        sumC += src1 * DCTCoefficients[j + (threadIdx.y+1*32) * 64];
        sumD += src2 * DCTCoefficients[j + (threadIdx.y+1*32) * 64];
    }
    dst[(threadIdx.x+0*32) + (threadIdx.y+0*32)*65] = sumA;
    dst[(threadIdx.x+1*32) + (threadIdx.y+0*32)*65] = sumB;
    dst[(threadIdx.x+0*32) + (threadIdx.y+1*32)*65] = sumC;
    dst[(threadIdx.x+1*32) + (threadIdx.y+1*32)*65] = sumD;
}


__device__ void computeIDCTAndTranspose(float *src, float *dst)
{
    float sumA = 0.0f;
    float sumB = 0.0f;
    float sumC = 0.0f;
    float sumD = 0.0f;
    for (unsigned j = 0; j < 64; j++) {
        sumA += src[j + (threadIdx.x+0*32)*65] * DCTCoefficients[(threadIdx.y+0*32) + j * 64];
        sumB += src[j + (threadIdx.x+1*32)*65] * DCTCoefficients[(threadIdx.y+0*32) + j * 64];
        sumC += src[j + (threadIdx.x+0*32)*65] * DCTCoefficients[(threadIdx.y+1*32) + j * 64];
        sumD += src[j + (threadIdx.x+1*32)*65] * DCTCoefficients[(threadIdx.y+1*32) + j * 64];
    }
    dst[(threadIdx.x+0*32) + (threadIdx.y+0*32)*65] = sumA;
    dst[(threadIdx.x+1*32) + (threadIdx.y+0*32)*65] = sumB;
    dst[(threadIdx.x+0*32) + (threadIdx.y+1*32)*65] = sumC;
    dst[(threadIdx.x+1*32) + (threadIdx.y+1*32)*65] = sumD;
}

__device__ void computeWHTAndTranspose(float * __restrict__ src, float *__restrict__ dst)
{
    for (unsigned j = 0; j < 6; j++) {
        float srcValue1A = src[(threadIdx.y+0*32) + (threadIdx.x+0*32)*65];
        float srcValue1B = src[(threadIdx.y+1*32) + (threadIdx.x+0*32)*65];
        float srcValue1C = src[(threadIdx.y+0*32) + (threadIdx.x+1*32)*65];
        float srcValue1D = src[(threadIdx.y+1*32) + (threadIdx.x+1*32)*65];
        float srcValue2A = src[((threadIdx.y+0*32) ^ (32 >> j)) + (threadIdx.x+0*32)*65];
        float srcValue2B = src[((threadIdx.y+1*32) ^ (32 >> j)) + (threadIdx.x+0*32)*65];
        float srcValue2C = src[((threadIdx.y+0*32) ^ (32 >> j)) + (threadIdx.x+1*32)*65];
        float srcValue2D = src[((threadIdx.y+1*32) ^ (32 >> j)) + (threadIdx.x+1*32)*65];

        __syncthreads();

        srcValue1A = (threadIdx.y+0*32)&(32 >> j)?-srcValue1A:srcValue1A;
        srcValue1B = (threadIdx.y+1*32)&(32 >> j)?-srcValue1B:srcValue1B;
        srcValue1C = (threadIdx.y+0*32)&(32 >> j)?-srcValue1C:srcValue1C;
        srcValue1D = (threadIdx.y+1*32)&(32 >> j)?-srcValue1D:srcValue1D;

        src[(threadIdx.y+0*32) + (threadIdx.x+0*32)*65] = srcValue1A + srcValue2A;
        src[(threadIdx.y+1*32) + (threadIdx.x+0*32)*65] = srcValue1B + srcValue2B;
        src[(threadIdx.y+0*32) + (threadIdx.x+1*32)*65] = srcValue1C + srcValue2C;
        src[(threadIdx.y+1*32) + (threadIdx.x+1*32)*65] = srcValue1D + srcValue2D;

        __syncthreads();
    }
    dst[(threadIdx.x+0*32) + (threadIdx.y+0*32)*65] = src[(threadIdx.y+0*32) + (threadIdx.x+0*32)*65] * (1.0f/sqrtf(64.0f));
    dst[(threadIdx.x+1*32) + (threadIdx.y+0*32)*65] = src[(threadIdx.y+1*32) + (threadIdx.x+0*32)*65] * (1.0f/sqrtf(64.0f));
    dst[(threadIdx.x+0*32) + (threadIdx.y+1*32)*65] = src[(threadIdx.y+0*32) + (threadIdx.x+1*32)*65] * (1.0f/sqrtf(64.0f));
    dst[(threadIdx.x+1*32) + (threadIdx.y+1*32)*65] = src[(threadIdx.y+1*32) + (threadIdx.x+1*32)*65] * (1.0f/sqrtf(64.0f));
}
*/

/*
template<bool vertical>
__device__ void computeWHT(float *src)
{
    for (unsigned j = 0; j < 6; j++) {
        for (unsigned offset = 0; offset < 64; offset += 32) {
            float srcValue1A;
            float srcValue1B;
            float srcValue2A;
            float srcValue2B;

            if (vertical) {
                srcValue1A = src[(threadIdx.y+0*32) + (threadIdx.x+offset)*65];
                srcValue1B = src[(threadIdx.y+1*32) + (threadIdx.x+offset)*65];
                srcValue2A = src[((threadIdx.y+0*32) ^ (32 >> j)) + (threadIdx.x+offset)*65];
                srcValue2B = src[((threadIdx.y+1*32) ^ (32 >> j)) + (threadIdx.x+offset)*65];
            } else {
                srcValue1A = src[(threadIdx.x+offset) + (threadIdx.y+0*32)*65];
                srcValue1B = src[(threadIdx.x+offset) + (threadIdx.y+1*32)*65];
                srcValue2A = src[(threadIdx.x+offset) + ((threadIdx.y+0*32) ^ (32 >> j))*65];
                srcValue2B = src[(threadIdx.x+offset) + ((threadIdx.y+1*32) ^ (32 >> j))*65];
            }

            __syncthreads();

            srcValue1A = (threadIdx.y+0*32)&(32 >> j)?-srcValue1A:srcValue1A;
            srcValue1B = (threadIdx.y+1*32)&(32 >> j)?-srcValue1B:srcValue1B;

            if (vertical) {
                src[(threadIdx.y+0*32) + (threadIdx.x+offset)*65] = srcValue1A + srcValue2A;
                src[(threadIdx.y+1*32) + (threadIdx.x+offset)*65] = srcValue1B + srcValue2B;
            } else {
                src[(threadIdx.x+offset) + (threadIdx.y+0*32)*65] = srcValue1A + srcValue2A;
                src[(threadIdx.x+offset) + (threadIdx.y+1*32)*65] = srcValue1B + srcValue2B;
            }
            __syncthreads();
        }
    }
    src[(threadIdx.y+0*32) + (threadIdx.x+0*32)*65] *= (1.0f/sqrtf(64.0f));
    src[(threadIdx.y+1*32) + (threadIdx.x+0*32)*65] *= (1.0f/sqrtf(64.0f));
    src[(threadIdx.y+0*32) + (threadIdx.x+1*32)*65] *= (1.0f/sqrtf(64.0f));
    src[(threadIdx.y+1*32) + (threadIdx.x+1*32)*65] *= (1.0f/sqrtf(64.0f));
}
*/

template<bool vertical>
__device__ void computeWHT(float *src)
{
    #pragma unroll
    for (unsigned offset = 0; offset < 64; offset += 32) {
        float srcValue1;
        float srcValue2;
        if (vertical) {
            srcValue1 = src[(threadIdx.x+0) + (threadIdx.y+offset)*65];
            srcValue2 = src[(threadIdx.x+32) + (threadIdx.y+offset)*65];
        } else {
            srcValue1 = src[(threadIdx.y+offset) + (threadIdx.x+0)*65];
            srcValue2 = src[(threadIdx.y+offset) + (threadIdx.x+32)*65];
        }
        {
            float oldSrc1 = srcValue1;
            srcValue1 += srcValue2;
            srcValue2 = oldSrc1 - srcValue2;
        }
        #pragma unroll
        for (unsigned i = 1; i < 6; i++) {
            {
                float a = cub::LaneId()&(32 >> i)?-srcValue1:srcValue1;
                float b = __shfl_xor(srcValue1, (32 >> i));
                srcValue1 = a+b;
            }
            {
                float a = cub::LaneId()&(32 >> i)?-srcValue2:srcValue2;
                float b = __shfl_xor(srcValue2, (32 >> i));
                srcValue2 = a+b;
            }
        }
        if (vertical) {
            src[(threadIdx.x+0) + (threadIdx.y+offset)*65] = srcValue1 * (1.0f/sqrtf(64.0f));
            src[(threadIdx.x+32) + (threadIdx.y+offset)*65] = srcValue2 * (1.0f/sqrtf(64.0f));
        } else {
            src[(threadIdx.y+offset) + (threadIdx.x+0)*65] = srcValue1 * (1.0f/sqrtf(64.0f));
            src[(threadIdx.y+offset) + (threadIdx.x+32)*65] = srcValue2 * (1.0f/sqrtf(64.0f));
        }
    }
}

__device__ void quantize(float quantizationScale, float *src, unsigned *packedDst)
{
    for (unsigned y = threadIdx.y; y < 64; y += 32) {
        float a;
        if ((threadIdx.x == 0) && (y == 0))
            a = src[(threadIdx.x*2+0) + y*65] * 15.0f;
        else
            a = src[(threadIdx.x*2+0) + y*65] * quantizationScale;
        float b = src[(threadIdx.x*2+1) + y*65] * quantizationScale;
        a = fmin(fmax(a, -1000.0f), 1000.0f);
        b = fmin(fmax(b, -1000.0f), 1000.0f);

        unsigned packed;

        //cub::BFI(packed, (unsigned) a, (unsigned) b, 16, 16);

#if 0
        packed = (__float2int_rn(a) & 0xFFFF) |
                 ((__float2int_rn(b) & 0xFFFF) << 16);
#else
        packed = (__float2int_rz(a + copysignf(0.4f, a)) & 0xFFFF) |
                 ((__float2int_rz(b + copysignf(0.4f, b)) & 0xFFFF) << 16);
#endif
        packedDst[threadIdx.x + y*32] = packed;
    }
}

__device__ void deQuantize(float rcpQuantizationScale, unsigned *packedSrc, float *dst)
{
    for (unsigned y = threadIdx.y; y < 64; y += 32) {
        unsigned packed = packedSrc[threadIdx.x + y*32];


        float a;
        if ((threadIdx.x == 0) && (y == 0))
            a = ((int16_t) (packed >> 0)) * (1.0f / 15.0f);
        else
            a = ((int16_t) (packed >> 0)) * rcpQuantizationScale;
        float b = ((int16_t) (packed >> 16)) * rcpQuantizationScale;

        dst[(threadIdx.x*2+0) + y*65] = a;
        dst[(threadIdx.x*2+1) + y*65] = b;
    }
}


extern "C" __global__ void __launch_bounds__(1024, 2) transformAndQuantize(CompressionJob *jobs, QuantizedData *results)
{

    const unsigned jobIndex = blockIdx.x;

    __shared__ CompressionJob job;

    if (cuUtils::getWarpIdInBlock2D() == 0) {
        cuUtils::warpCopy(&job, jobs + jobIndex);
    }
/*
    __shared__ float mem1[64*65];
    __shared__ float mem2[64*65];

    __syncthreads();
    readChannel<READ_LUMA>(job, mem1);
    __syncthreads();
    computeWHTAndTranspose(mem1, mem2);
    __syncthreads();
    computeWHTAndTranspose(mem2, mem1);
    __syncthreads();
    quantize(job.quantizationY, mem1, (unsigned*)results[jobIndex].dataY);


    readChannel<READ_CO>(job, mem1);
    __syncthreads();
    computeWHTAndTranspose(mem1, mem2);
    __syncthreads();
    computeWHTAndTranspose(mem2, mem1);
    __syncthreads();
    quantize(job.quantizationY, mem1, (unsigned*)results[jobIndex].dataCo);


    readChannel<READ_CG>(job, mem1);
    __syncthreads();
    computeWHTAndTranspose(mem1, mem2);
    __syncthreads();
    computeWHTAndTranspose(mem2, mem1);
    __syncthreads();
    quantize(job.quantizationY, mem1, (unsigned*)results[jobIndex].dataCg);
*/
    __shared__ float mem[64*65];

    __syncthreads();
    readChannel<READ_LUMA>(job, mem);
    __syncthreads();
    computeWHT<true>(mem);
    __syncthreads();
    computeWHT<false>(mem);
    __syncthreads();
    quantize(job.quantizationY, mem, (unsigned*)results[jobIndex].dataY);


    readChannel<READ_CO>(job, mem);
    __syncthreads();
    computeWHT<true>(mem);
    __syncthreads();
    computeWHT<false>(mem);
    __syncthreads();
    quantize(job.quantizationY, mem, (unsigned*)results[jobIndex].dataCo);


    readChannel<READ_CG>(job, mem);
    __syncthreads();
    computeWHT<true>(mem);
    __syncthreads();
    computeWHT<false>(mem);
    __syncthreads();
    quantize(job.quantizationY, mem, (unsigned*)results[jobIndex].dataCg);

#if 0
    __syncthreads();
    readChannel<READ_LUMA>(job, mem);

    for (unsigned y = threadIdx.y; y < 64; y += 32) {
        float a = mem[(threadIdx.x*2+0) + y*65] * 10000.0f;
        float b = mem[(threadIdx.x*2+1) + y*65] * 10000.0f;
        a = fmin(fmax(a, -10000.0f), 10000.0f);
        b = fmin(fmax(b, -10000.0f), 10000.0f);

        unsigned packed;

        //cub::BFI(packed, (unsigned) a, (unsigned) b, 16, 16);
        packed = (__float2int_rn(a) & 0xFFFF) |
                 ((__float2int_rn(b) & 0xFFFF) << 16);

        ((unsigned*)results[jobIndex].dataY)[threadIdx.x + y*32] = packed;
    }

#endif
}


extern "C" __global__ void debugDequantizeAndTransform(DebugDequantizeAndTransformKernelParams kernelParams)
{
    const unsigned jobIndex = blockIdx.x;

    __shared__ float mem1[64*65];
    //__shared__ float mem2[64*65];

    deQuantize(kernelParams.rcpQuantScale, (unsigned*)kernelParams.quantizedData[jobIndex].dataY, mem1);
    __syncthreads();
    computeWHT<true>(mem1);
    __syncthreads();
    computeWHT<false>(mem1);
    __syncthreads();
    for (unsigned y = threadIdx.y; y < 64; y += 32) {
        kernelParams.result[jobIndex * 64*64 + threadIdx.x + y*64] = mem1[threadIdx.x + y*65];
        kernelParams.result[jobIndex * 64*64 + threadIdx.x + 32 + y*64] = mem1[threadIdx.x + 32 + y*65];
    }

#if 0
    __syncthreads();
    unsigned *data = (unsigned*)kernelParams.quantizedData[jobIndex].dataY;

    for (unsigned y = threadIdx.y; y < 64; y += 32) {
        unsigned packed = data[threadIdx.x + y*32];

        float a = ((int16_t) (packed >> 0)) * 0.0001f;
        float b = ((int16_t) (packed >> 16)) * 0.0001f;

        kernelParams.result[jobIndex * 64*64 + (threadIdx.x*2+0) + y*64] = a;
        kernelParams.result[jobIndex * 64*64 + (threadIdx.x*2+1) + y*64] = b;
    }
#endif
}


extern "C" __global__ void debugComputePSNR(DebugComputePSNRKernelParams kernelParams)
{
    const unsigned jobIndex = blockIdx.x;

    __shared__ float mem[64*65];

    __shared__ CompressionJob job;

    if (cuUtils::getWarpIdInBlock2D() == 0) {
        cuUtils::warpCopy(&job, kernelParams.jobs + jobIndex);
    }


    deQuantize(kernelParams.rcpQuantScale, (unsigned*)kernelParams.quantizedData[jobIndex].dataY, mem);
    __syncthreads();
    computeWHT<true>(mem);
    __syncthreads();
    computeWHT<false>(mem);
    __syncthreads();
    for (unsigned y = threadIdx.y; y < 64; y += 32) {
        for (unsigned x = threadIdx.x; x < 64; x += 32) {
            float sx = (x - 31.5f) * job.stepX + job.sourceX;
            float sy = (y - 31.5f) * job.stepY + job.sourceY;

            float4 colors = tex2DLod(sourceImage, sx, sy, job.sourceBaseMipLevel);

            float v = (0.25f * colors.x  + 0.5f * colors.y + 0.25f * colors.z) * colors.w;
            float a = mem[x + y * 65];
            float d = fabs(a-v);

            kernelParams.differences[jobIndex * 64*64 + x + y*64] = d;

            mem[x + y * 65] = d*d;
        }
    }
    __syncthreads();

    if (threadIdx.x < 32) {
        float sum = 0;
        for (unsigned y = 0; y < 64; y++) {
            for (unsigned x = 0; x < 64; x += 32) {
                sum += mem[threadIdx.x + x + y*65];
            }
        }
        for (int i=16; i>=1; i/=2)
            sum += __shfl_xor(sum, i, 32);

        if (threadIdx.x == 0) {
            kernelParams.result[jobIndex] = sum;
        }
    }
}
