/*
    Structure from Motion with Deferred Feature Matching and Subset Bundle Adjustment
    Copyright (C) 2015 Andreas Ley <andy-ley@arcor.de>

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
*/

#include "../cudaInterface/PatchAtlasConstants.h"
#include "newObservationTest.h"

texture<uchar4, cudaTextureType2DLayered, cudaReadModeNormalizedFloat> patchAtlas;

__constant__ PatchAtlasConstants patchAtlasConstants;

__device__ void project2D(float *P, float x, float y, float z, float &rx, float &ry)
{
    float tx = P[0*4+0] * x + P[0*4+1] * y + P[0*4+2] * z + P[0*4+3];
    float ty = P[1*4+0] * x + P[1*4+1] * y + P[1*4+2] * z + P[1*4+3];
    float tw = P[3*4+0] * x + P[3*4+1] * y + P[3*4+2] * z + P[3*4+3];

    rx = tx*(1.0f/tw);
    ry = ty*(1.0f/tw);
}


__device__ void outputColor(float r, float g, float b, uint32_t *dst, unsigned idx)
{
    float m = fmax(r, fmax(g, b));

    unsigned w = min((int)(m * 255+1), 255);

    float rcpM = 255.0f*255.0f / w;

    uint32_t data = r * rcpM;
    data |= ((unsigned)(g * rcpM)) << 8;
    data |= ((unsigned)(b * rcpM)) << 16;
    data |= w << 24;

    dst[idx] = data;
}


extern "C" __global__ void computeNewObservationProbability(NewObservationTestParams kernelParams)
{
    const unsigned fullIndex = threadIdx.x + threadIdx.y * 8;
    __shared__ float trackSize;
    __shared__ float trackSurfaceToWorld[3*4];
    __shared__ unsigned numObservations;
    __shared__ TrackObservation *obs;

    __shared__ uint32_t patchAtlasIndex;
    __shared__ float screenSpaceOffset[2];
    __shared__ NewObservationCandidate *obsCandidate;


    if (fullIndex == 0) {
        obsCandidate = kernelParams.newCandidates + blockIdx.x + kernelParams.blockOffset;

        patchAtlasIndex = obsCandidate->patchAtlasIndex;
        screenSpaceOffset[0] = obsCandidate->screenSpaceOffset[0];
        screenSpaceOffset[1] = obsCandidate->screenSpaceOffset[1];

        TrackHead *head = obsCandidate->track;

        trackSize = head->size;
        numObservations = head->numObservations;
        for (unsigned i = 0; i < 3*4; i++)
            trackSurfaceToWorld[i] = head->trackSurfaceToWorld[i];

        obs = (TrackObservation*)(head + 1);
    }

    __syncthreads();

    __shared__ float avg[24*24*3];

    for (unsigned ix = 0; ix < 3; ix++)
        for (unsigned iy = 0; iy < 3; iy++) {

            unsigned idx = (iy*8+threadIdx.y)*24 + (ix*8+threadIdx.x);
            avg[0*24*24+idx] = 0.0f;
            avg[1*24*24+idx] = 0.0f;
            avg[2*24*24+idx] = 0.0f;

            float s = (threadIdx.x - 11.5f + ix*8.0f);
            float t = (threadIdx.y - 11.5f + iy*8.0f);
            float u = s * trackSize;
            float v = t * trackSize;

            float r = s*s+t*t;
            float lodOffset = r / 250.0f;

            float x = trackSurfaceToWorld[0*4+0] * u + trackSurfaceToWorld[0*4+1] * v + trackSurfaceToWorld[0*4+3];
            float y = trackSurfaceToWorld[1*4+0] * u + trackSurfaceToWorld[1*4+1] * v + trackSurfaceToWorld[1*4+3];
            float z = trackSurfaceToWorld[2*4+0] * u + trackSurfaceToWorld[2*4+1] * v + trackSurfaceToWorld[2*4+3];

            for (unsigned obsIndex = 0; obsIndex < numObservations; obsIndex++) {

                float texelX, texelY;

                project2D(kernelParams.patchAtlasPatchParams[obs[obsIndex].patchAtlasIndex].worldToAtlasPatch, x, y, z, texelX, texelY);
                texelX += obs[obsIndex].screenSpaceOffset[0] * kernelParams.patchAtlasPatchParams[obs[obsIndex].patchAtlasIndex].sourceImageToAtlasPatchScaleOffsetX[0];
                texelY += obs[obsIndex].screenSpaceOffset[1] * kernelParams.patchAtlasPatchParams[obs[obsIndex].patchAtlasIndex].sourceImageToAtlasPatchScaleOffsetY[0];

                float4 color = tex2DLayeredLod(patchAtlas, texelX, texelY, kernelParams.patchAtlasPatchParams[obs[obsIndex].patchAtlasIndex].layer, lodOffset);

                color.x *= color.w;
                color.y *= color.w;
                color.z *= color.w;

                avg[0*24*24 + idx] += color.x;
                avg[1*24*24 + idx] += color.y;
                avg[2*24*24 + idx] += color.z;
            }

            float fac = 1.0f / numObservations;
            avg[0*24*24+idx] *= fac;
            avg[1*24*24+idx] *= fac;
            avg[2*24*24+idx] *= fac;


        }

    __syncthreads();


    #ifdef NEWOBS_OUTPUT_DEBUG_PAIRS
    if (blockIdx.x + kernelParams.blockOffset < maxDebugOutputPairs)
        for (unsigned ix = 0; ix < 3; ix++)
            for (unsigned iy = 0; iy < 3; iy++) {
                unsigned idx = (iy*8+threadIdx.y)*24 + (ix*8+threadIdx.x);

                outputColor(avg[0*24*24+idx],
                            avg[1*24*24+idx],
                            avg[2*24*24+idx],
                            kernelParams.debugOutput+24*24*2*(blockIdx.x + kernelParams.blockOffset),
                            idx);
            }

    #endif


#if 1
    __shared__ float offset;
    __shared__ float scale;
    {
        float accuX = 0.0f;
        float accuY = 0.0f;
        float accuSQRY = 0.0f;
        float accuXY = 0.0f;

        for (unsigned ix = 0; ix < 3; ix++)
            for (unsigned iy = 0; iy < 3; iy++) {
                unsigned idx = (iy*8+threadIdx.y)*24 + (ix*8+threadIdx.x);

                float s = (threadIdx.x - 11.5f + ix*8.0f);
                float t = (threadIdx.y - 11.5f + iy*8.0f);
                float u = s * trackSize;
                float v = t * trackSize;

                float r = s*s+t*t;
                float lodOffset = r / 250.0f;

                float x = trackSurfaceToWorld[0*4+0] * u + trackSurfaceToWorld[0*4+1] * v + trackSurfaceToWorld[0*4+3];
                float y = trackSurfaceToWorld[1*4+0] * u + trackSurfaceToWorld[1*4+1] * v + trackSurfaceToWorld[1*4+3];
                float z = trackSurfaceToWorld[2*4+0] * u + trackSurfaceToWorld[2*4+1] * v + trackSurfaceToWorld[2*4+3];

                float texelX, texelY;
                project2D(kernelParams.patchAtlasPatchParams[patchAtlasIndex].worldToAtlasPatch, x, y, z, texelX, texelY);
                texelX += screenSpaceOffset[0] * kernelParams.patchAtlasPatchParams[patchAtlasIndex].sourceImageToAtlasPatchScaleOffsetX[0];
                texelY += screenSpaceOffset[1] * kernelParams.patchAtlasPatchParams[patchAtlasIndex].sourceImageToAtlasPatchScaleOffsetY[0];

                float4 color = tex2DLayeredLod(patchAtlas, texelX, texelY, kernelParams.patchAtlasPatchParams[patchAtlasIndex].layer, lodOffset);

                float lum = (color.x + color.y + color.z) * color.w;

                float lum2 = (avg[0*24*24+idx] + avg[1*24*24+idx] + avg[2*24*24+idx]);
                accuX += lum2;
                accuY += lum;
                accuSQRY += lum*lum;
                accuXY += lum*lum2;
            }


        __shared__ float accu[8*8];

        accu[fullIndex] = accuX;
        __syncthreads();
        for (unsigned i = 32; i > 0; i >>= 1) {
            if (fullIndex < i) {
                accu[fullIndex] += accu[i+fullIndex];
            }
            __syncthreads();
        }
        if (fullIndex == 0)
            accuX = accu[0];
        __syncthreads();


        accu[fullIndex] = accuY;
        __syncthreads();
        for (unsigned i = 32; i > 0; i >>= 1) {
            if (fullIndex < i) {
                accu[fullIndex] += accu[i+fullIndex];
            }
            __syncthreads();
        }
        if (fullIndex == 0)
            accuY = accu[0];
        __syncthreads();


        accu[fullIndex] = accuSQRY;
        __syncthreads();
        for (unsigned i = 32; i > 0; i >>= 1) {
            if (fullIndex < i) {
                accu[fullIndex] += accu[i+fullIndex];
            }
            __syncthreads();
        }
        if (fullIndex == 0)
            accuSQRY = accu[0];
        __syncthreads();


        accu[fullIndex] = accuXY;
        __syncthreads();
        for (unsigned i = 32; i > 0; i >>= 1) {
            if (fullIndex < i) {
                accu[fullIndex] += accu[i+fullIndex];
            }
            __syncthreads();
        }
        if (fullIndex == 0)
            accuXY = accu[0];
        __syncthreads();


        if (fullIndex == 0) {
            float det = accuY*accuY - (24*24)*accuSQRY;
            if (fabs(det) < 1e-9f) {
                offset = 0.0f;
                scale = 1.0f;
            } else {
                scale = (accuY*accuX  - (24*24)*accuXY)/det;
                offset = (-accuSQRY * accuX + accuY * accuXY)/det;

                offset = offset * 0.333f;
            }
        }
        __syncthreads();
    }
#endif


    float sum = 0.0f;

    for (unsigned ix = 0; ix < 3; ix++)
        for (unsigned iy = 0; iy < 3; iy++) {

            unsigned idx = (iy*8+threadIdx.y)*24 + (ix*8+threadIdx.x);

            float s = (threadIdx.x - 11.5f + ix*8.0f);
            float t = (threadIdx.y - 11.5f + iy*8.0f);
            float u = s * trackSize;
            float v = t * trackSize;

            float r = s*s+t*t;
            float lodOffset = r / 250.0f;

            float x = trackSurfaceToWorld[0*4+0] * u + trackSurfaceToWorld[0*4+1] * v + trackSurfaceToWorld[0*4+3];
            float y = trackSurfaceToWorld[1*4+0] * u + trackSurfaceToWorld[1*4+1] * v + trackSurfaceToWorld[1*4+3];
            float z = trackSurfaceToWorld[2*4+0] * u + trackSurfaceToWorld[2*4+1] * v + trackSurfaceToWorld[2*4+3];

            float texelX, texelY;
            project2D(kernelParams.patchAtlasPatchParams[patchAtlasIndex].worldToAtlasPatch, x, y, z, texelX, texelY);
            texelX += screenSpaceOffset[0] * kernelParams.patchAtlasPatchParams[patchAtlasIndex].sourceImageToAtlasPatchScaleOffsetX[0];
            texelY += screenSpaceOffset[1] * kernelParams.patchAtlasPatchParams[patchAtlasIndex].sourceImageToAtlasPatchScaleOffsetY[0];

            float4 color = tex2DLayeredLod(patchAtlas, texelX, texelY, kernelParams.patchAtlasPatchParams[patchAtlasIndex].layer, lodOffset);

            color.x *= color.w;
            color.y *= color.w;
            color.z *= color.w;

#if 1
            color.x = fmin(fmax(color.x * scale + offset, 0.0f), 1.0f);
            color.y = fmin(fmax(color.y * scale + offset, 0.0f), 1.0f);
            color.z = fmin(fmax(color.z * scale + offset, 0.0f), 1.0f);
#endif

#ifdef NEWOBS_OUTPUT_DEBUG_PAIRS
            if (blockIdx.x + kernelParams.blockOffset < maxDebugOutputPairs)
                outputColor(color.x,
                            color.y,
                            color.z,
                            kernelParams.debugOutput+24*24*2*(blockIdx.x + kernelParams.blockOffset) + 24*24,
                            idx);

#endif


            float dx = color.x - avg[0*24*24+idx];
            float dy = color.y - avg[1*24*24+idx];
            float dz = color.z - avg[2*24*24+idx];

            sum += dx*dx + dy*dy + dz*dz;
        }

    float *accu = avg;
    accu[fullIndex] = sum;

    __syncthreads();
    for (unsigned i = 32; i > 0; i >>= 1) {
        if (fullIndex < i) {
            accu[fullIndex] += accu[i+fullIndex];
        }
        __syncthreads();
    }

    if (fullIndex == 0) {
        obsCandidate->error = accu[0];
    }


}


