/*
    Structure from Motion with Deferred Feature Matching and Subset Bundle Adjustment
    Copyright (C) 2015 Andreas Ley <andy-ley@arcor.de>

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
*/

#undef _GLIBCXX_USE_INT128

#include <stdint.h>

#include "LensCalibration.h"

texture<uchar1, 2, cudaReadModeNormalizedFloat> inputImageTexture;
texture<float2, 2, cudaReadModeElementType> filterResponseTexture;
surface<void, 2> filterResponseOutputSurface;


texture<uchar4, 2, cudaReadModeNormalizedFloat> distortedImageTexture;
surface<void, 2> undistortedOutputSurface;




__constant__ float filterResponse[9*9] = {
-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f,
-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f,
0.5f,   0.5f,  0.5f,  0.5f,  0.5f,  0.5f,  0.5f,  0.5f,  0.5f,
1.0f,   1.0f,  1.0f,  1.0f,  1.0f,  1.0f,  1.0f,  1.0f,  1.0f,
1.0f,   1.0f,  1.0f,  1.0f,  1.0f,  1.0f,  1.0f,  1.0f,  1.0f,
1.0f,   1.0f,  1.0f,  1.0f,  1.0f,  1.0f,  1.0f,  1.0f,  1.0f,
0.5f,   0.5f,  0.5f,  0.5f,  0.5f,  0.5f,  0.5f,  0.5f,  0.5f,
-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f,
-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f
};

extern "C" __global__ void runEdgeFilter(unsigned width, unsigned height)
{
    const unsigned x = blockIdx.x * blockDim.x + threadIdx.x;
    const unsigned y = blockIdx.y * blockDim.y + threadIdx.y;

    if ((x >= width) || (y >= height))
        return;

    float sumX = 0.0f;
    float sumY = 0.0f;
    for (int i = 0; i < 9; i++)
        for (int j = 0; j < 9; j++) {
            float sampleX = x + j - 3.5f;
            float sampleY = y + i - 3.5f;

            float value = tex2D(inputImageTexture, sampleX, sampleY).x;
            sumX += value * filterResponse[i*9+j];
            sumY += value * filterResponse[i+j*9];
        }

    uint32_t output = __float2half_rn(sumX) | (__float2half_rn(sumY) << 16);
    surf2Dwrite(output, filterResponseOutputSurface, x * 4, y);
}


__device__ void matMul2x2(float * __restrict__ dst, const float * __restrict__ left, const float * __restrict__ right)
{
    #pragma unroll
    for (unsigned i = 0; i < 2; i++)
        #pragma unroll
        for (unsigned j = 0; j < 2; j++)
            dst[i*2+j] = left[i*2+0] * right[0*2+j] +
                         left[i*2+1] * right[1*2+j];
}

__device__ void cfTensorVote(float dX, float dY, float senderTensor11, float senderTensor12, float senderTensor22,
                                float &recieverTensor11, float &recieverTensor12, float &recieverTensor22,
                                float scaleFactor)
{
    float sqrD = dX*dX + dY*dY;

    float c = expf(-sqrD * scaleFactor);

    float rX, rY;
    {
        float rcpD = 1.0f/sqrtf(sqrD);
        rX = dX * rcpD;
        rY = dY * rcpD;
    }

    float R[4];
    R[0*2+0] = 1.0f - 2.0f * rX*rX;
    R[0*2+1] =
    R[1*2+0] = 0.0f - 2.0f * rX*rY;
    R[1*2+1] = 1.0f - 2.0f * rY*rY;

    float tmp[4];
    tmp[0*2+0] = 1.0f - 0.5f * rX*rX;
    tmp[0*2+1] =
    tmp[1*2+0] = 0.0f - 0.5f * rX*rY;
    tmp[1*2+1] = 1.0f - 0.5f * rY*rY;

    float Rdash[4];
    matMul2x2(Rdash, tmp, R);

    float tmp2[4];
    tmp2[0*2+0] = senderTensor11;
    tmp2[0*2+1] = senderTensor12;
    tmp2[1*2+0] = senderTensor12;
    tmp2[1*2+1] = senderTensor22;


    matMul2x2(tmp, R, tmp2);
    matMul2x2(tmp2, tmp, Rdash);

    recieverTensor11 += tmp2[0*2+0] * c;
    recieverTensor12 += (tmp2[0*2+1]+tmp2[1*2+0])*0.5f * c;
    recieverTensor22 += tmp2[1*2+1] * c;

/*
    LinAlg::Matrix2x2f R;
    R[0][0] = 1.0f - 2.0f * r[0]*r[0];
    R[0][1] = 0.0f - 2.0f * r[0]*r[1];
    R[1][0] = 0.0f - 2.0f * r[1]*r[0];
    R[1][1] = 1.0f - 2.0f * r[1]*r[1];

    LinAlg::Matrix2x2f Rdash;
    Rdash[0][0] = 1.0f - 0.5f * r[0]*r[0];
    Rdash[0][1] = 0.0f - 0.5f * r[0]*r[1];
    Rdash[1][0] = 0.0f - 0.5f * r[1]*r[0];
    Rdash[1][1] = 1.0f - 0.5f * r[1]*r[1];

    Rdash = Rdash * R;

    return R*senderTensor*Rdash * c;
*/
}


extern "C" __global__ void tensorVoteEnhanceEdges(unsigned width, unsigned height)
{
    const unsigned x = blockIdx.x * blockDim.x + threadIdx.x;
    const unsigned y = blockIdx.y * blockDim.y + threadIdx.y;

    if ((x >= width) || (y >= height))
        return;

    float tensor11 = 0.00001f;
    float tensor12 = 0.0f;
    float tensor22 = 0.00001f;
    for (int i = 0; i < 31; i++)
        for (int j = 0; j < 31; j++) {
            if ((i != 15) || (j != 15)) {
                float sampleX = x + j - 14.5f;
                float sampleY = y + i - 14.5f;

                float2 grad = tex2D(filterResponseTexture, sampleX, sampleY);
                float x = fmax(grad.x, 0.0f);
                float y = 0.0f;//fmax(grad.y, 0.0f);
/*
                cfTensorVote(j-15.0f, i-15.0f,
                             0.001f + x*x, x*y, 0.001f + y*y,
                             tensor11, tensor12, tensor22,
                             1.0f/100.0f);
*/
                float dX = j-15.0f;
                float dY = i-15.0f;
                float sqrD = dX*dX + dY*dY;
                float c = expf(-sqrD/100.0f);

                float rX = dX / sqrtf(sqrD);
                float rY = dY / sqrtf(sqrD);

                float voteStrength = fabs(rX*x + rY*y)*c;
                float vX = rX * voteStrength;
                float vY = rY * voteStrength;

                tensor11 += vX*vX;
                tensor12 += vX*vY;
                tensor22 += vY*vY;
            }
        }

    float trace = tensor11 + tensor22;
    float det = tensor11 * tensor22 - tensor12 * tensor12;

    float lambda1 = trace * 0.5f + sqrtf(trace*trace*0.25f - det);
    float lambda2 = trace * 0.5f - sqrtf(trace*trace*0.25f - det);

    float bigEv = fmax(lambda1, lambda2);
    float smallEv = fmin(lambda1, lambda2);

    float dX, dY;

    if (fabs(tensor12) < 1e-5f) {
        if (tensor11 > tensor22) {
            dX = 1.0f;
            dY = 0.0f;
        } else {
            dX = 0.0f;
            dY = 1.0f;
        }
    } else {
        dX = bigEv-tensor22;
        dY = tensor12;

        float rcpL = 1.0f/sqrtf(dX*dX + dY*dY);
        dX *= rcpL;
        dY *= rcpL;
    }

    dX *= bigEv * bigEv * 1e-4f;
    dY *= bigEv * bigEv * 1e-4f;

    uint32_t output = __float2half_rn(dX) | (__float2half_rn(dY) << 16);
    //output = __float2half_rn(bigEv / fmax(0.01f, smallEv)*1000.0f) | (__float2half_rn(0.0f) << 16);
   // output = __float2half_rn(bigEv*bigEv*bigEv*1e-7f) | (__float2half_rn(0.0f) << 16);
    surf2Dwrite(output, filterResponseOutputSurface, x * 4, y);
}



__device__ void VSFMDist_computeCoordResidual(float undistortedX,
                                              float undistortedY,
                                              float distortedX,
                                              float distortedY,
                                              float distortion,
                                              float &residualX,
                                              float &residualY)
{
    float D = distortedX*distortedX + distortedY*distortedY;
    float r = 1.0f + D * distortion;
    residualX = undistortedX - r * distortedX;
    residualY = undistortedY - r * distortedY;
}

__device__ void VSFMDist_computeCoordJacobi(float distortedX,
                                            float distortedY,
                                            float distortion,
                                            float *J)
{
    float XX = distortedX*distortedX;
    float YY = distortedY*distortedY;
    J[0*2+0] = 1.0f + distortion * YY + 3.0f * distortion * XX;
    J[0*2+1] = 2.0f * distortion * distortedX * distortedY;

    J[1*2+0] = 2.0f * distortion * distortedX * distortedY;
    J[1*2+1] = 1.0f + distortion * XX + 3.0f * distortion * YY;
}


__device__ void VSFMDist_computeDampenedHessian(const float *J, float lambda, float *h)
{
    h[0*2+0] = (J[0*2+0] * J[0*2+0] + J[0*2+1] * J[1*2+0]) * (1.0f + lambda);

    h[0*2+1] = J[0*2+0] * J[0*2+1] + J[0*2+1] * J[1*2+1];
    h[1*2+0] = J[1*2+0] * J[0*2+0] + J[1*2+1] * J[1*2+0];

    h[1*2+1] = (J[1*2+0] * J[0*2+1] + J[1*2+1] * J[1*2+1]) * (1.0f + lambda);
}

__device__ void VSFMDist_mul2x2(const float *mat, const float *x, float *y)
{
    y[0] = x[0] * mat[0*2+0] + x[1] * mat[0*2+1];
    y[1] = x[0] * mat[1*2+0] + x[1] * mat[1*2+1];
}

__device__ void VSFMDist_mul2x2Transposed(const float *mat, const float *x, float *y)
{
    y[0] = x[0] * mat[0*2+0] + x[1] * mat[1*2+0];
    y[1] = x[0] * mat[0*2+1] + x[1] * mat[1*2+1];
}

__device__ void VSFMDist_solve2x2(const float *mat, const float *x, float *y)
{
    float det = mat[0*2+0]*mat[1*2+1] - mat[0*2+1]*mat[1*2+0];
    float rcpDet = 1.0f / det;

    y[0] = rcpDet * (x[0] * mat[1*2+1] - x[1] * mat[0*2+1]);
    y[1] = rcpDet * (-x[0] * mat[1*2+0] + x[1] * mat[0*2+0]);
}

extern "C" __global__ void undistortVSFM(UndistortVSFMKernelParams params)
{
    const unsigned x = blockIdx.x * blockDim.x + threadIdx.x;
    const unsigned y = blockIdx.y * blockDim.y + threadIdx.y;

    if ((x >= params.width) || (y >= params.height))
        return;

    const float undistortedX = x + 0.5f - params.imageCenterX;
    const float undistortedY = y + 0.5f - params.imageCenterY;

    float distortedX = undistortedX;
    float distortedY = undistortedY;

    float residual[2];
    VSFMDist_computeCoordResidual(undistortedX, undistortedY, distortedX, distortedY, params.distortion,
                                  residual[0], residual[1]);
    float lastError = residual[0]*residual[0] + residual[1]*residual[1];

    float lambda = 1.0f;

    for (unsigned iteration = 0; iteration < 50; iteration++) {
        float jacobi[4];
        VSFMDist_computeCoordJacobi(distortedX, distortedY, params.distortion,
                                    jacobi);

        float J_r[2];
        VSFMDist_mul2x2Transposed(jacobi, residual, J_r);

        float dampenedHessian[4];
        VSFMDist_computeDampenedHessian(jacobi, lambda, dampenedHessian);

        float update[2];
        VSFMDist_solve2x2(dampenedHessian, J_r, update);

        float updatedCoords[2];
        updatedCoords[0] = distortedX + update[0];
        updatedCoords[1] = distortedY + update[1];

        float newResidual[2];
        VSFMDist_computeCoordResidual(undistortedX, undistortedY, updatedCoords[0], updatedCoords[1], params.distortion,
                                      newResidual[0], newResidual[1]);
        float newError = newResidual[0]*newResidual[0] + newResidual[1]*newResidual[1];

        if (newError < lastError) {
            lambda *= 0.7f;
            lastError = newError;
            residual[0] = newResidual[0];
            residual[1] = newResidual[1];
            distortedX = updatedCoords[0];
            distortedY = updatedCoords[1];
        } else {
            lambda *= 2.0f;
        }
    }

    float4 color = tex2D(distortedImageTexture, distortedX + params.imageCenterX, distortedY + params.imageCenterY);

    //color.x = lastError * 10000.0f;

    uint32_t output = (min(255, int(color.x * 256.0f)) << 0) |
                      (min(255, int(color.y * 256.0f)) << 8) |
                      (min(255, int(color.z * 256.0f)) << 16) |
                      (min(255, int(color.w * 256.0f)) << 24);
    surf2Dwrite(output, undistortedOutputSurface, x * 4, y);
}


extern "C" __global__ void undistortRadial(UndistortRadialKernelParams params)
{
    const unsigned x = blockIdx.x * blockDim.x + threadIdx.x;
    const unsigned y = blockIdx.y * blockDim.y + threadIdx.y;

    if ((x >= params.width) || (y >= params.height))
        return;

    const float undistortedX = (x + 0.5f - params.imageCenterX) * params.rcpFac;
    const float undistortedY = (y + 0.5f - params.imageCenterY) * params.rcpFac;

    const float r2 = undistortedX*undistortedX + undistortedY*undistortedY;
    const float r3 = r2 * std::sqrt(r2);
    const float r4 = r2*r2;

    float R = 1.0f + r2 * params.distortion[0] + r3 * params.distortion[1] + r4 * params.distortion[2];

    float distortedX = undistortedX * R;
    float distortedY = undistortedY * R;


    float4 color = tex2D(distortedImageTexture, distortedX * params.fac + params.imageCenterX, distortedY * params.fac + params.imageCenterY);

    //color.x = lastError * 10000.0f;
/*
    color.x =
    color.y =
    color.z = fabs(R-1.0f) * params.width * 0.5f;
*/
    uint32_t output = (min(255, int(color.x * 256.0f)) << 0) |
                      (min(255, int(color.y * 256.0f)) << 8) |
                      (min(255, int(color.z * 256.0f)) << 16) |
                      (min(255, int(color.w * 256.0f)) << 24);
    surf2Dwrite(output, undistortedOutputSurface, x * 4, y);
}
