/*
    Structure from Motion with Deferred Feature Matching and Subset Bundle Adjustment
    Copyright (C) 2015 Andreas Ley <andy-ley@arcor.de>

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
*/

#include "../cudaInterface/PatchAtlasConstants.h"
#include "trackOptimization.h"

texture<uchar4, cudaTextureType2DLayered, cudaReadModeNormalizedFloat> patchAtlas;

/*
struct Camera
{
    float normProjectionViewMatrix[4*4];
};
*/


__constant__ PatchAtlasConstants patchAtlasConstants;

__device__ void project2D(float *P, float x, float y, float z, float &rx, float &ry)
{
    float tx = P[0*4+0] * x + P[0*4+1] * y + P[0*4+2] * z + P[0*4+3];
    float ty = P[1*4+0] * x + P[1*4+1] * y + P[1*4+2] * z + P[1*4+3];
    float tw = P[3*4+0] * x + P[3*4+1] * y + P[3*4+2] * z + P[3*4+3];

    rx = tx*(1.0f/tw);
    ry = ty*(1.0f/tw);
}


__device__ void rotate1(float *srcMat, float *dstMat, float alpha)
{
    float cosAlpha = cos(alpha);
    float sinAlpha = sin(alpha);

    dstMat[0*4+0] =  srcMat[0*4+0]*cosAlpha + srcMat[0*4+2]*sinAlpha;
    dstMat[0*4+1] =  srcMat[0*4+1];
    dstMat[0*4+2] = -srcMat[0*4+0]*sinAlpha + srcMat[0*4+2]*cosAlpha;

    dstMat[1*4+0] =  srcMat[1*4+0]*cosAlpha + srcMat[1*4+2]*sinAlpha;
    dstMat[1*4+1] =  srcMat[1*4+1];
    dstMat[1*4+2] = -srcMat[1*4+0]*sinAlpha + srcMat[1*4+2]*cosAlpha;

    dstMat[2*4+0] =  srcMat[2*4+0]*cosAlpha + srcMat[2*4+2]*sinAlpha;
    dstMat[2*4+1] =  srcMat[2*4+1];
    dstMat[2*4+2] = -srcMat[2*4+0]*sinAlpha + srcMat[2*4+2]*cosAlpha;
}

__device__ void rotate2(float *srcMat, float *dstMat, float alpha)
{
    float cosAlpha = cos(alpha);
    float sinAlpha = sin(alpha);

    dstMat[0*4+0] =  srcMat[0*4+0];
    dstMat[0*4+1] =  srcMat[0*4+1]*cosAlpha + srcMat[0*4+2]*sinAlpha;
    dstMat[0*4+2] = -srcMat[0*4+1]*sinAlpha + srcMat[0*4+2]*cosAlpha;

    dstMat[1*4+0] =  srcMat[1*4+0];
    dstMat[1*4+1] =  srcMat[1*4+1]*cosAlpha + srcMat[1*4+2]*sinAlpha;
    dstMat[1*4+2] = -srcMat[1*4+1]*sinAlpha + srcMat[1*4+2]*cosAlpha;

    dstMat[2*4+0] =  srcMat[2*4+0];
    dstMat[2*4+1] =  srcMat[2*4+1]*cosAlpha + srcMat[2*4+2]*sinAlpha;
    dstMat[2*4+2] = -srcMat[2*4+1]*sinAlpha + srcMat[2*4+2]*cosAlpha;
}

__device__ void orthonormalizeSub3x3(float *mat)
{
    float f;

    f = 1.0f/sqrtf(mat[0*4+0]*mat[0*4+0] + mat[0*4+1]*mat[0*4+1] + mat[0*4+2]*mat[0*4+2]);
    mat[0*4+0] *= f;
    mat[0*4+1] *= f;
    mat[0*4+2] *= f;

    f = mat[0*4+0]*mat[1*4+0] + mat[0*4+1]*mat[1*4+1] + mat[0*4+2]*mat[1*4+2];
    mat[1*4+0] -= f * mat[0*4+0];
    mat[1*4+1] -= f * mat[0*4+1];
    mat[1*4+2] -= f * mat[0*4+2];
    f = 1.0f/sqrtf(mat[1*4+0]*mat[1*4+0] + mat[1*4+1]*mat[1*4+1] + mat[1*4+2]*mat[1*4+2]);
    mat[1*4+0] *= f;
    mat[1*4+1] *= f;
    mat[1*4+2] *= f;

    f = mat[0*4+0]*mat[2*4+0] + mat[0*4+1]*mat[2*4+1] + mat[0*4+2]*mat[2*4+2];
    mat[2*4+0] -= f * mat[0*4+0];
    mat[2*4+1] -= f * mat[0*4+1];
    mat[2*4+2] -= f * mat[0*4+2];
    f = mat[1*4+0]*mat[2*4+0] + mat[1*4+1]*mat[2*4+1] + mat[1*4+2]*mat[2*4+2];
    mat[2*4+0] -= f * mat[1*4+0];
    mat[2*4+1] -= f * mat[1*4+1];
    mat[2*4+2] -= f * mat[1*4+2];
    f = 1.0f/sqrtf(mat[2*4+0]*mat[2*4+0] + mat[2*4+1]*mat[2*4+1] + mat[2*4+2]*mat[2*4+2]);
    mat[2*4+0] *= f;
    mat[2*4+1] *= f;
    mat[2*4+2] *= f;
}


__device__ void computeObservationVariance(float trackSize, float *trackSurfaceToWorld,
                                           TrackObservation *obs, unsigned numObservations,
                                           float *accu, unsigned fullIndex, PatchAtlasPatchParams *patchAtlasPatchParams)
{

    float varAccu = 0.0f;
    float fac = 1.0f / numObservations;
    float facSqr = fac*fac;

    for (unsigned ix = 0; ix < 3; ix++)
        for (unsigned iy = 0; iy < 3; iy++) {
            float s = (threadIdx.x - 11.5f + ix*8.0f);
            float t = (threadIdx.y - 11.5f + iy*8.0f);
            float u = s * trackSize;
            float v = t * trackSize;

            float r = s*s+t*t;
            float lodOffset = r / 250.0f;

            float x = trackSurfaceToWorld[0*4+0] * u + trackSurfaceToWorld[0*4+1] * v + trackSurfaceToWorld[0*4+3];
            float y = trackSurfaceToWorld[1*4+0] * u + trackSurfaceToWorld[1*4+1] * v + trackSurfaceToWorld[1*4+3];
            float z = trackSurfaceToWorld[2*4+0] * u + trackSurfaceToWorld[2*4+1] * v + trackSurfaceToWorld[2*4+3];

            float meanR = 0.0f;
            float meanG = 0.0f;
            float meanB = 0.0f;

            float varR = 0.0f;
            float varG = 0.0f;
            float varB = 0.0f;
            for (unsigned obsIndex = 0; obsIndex < numObservations; obsIndex++) {

                float texelX, texelY;
                project2D(patchAtlasPatchParams[obs[obsIndex].patchAtlasIndex].worldToAtlasPatch, x, y, z, texelX, texelY);
                texelX += obs[obsIndex].screenSpaceOffset[0] * patchAtlasPatchParams[obs[obsIndex].patchAtlasIndex].sourceImageToAtlasPatchScaleOffsetX[0];
                texelY += obs[obsIndex].screenSpaceOffset[1] * patchAtlasPatchParams[obs[obsIndex].patchAtlasIndex].sourceImageToAtlasPatchScaleOffsetY[0];

                float4 color = tex2DLayeredLod(patchAtlas, texelX, texelY, patchAtlasPatchParams[obs[obsIndex].patchAtlasIndex].layer, lodOffset);

                color.x *= color.w;
                color.y *= color.w;
                color.z *= color.w;

                meanR += color.x;
                meanG += color.y;
                meanB += color.z;

                varR += color.x * color.x;
                varG += color.y * color.y;
                varB += color.z * color.z;
            }
            varAccu += (varR + varG + varB) * fac -
                       (meanR * meanR + meanG * meanG + meanB * meanB) * facSqr;
        }


    accu[fullIndex] = varAccu;
    __syncthreads();

    if (fullIndex < 32) {
        float sum = accu[fullIndex] + accu[fullIndex + 32];
        // Use XOR mode to perform butterfly reduction
        #pragma unroll
        for (int i=16; i>=1; i/=2)
            sum += __shfl_xor(sum, i, 32);

        if (fullIndex == 0)
            accu[0] = sum;
    }
    __syncthreads();
}


__device__ void dumpProjection(float trackSize, float *trackSurfaceToWorld,
                                           TrackObservation *obs,
                                           unsigned fullIndex, PatchAtlasPatchParams *patchAtlasPatchParams, uint32_t *dst)
{
    for (unsigned ix = 0; ix < 3; ix++)
        for (unsigned iy = 0; iy < 3; iy++) {
            float s = (threadIdx.x - 11.5f + ix*8.0f);
            float t = (threadIdx.y - 11.5f + iy*8.0f);
            float u = s * trackSize;
            float v = t * trackSize;

            float radius = s*s+t*t;
            float lodOffset = radius / 250.0f;

            float x = trackSurfaceToWorld[0*4+0] * u + trackSurfaceToWorld[0*4+1] * v + trackSurfaceToWorld[0*4+3];
            float y = trackSurfaceToWorld[1*4+0] * u + trackSurfaceToWorld[1*4+1] * v + trackSurfaceToWorld[1*4+3];
            float z = trackSurfaceToWorld[2*4+0] * u + trackSurfaceToWorld[2*4+1] * v + trackSurfaceToWorld[2*4+3];

            float texelX, texelY;
            project2D(patchAtlasPatchParams[obs->patchAtlasIndex].worldToAtlasPatch, x, y, z, texelX, texelY);
            texelX += obs->screenSpaceOffset[0] * patchAtlasPatchParams[obs->patchAtlasIndex].sourceImageToAtlasPatchScaleOffsetX[0];
            texelY += obs->screenSpaceOffset[1] * patchAtlasPatchParams[obs->patchAtlasIndex].sourceImageToAtlasPatchScaleOffsetY[0];

            float4 color = tex2DLayeredLod(patchAtlas, texelX, texelY, patchAtlasPatchParams[obs->patchAtlasIndex].layer, lodOffset);

/*
            unsigned patchAtlasIndex = obs->patchAtlasIndex;
            //patchAtlasIndex = blockIdx.x;

            color = tex2DLayeredLod(patchAtlas,
                patchAtlasPatchParams[patchAtlasIndex].patchX * patchAtlasConstants.patchSize + (ix*24+threadIdx.x*3 - 3.5f)*patchAtlasConstants.halfTexelSize*2.0f,
                patchAtlasPatchParams[patchAtlasIndex].patchY * patchAtlasConstants.patchSize + (iy*24+threadIdx.y*3 - 3.5f)*patchAtlasConstants.halfTexelSize*2.0f,
                                           patchAtlasPatchParams[patchAtlasIndex].layer, 0.0f);


            if ((ix == 0) && (iy == 0)) {
                color.x = (((patchAtlasIndex >> 0) % 16) * 16) / 256.0f;
                color.y = (((patchAtlasIndex >> 4) % 16) * 16) / 256.0f;
                color.z = (((patchAtlasIndex >> 8) % 16) * 16) / 256.0f;
                color.w = 1.0f;
            }
            */
            float r = color.x * color.w;
            float g = color.y * color.w;
            float b = color.z * color.w;
            float m = fmax(r, fmax(g, b));

            unsigned w = min((int)(m * 255+1), 255);

            float rcpM = 255.0f*255.0f / w;

            uint32_t data = r * rcpM;
            data |= ((unsigned)(g * rcpM)) << 8;
            data |= ((unsigned)(b * rcpM)) << 16;
            data |= w << 24;

            dst[(iy*8+threadIdx.y) * 24 + ix*8+threadIdx.x] = data;
        }
}


extern "C" __global__ void optimizeTracks(OptimizeTracksKernalParams kernelParams)
{
    const unsigned fullIndex = threadIdx.x + threadIdx.y * 8;
    __shared__ float trackSize;
    __shared__ float trackSurfaceToWorld[3*4];
    __shared__ unsigned numObservations;
    __shared__ TrackObservation *observations;

    if (fullIndex == 0) {
        unsigned char *ptr = kernelParams.index[blockIdx.x + kernelParams.blockOffset];
        TrackHead *head = (TrackHead*)ptr;

        trackSize = head->size;
        numObservations = head->numObservations;
        for (unsigned i = 0; i < 3*4; i++)
            trackSurfaceToWorld[i] = head->trackSurfaceToWorld[i];

        observations = (TrackObservation*)(ptr + sizeof(TrackHead));
    }

    __syncthreads();
#ifdef TRACK_ALIGN_OUTPUT_DEBUG_DATA
    dumpProjection(trackSize, trackSurfaceToWorld, observations+0, fullIndex, kernelParams.patchAtlasPatchParams, kernelParams.debugOutput+24*24*2*(blockIdx.x + kernelParams.blockOffset));
    dumpProjection(trackSize, trackSurfaceToWorld, observations+1, fullIndex, kernelParams.patchAtlasPatchParams, kernelParams.debugOutput+24*24*2*(blockIdx.x + kernelParams.blockOffset)+24*24);
#endif
    __shared__ float accu[8*8];


    __shared__ float modifiedTrackSurfaceToWorld[3*4];
    __shared__ float rotGrad1;
    __shared__ float rotGrad2;

    for (unsigned iteration = 0; iteration < kernelParams.numIterations; iteration++) {
#if 1
        for (unsigned obs = 1; obs < numObservations; obs++) {
            __shared__ float offsetXOrig;
            __shared__ float offsetYOrig;

            __shared__ float offsetXGrad;
            __shared__ float offsetYGrad;

            if (fullIndex == 0) {
                offsetXOrig = observations[obs].screenSpaceOffset[0];
                offsetYOrig = observations[obs].screenSpaceOffset[1];

                observations[obs].screenSpaceOffset[0] = offsetXOrig + observations[obs].screenSize * 0.5f;
            }
            __syncthreads();
            computeObservationVariance(trackSize, trackSurfaceToWorld, observations, numObservations,
                                                accu, fullIndex, kernelParams.patchAtlasPatchParams);
            if (fullIndex == 0) {
                offsetXGrad = accu[0];
                observations[obs].screenSpaceOffset[0] = offsetXOrig - observations[obs].screenSize * 0.5f;
            }
            __syncthreads();
            computeObservationVariance(trackSize, trackSurfaceToWorld, observations, numObservations,
                                                accu, fullIndex, kernelParams.patchAtlasPatchParams);
            if (fullIndex == 0) {
                offsetXGrad = (offsetXGrad - accu[0]) / 1.0f;
                observations[obs].screenSpaceOffset[0] = offsetXOrig;

                observations[obs].screenSpaceOffset[1] = offsetYOrig + observations[obs].screenSize * 0.5f;
            }
            __syncthreads();
            computeObservationVariance(trackSize, trackSurfaceToWorld, observations, numObservations,
                                                accu, fullIndex, kernelParams.patchAtlasPatchParams);
            if (fullIndex == 0) {
                offsetYGrad = accu[0];
                observations[obs].screenSpaceOffset[1] = offsetYOrig - observations[obs].screenSize * 0.5f;
            }
            __syncthreads();
            computeObservationVariance(trackSize, trackSurfaceToWorld, observations, numObservations,
                                                accu, fullIndex, kernelParams.patchAtlasPatchParams);
            if (fullIndex == 0) {
                offsetYGrad = (offsetYGrad - accu[0]) / 1.0f;
                observations[obs].screenSpaceOffset[0] = offsetXOrig - offsetXGrad * 0.05f * observations[obs].screenSize / numObservations * 2;
                observations[obs].screenSpaceOffset[1] = offsetYOrig - offsetYGrad * 0.05f * observations[obs].screenSize / numObservations * 2;
            }
            __syncthreads();
        }
#endif

        if (fullIndex < 3*4) {
            modifiedTrackSurfaceToWorld[fullIndex] = trackSurfaceToWorld[fullIndex];
        }
        __syncthreads();

        if (fullIndex == 0) {
            rotate1(trackSurfaceToWorld, modifiedTrackSurfaceToWorld, 0.025f);
        }
        __syncthreads();
        computeObservationVariance(trackSize, modifiedTrackSurfaceToWorld, observations, numObservations,
                                            accu, fullIndex, kernelParams.patchAtlasPatchParams);
        if (fullIndex == 0) {
            rotGrad1 = accu[0];
            rotate1(trackSurfaceToWorld, modifiedTrackSurfaceToWorld, -0.025f);
        }
        __syncthreads();
        computeObservationVariance(trackSize, modifiedTrackSurfaceToWorld, observations, numObservations,
                                            accu, fullIndex, kernelParams.patchAtlasPatchParams);
        if (fullIndex == 0) {
            rotGrad1 = (rotGrad1 - accu[0]) / 0.05f;
        }



        __syncthreads();
        if (fullIndex < 3*4) {
            modifiedTrackSurfaceToWorld[fullIndex] = trackSurfaceToWorld[fullIndex];
        }
        __syncthreads();

        if (fullIndex == 0) {
            rotate2(trackSurfaceToWorld, modifiedTrackSurfaceToWorld, 0.025f);
        }
        __syncthreads();
        computeObservationVariance(trackSize, modifiedTrackSurfaceToWorld, observations, numObservations,
                                            accu, fullIndex, kernelParams.patchAtlasPatchParams);
        if (fullIndex == 0) {
            rotGrad2 = accu[0];
            rotate2(trackSurfaceToWorld, modifiedTrackSurfaceToWorld, -0.025f);
        }
        __syncthreads();
        computeObservationVariance(trackSize, modifiedTrackSurfaceToWorld, observations, numObservations,
                                            accu, fullIndex, kernelParams.patchAtlasPatchParams);
        if (fullIndex == 0) {
            rotGrad2 = (rotGrad2 - accu[0]) / 0.05f;
        }


        if (fullIndex == 0) {
            rotate1(trackSurfaceToWorld, modifiedTrackSurfaceToWorld, -rotGrad1*0.05f / numObservations * 2);
            rotate2(modifiedTrackSurfaceToWorld, trackSurfaceToWorld, -rotGrad2*0.05f / numObservations * 2);
            orthonormalizeSub3x3(trackSurfaceToWorld);
        }
    }

    __syncthreads();
    computeObservationVariance(trackSize, trackSurfaceToWorld, observations, numObservations,
                                        accu, fullIndex, kernelParams.patchAtlasPatchParams);


    if (fullIndex == 0) {
        unsigned char *ptr = kernelParams.index[blockIdx.x + kernelParams.blockOffset];
        TrackHead *head = (TrackHead*)ptr;
        head->remainingError = accu[0];

        for (unsigned i = 0; i < 3*4; i++)
            head->trackSurfaceToWorld[i] = trackSurfaceToWorld[i];
    }
}






const unsigned NumPreWarps = 11;
__constant__ float ExtractionPreWarpMatrices[4*4*NumPreWarps];

__device__ void mulMatrix4x4with4x4(float *dst, const float *src1, const float *src2, unsigned fullIndex)
{
    if (fullIndex < 16) {
        const unsigned x = fullIndex & 3;
        const unsigned y = fullIndex >> 2;
        dst[y*4+x] = src1[y*4+0] * src2[0*4+x] +
                     src1[y*4+1] * src2[1*4+x] +
                     src1[y*4+2] * src2[2*4+x] +
                     src1[y*4+3] * src2[3*4+x];
    }
}

__device__ void mulMatrix3x4with4x4(float *dst, const float *src1, const float *src2, unsigned fullIndex)
{
    if (fullIndex < 16) {
        const unsigned x = fullIndex & 3;
        const unsigned y = fullIndex >> 2;
        if (y < 3)
            dst[y*4+x] = src1[y*4+0] * src2[0*4+x] +
                         src1[y*4+1] * src2[1*4+x] +
                         src1[y*4+2] * src2[2*4+x] +
                         src1[y*4+3] * src2[3*4+x];
        else
            dst[3*4+x] = src2[3*4+x];
    }
}

__device__ void mulMatrixTranslationwith4x4(float *dst, const float *lastCol, const float *src2, unsigned fullIndex)
{
    if (fullIndex < 16) {
        const unsigned x = fullIndex & 3;
        const unsigned y = fullIndex >> 2;
        dst[y*4+x] = src2[y*4+x] +
                     lastCol[y] * src2[3*4+x];
    }
}

__device__ void projectSurface2D(float *P, float x, float y, float &rx, float &ry)
{
    float tx = P[0*4+0] * x + P[0*4+1] * y + P[0*4+3];
    float ty = P[1*4+0] * x + P[1*4+1] * y + P[1*4+3];
    float tw = P[3*4+0] * x + P[3*4+1] * y + P[3*4+3];

    rx = tx*(1.0f/tw);
    ry = ty*(1.0f/tw);
}

__constant__ float blobbResponseBig[8*8] = {
-0.1f, -0.1f, 0.1f, 0.1f, 0.1f, 0.1f, -0.1f, -0.1f,
-0.3f, -0.3f, 0.3f, 0.3f, 0.3f, 0.3f, -0.3f, -0.3f,
-0.8f, -0.8f, 0.8f, 0.8f, 0.8f, 0.8f, -0.8f, -0.8f,
-1.0f, -1.0f, 1.0f, 1.0f, 1.0f, 1.0f, -1.0f, -1.0f,
-1.0f, -1.0f, 1.0f, 1.0f, 1.0f, 1.0f, -1.0f, -1.0f,
-0.8f, -0.8f, 0.8f, 0.8f, 0.8f, 0.8f, -0.8f, -0.8f,
-0.3f, -0.3f, 0.3f, 0.3f, 0.3f, 0.3f, -0.3f, -0.3f,
-0.1f, -0.1f, 0.1f, 0.1f, 0.1f, 0.1f, -0.1f, -0.1f
};

__device__ void evalPatchLocationScale(float &cosAngle, float &sinAngle,
                                       float x, float y, float scale,
                                       float *matrix, float layer,
                                       float *accuX, float *accuY,
                                       float &dst,
                                       const unsigned fullIndex)
{
    float u = (threadIdx.x - 3.5f) * scale;
    float v = (threadIdx.y - 3.5f) * scale;

    float texelX, texelY;
    projectSurface2D(matrix, x + cosAngle * u - sinAngle * v,
                             y + sinAngle * u + cosAngle * v,
                             texelX,
                             texelY);

    float4 srcColor = tex2DLayeredLod(patchAtlas, texelX, texelY, layer, 0.0f);

    float value = (srcColor.x+srcColor.y+srcColor.z)*srcColor.w;
    accuX[fullIndex] = value * blobbResponseBig[threadIdx.x + threadIdx.y*8];
    accuY[fullIndex] = value * blobbResponseBig[threadIdx.x*8 + threadIdx.y];


    for (unsigned i = 8*4; i > 0; i/=2) {
        __syncthreads();
        if (fullIndex < i) {
            accuX[fullIndex] += accuX[fullIndex + i];
            accuY[fullIndex] += accuY[fullIndex + i];
        }
    }

    __syncthreads();

    if (fullIndex == 0) {
        dst = fabs(accuX[0]*accuY[0]);
    }

    __syncthreads();
}


__device__ void evalPatchAngle(float &currentX, float &currentY, float &currentScale,
                               float *accuX, float *accuY,
                               float *matrix, float layer,
                               float &dstAngle, float &dstCosAngle, float &dstSinAngle,
                               const unsigned fullIndex)
{
    {
        float angle = fullIndex * (float)M_PI / (8*8);

        float cosAngle = cosf(angle);
        float sinAngle = sinf(angle);

        float diff;
        {
            float4 sample;
            {
                float texelX, texelY;
                projectSurface2D(matrix, currentX + 7.0f * cosAngle * currentScale,
                                         currentY + 7.0f * sinAngle * currentScale,
                                         texelX,
                                         texelY);

                sample = tex2DLayeredLod(patchAtlas, texelX, texelY, layer, 1.2f);
            }
            diff = (sample.x+sample.y+sample.z)*sample.w;

            {
                float texelX, texelY;
                projectSurface2D(matrix, currentX + 5.0f * cosAngle * currentScale,
                                         currentY + 5.0f * sinAngle * currentScale,
                                         texelX,
                                         texelY);

                sample = tex2DLayeredLod(patchAtlas, texelX, texelY, layer, 1.2f);
            }
            diff += (sample.x+sample.y+sample.z)*sample.w;

            {
                float texelX, texelY;
                projectSurface2D(matrix, currentX - 7.0f * cosAngle * currentScale,
                                         currentY - 7.0f * sinAngle * currentScale,
                                         texelX,
                                         texelY);

                sample = tex2DLayeredLod(patchAtlas, texelX, texelY, layer, 1.2f);
            }
            diff -= (sample.x+sample.y+sample.z)*sample.w;

            {
                float texelX, texelY;
                projectSurface2D(matrix, currentX - 5.0f * cosAngle * currentScale,
                                         currentY - 5.0f * sinAngle * currentScale,
                                         texelX,
                                         texelY);

                sample = tex2DLayeredLod(patchAtlas, texelX, texelY, layer, 1.2f);
            }
            diff -= (sample.x+sample.y+sample.z)*sample.w;
        }
        accuX[fullIndex] = diff * cosAngle;
        accuY[fullIndex] = diff * sinAngle;
    }

    for (unsigned i = 8*4; i > 0; i/=2) {
        __syncthreads();
        if (fullIndex < i) {
            accuX[fullIndex] += accuX[fullIndex + i];
            accuY[fullIndex] += accuY[fullIndex + i];
        }
    }
    __syncthreads();

    if (fullIndex == 0) {
        dstAngle = atan2f(accuY[0], accuX[0]);
        dstCosAngle = cosf(dstAngle);
        dstSinAngle = sinf(dstAngle);
    }
    __syncthreads();

}


extern "C" __global__ void extractPreWarpedProjections(ExtractProjectionsKernelParams kernelParams)
{
    const unsigned trackIndex = blockIdx.x + kernelParams.blockOffset;

    const unsigned fullIndex = threadIdx.x + threadIdx.y*8;

    __shared__ float trackSize;
    __shared__ float trackSurfaceToWorld[3*4];
    __shared__ unsigned numObservations;
    __shared__ TrackObservation *observations;
    __shared__ uint32_t *dstPtr;

    if (fullIndex == 0) {
        unsigned char *ptr = kernelParams.index[blockIdx.x + kernelParams.blockOffset];
        TrackHead *head = (TrackHead*)ptr;

        trackSize = head->size;
        numObservations = head->numObservations;
        for (unsigned i = 0; i < 3*4; i++)
            trackSurfaceToWorld[i] = head->trackSurfaceToWorld[i];

        observations = (TrackObservation*)(ptr + sizeof(TrackHead));

        dstPtr = kernelParams.dst + trackIndex*16*16*11;

    }

    __syncthreads();

    __shared__ float warpedProjectionMatrix[4*4];
    __shared__ float fullProjectionMatrix[4*4];

    for (unsigned warpIndex = 0; warpIndex < NumPreWarps; warpIndex++) {

        mulMatrix3x4with4x4(fullProjectionMatrix, trackSurfaceToWorld, ExtractionPreWarpMatrices+warpIndex*4*4, fullIndex);
        __syncthreads();

        unsigned obsIndex = 0;
        TrackObservation *obs = observations + obsIndex;


        mulMatrix4x4with4x4(warpedProjectionMatrix, kernelParams.patchAtlasPatchParams[obs->patchAtlasIndex].worldToAtlasPatch,
                            fullProjectionMatrix, fullIndex);

        __shared__ float translation[4];
        if (fullIndex == 0) {
            translation[0] = obs->screenSpaceOffset[0] * kernelParams.patchAtlasPatchParams[obs->patchAtlasIndex].sourceImageToAtlasPatchScaleOffsetX[0];
            translation[1] = obs->screenSpaceOffset[1] * kernelParams.patchAtlasPatchParams[obs->patchAtlasIndex].sourceImageToAtlasPatchScaleOffsetY[0];
            translation[2] = 0.0f;
            translation[3] = 0.0f;
        }
        __syncthreads();
        mulMatrixTranslationwith4x4(fullProjectionMatrix, translation, warpedProjectionMatrix, fullIndex);


        __shared__ float angle;
        __shared__ float cosAngle;
        __shared__ float sinAngle;
        __shared__ float offsetX;
        __shared__ float offsetY;
        __shared__ float scale;
        __shared__ float layer;
        if (fullIndex == 0) {
            offsetX = 0.0f;
            offsetY = 0.0f;
            /*
            cosAngle = 1.0f;
            sinAngle = 0.0f;
            */
            scale = trackSize;
            layer = kernelParams.patchAtlasPatchParams[obs->patchAtlasIndex].layer;
        }
        __syncthreads();
        __shared__ float accuX[8*8];
        {
            __shared__ float accuY[8*8];

            __shared__ float samples[3*2];
            for (unsigned iter = 0; iter < 100; iter++) {
                evalPatchAngle(offsetX, offsetY, scale, accuX, accuY, fullProjectionMatrix, layer,
                               angle, cosAngle, sinAngle, fullIndex);

                evalPatchLocationScale(cosAngle, sinAngle,
                                       offsetX-1.0f*scale, offsetY+0.0f*scale, scale+0.0f,
                                       fullProjectionMatrix, layer, accuX, accuY,
                                       samples[0*2+0], fullIndex);
                evalPatchLocationScale(cosAngle, sinAngle,
                                       offsetX+1.0f*scale, offsetY+0.0f*scale, scale+0.0f,
                                       fullProjectionMatrix, layer, accuX, accuY,
                                       samples[0*2+1], fullIndex);

                evalPatchLocationScale(cosAngle, sinAngle,
                                       offsetX+0.0f*scale, offsetY-1.0f*scale, scale+0.0f,
                                       fullProjectionMatrix, layer, accuX, accuY,
                                       samples[1*2+0], fullIndex);
                evalPatchLocationScale(cosAngle, sinAngle,
                                       offsetX+0.0f*scale, offsetY+1.0f*scale, scale+0.0f,
                                       fullProjectionMatrix, layer, accuX, accuY,
                                       samples[1*2+1], fullIndex);

                evalPatchLocationScale(cosAngle, sinAngle,
                                       offsetX+0.0f*scale, offsetY+0.0f*scale, scale/1.05f,
                                       fullProjectionMatrix, layer, accuX, accuY,
                                       samples[2*2+0], fullIndex);
                evalPatchLocationScale(cosAngle, sinAngle,
                                       offsetX+0.0f*scale, offsetY+0.0f*scale, scale*1.05f,
                                       fullProjectionMatrix, layer, accuX, accuY,
                                       samples[2*2+1], fullIndex);

                if (fullIndex == 0) {
                    float Dx = samples[0*2+1] - samples[0*2+0];
                    float Dy = samples[1*2+1] - samples[1*2+0];
                    float Ds = samples[2*2+1] - samples[2*2+0];
/*
                    Dx = fmin(fmax(Dx, -10.0f), 10.0f);
                    Dy = fmin(fmax(Dy, -10.0f), 10.0f);
                    Ds = fmin(fmax(Ds, -2.0f), 2.0f);
*/
                    offsetX += Dx * 0.001f * scale;
                    offsetY += Dy * 0.001f * scale;
                    scale += Ds * 0.001f * scale;
                }
            }
            evalPatchAngle(offsetX, offsetY, scale, accuX, accuY, fullProjectionMatrix, layer,
                           angle, cosAngle, sinAngle, fullIndex);
        }

        __syncthreads();

        accuX[fullIndex] = 0.0f;
        for (unsigned ix = 0; ix < 2; ix++)
            for (unsigned iy = 0; iy < 2; iy++) {
                float4 colors;

                float u = (threadIdx.x - 7.5f + ix*8.0f) * scale;
                float v = (threadIdx.y - 7.5f + iy*8.0f) * scale;

                float texelX, texelY;
                //projectSurface2D(fullProjectionMatrix, u, v, texelX, texelY);
                projectSurface2D(fullProjectionMatrix,
                                         offsetX + cosAngle * u - sinAngle * v,
                                         offsetY + sinAngle * u + cosAngle * v,
                                         texelX,
                                         texelY);

                colors = tex2DLayeredLod(patchAtlas, texelX, texelY, kernelParams.patchAtlasPatchParams[obs->patchAtlasIndex].layer, 0.0f);



                float lum = (colors.x + colors.y + colors.z) * colors.w;

                accuX[fullIndex] += lum;
            }


        for (unsigned i = 8*4; i > 0; i/=2) {
            __syncthreads();
            if (fullIndex < i) {
                accuX[fullIndex] += accuX[fullIndex + i];
            }
        }
        __syncthreads();
        __shared__ float lumScale;
        if (fullIndex == 0) {
            lumScale = (0.2f*16.0f*16.0f) / fmax(accuX[0], 0.01f);
        }
        __syncthreads();


        for (unsigned ix = 0; ix < 2; ix++)
            for (unsigned iy = 0; iy < 2; iy++) {
                float4 colors1;
                {
                    float u = (threadIdx.x - 7.5f + ix*8.0f) * scale;
                    float v = -(threadIdx.y - 7.5f + iy*8.0f) * scale;

                    float texelX, texelY;
                    //projectSurface2D(fullProjectionMatrix, u, v, texelX, texelY);
                    projectSurface2D(fullProjectionMatrix,
                                             offsetX + cosAngle * u - sinAngle * v,
                                             offsetY + sinAngle * u + cosAngle * v,
                                             texelX,
                                             texelY);

                    colors1 = tex2DLayeredLod(patchAtlas, texelX, texelY, kernelParams.patchAtlasPatchParams[obs->patchAtlasIndex].layer, 0.0f);

                    float fac = colors1.w * lumScale;

                    colors1.x *= fac;
                    colors1.y *= fac;
                    colors1.z *= fac;
                }
                float r = colors1.x;
                float g = colors1.y;
                float b = colors1.z;
                float m = fmax(r, fmax(g, b));

                unsigned w = min((int)(m * 255+1), 255);

                float rcpM = 255.0f*255.0f / w;

                uint32_t data = r * rcpM;
                data |= ((unsigned)(g * rcpM)) << 8;
                data |= ((unsigned)(b * rcpM)) << 16;
                data |= w << 24;

                dstPtr[warpIndex*16*16 + (iy*8+threadIdx.y) * 16 + ix*8+threadIdx.x] = data;
            }
    }
}


