/*
    Structure from Motion with Deferred Feature Matching and Subset Bundle Adjustment
    Copyright (C) 2015 Andreas Ley <andy-ley@arcor.de>

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
*/

#undef _GLIBCXX_USE_INT128

#include "initialTrackOptimizer.h"

texture<uchar4, 2, cudaReadModeNormalizedFloat> sourceImagePyramid1;
texture<uchar4, 2, cudaReadModeNormalizedFloat> sourceImagePyramid2;


__constant__ ConstantParams optimizeInitialTrackParams;


__device__ void project2D(float *P, float x, float y, float z, float &rx, float &ry)
{
    float tx = P[0*4+0] * x + P[0*4+1] * y + P[0*4+2] * z + P[0*4+3];
    float ty = P[1*4+0] * x + P[1*4+1] * y + P[1*4+2] * z + P[1*4+3];
    float tw = P[3*4+0] * x + P[3*4+1] * y + P[3*4+2] * z + P[3*4+3];

    rx = tx*(1.0f/tw);
    ry = ty*(1.0f/tw);
}

__device__ void computeAvgLumImage1(TrackData *trackData, float *accu, unsigned fullIndex)
{
    float lum = 0.0f;
    for (unsigned ix = 0; ix < 2; ix++)
        for (unsigned iy = 0; iy < 2; iy++) {
            float u = (threadIdx.x - 7.5f + ix*8.0f + trackData->epipolarOffsetHalf) * trackData->size;
            float v = (threadIdx.y - 7.5f + iy*8.0f) * trackData->size;
            float x = trackData->surfaceToWorld[0*4+0] * u + trackData->surfaceToWorld[0*4+1] * v + trackData->surfaceToWorld[0*4+3];
            float y = trackData->surfaceToWorld[1*4+0] * u + trackData->surfaceToWorld[1*4+1] * v + trackData->surfaceToWorld[1*4+3];
            float z = trackData->surfaceToWorld[2*4+0] * u + trackData->surfaceToWorld[2*4+1] * v + trackData->surfaceToWorld[2*4+3];

            float texelX, texelY;
            project2D(optimizeInitialTrackParams.P1norm, x, y, z, texelX, texelY);

            float4 color = tex2DLod(sourceImagePyramid1, texelX, texelY, trackData->lod[0]);
            lum += (color.x+color.y+color.z)*color.w;
        }
    accu[fullIndex] = lum;
    for (unsigned i = 8*4; i > 0; i/=2) {
        __syncthreads();
        if (fullIndex < i) {
            accu[fullIndex] += accu[fullIndex + i];
        }
    }
    __syncthreads();
}

__device__ void computeAvgLumImage2(TrackData *trackData, float *accu, unsigned fullIndex)
{
    float lum = 0.0f;
    for (unsigned ix = 0; ix < 2; ix++)
        for (unsigned iy = 0; iy < 2; iy++) {
            float u = (threadIdx.x - 7.5f + ix*8.0f - trackData->epipolarOffsetHalf) * trackData->size;
            float v = (threadIdx.y - 7.5f + iy*8.0f) * trackData->size;
            float x = trackData->surfaceToWorld[0*4+0] * u + trackData->surfaceToWorld[0*4+1] * v + trackData->surfaceToWorld[0*4+3];
            float y = trackData->surfaceToWorld[1*4+0] * u + trackData->surfaceToWorld[1*4+1] * v + trackData->surfaceToWorld[1*4+3];
            float z = trackData->surfaceToWorld[2*4+0] * u + trackData->surfaceToWorld[2*4+1] * v + trackData->surfaceToWorld[2*4+3];


            float texelX, texelY;
            project2D(optimizeInitialTrackParams.P2norm, x, y, z, texelX, texelY);

            float4 color = tex2DLod(sourceImagePyramid2, texelX, texelY, trackData->lod[1]);
            lum += (color.x+color.y+color.z)*color.w;
        }
    accu[fullIndex] = lum;
    for (unsigned i = 8*4; i > 0; i/=2) {
        __syncthreads();
        if (fullIndex < i) {
            accu[fullIndex] += accu[fullIndex + i];
        }
    }
    __syncthreads();
}


__device__ void computeAvgDifference(TrackData *trackData, float *accu, unsigned fullIndex, float lumScale1, float lumScale2)
{
    float diff = 0.0f;
    for (unsigned ix = 0; ix < 4; ix++)
        for (unsigned iy = 0; iy < 4; iy++) {
            float4 colors1, colors2;
            {
                float u = (threadIdx.x - 15.5f + ix*8.0f + trackData->epipolarOffsetHalf) * trackData->size;
                float v = (threadIdx.y - 15.5f + iy*8.0f) * trackData->size;

                float x = trackData->surfaceToWorld[0*4+0] * u + trackData->surfaceToWorld[0*4+1] * v + trackData->surfaceToWorld[0*4+3];
                float y = trackData->surfaceToWorld[1*4+0] * u + trackData->surfaceToWorld[1*4+1] * v + trackData->surfaceToWorld[1*4+3];
                float z = trackData->surfaceToWorld[2*4+0] * u + trackData->surfaceToWorld[2*4+1] * v + trackData->surfaceToWorld[2*4+3];


                float texelX, texelY;
                project2D(optimizeInitialTrackParams.P1norm, x, y, z, texelX, texelY);
                colors1 = tex2DLod(sourceImagePyramid1, texelX, texelY, trackData->lod[0]);
                colors1.w *= lumScale1;
            }
            {
                float u = (threadIdx.x - 15.5f + ix*8.0f - trackData->epipolarOffsetHalf) * trackData->size;
                float v = (threadIdx.y - 15.5f + iy*8.0f) * trackData->size;

                float x = trackData->surfaceToWorld[0*4+0] * u + trackData->surfaceToWorld[0*4+1] * v + trackData->surfaceToWorld[0*4+3];
                float y = trackData->surfaceToWorld[1*4+0] * u + trackData->surfaceToWorld[1*4+1] * v + trackData->surfaceToWorld[1*4+3];
                float z = trackData->surfaceToWorld[2*4+0] * u + trackData->surfaceToWorld[2*4+1] * v + trackData->surfaceToWorld[2*4+3];


                float texelX, texelY;
                project2D(optimizeInitialTrackParams.P2norm, x, y, z, texelX, texelY);
                colors2 = tex2DLod(sourceImagePyramid2, texelX, texelY, trackData->lod[1]);
                colors2.w *= lumScale2;
            }

            {
                float d = colors1.x * colors1.w - colors2.x * colors2.w;
                diff += d*d;
            }
            {
                float d = colors1.y * colors1.w - colors2.y * colors2.w;
                diff += d*d;
            }
            {
                float d = colors1.z * colors1.w - colors2.z * colors2.w;
                diff += d*d;
            }
        }
    accu[fullIndex] = diff;
    for (unsigned i = 8*4; i > 0; i/=2) {
        __syncthreads();
        if (fullIndex < i) {
            accu[fullIndex] += accu[fullIndex + i];
        }
    }
    __syncthreads();
}


__device__ void copyState(const TrackData &src, TrackData &dst, unsigned fullIndex)
{
    if (fullIndex < sizeof(TrackData)/4) {
        const float *fptrSrc = (float*)&src;
        float *fptr = (float*)&dst;
        fptr[fullIndex] = fptrSrc[fullIndex];
    }
}

__device__ void rotate1(float *srcMat, float *dstMat, float alpha)
{
    float cosAlpha = cos(alpha);
    float sinAlpha = sin(alpha);

    dstMat[0*4+0] =  srcMat[0*4+0]*cosAlpha + srcMat[0*4+2]*sinAlpha;
    dstMat[0*4+1] =  srcMat[0*4+1];
    dstMat[0*4+2] = -srcMat[0*4+0]*sinAlpha + srcMat[0*4+2]*cosAlpha;

    dstMat[1*4+0] =  srcMat[1*4+0]*cosAlpha + srcMat[1*4+2]*sinAlpha;
    dstMat[1*4+1] =  srcMat[1*4+1];
    dstMat[1*4+2] = -srcMat[1*4+0]*sinAlpha + srcMat[1*4+2]*cosAlpha;

    dstMat[2*4+0] =  srcMat[2*4+0]*cosAlpha + srcMat[2*4+2]*sinAlpha;
    dstMat[2*4+1] =  srcMat[2*4+1];
    dstMat[2*4+2] = -srcMat[2*4+0]*sinAlpha + srcMat[2*4+2]*cosAlpha;
}

__device__ void rotate2(float *srcMat, float *dstMat, float alpha)
{
    float cosAlpha = cos(alpha);
    float sinAlpha = sin(alpha);

    dstMat[0*4+0] =  srcMat[0*4+0];
    dstMat[0*4+1] =  srcMat[0*4+1]*cosAlpha + srcMat[0*4+2]*sinAlpha;
    dstMat[0*4+2] = -srcMat[0*4+1]*sinAlpha + srcMat[0*4+2]*cosAlpha;

    dstMat[1*4+0] =  srcMat[1*4+0];
    dstMat[1*4+1] =  srcMat[1*4+1]*cosAlpha + srcMat[1*4+2]*sinAlpha;
    dstMat[1*4+2] = -srcMat[1*4+1]*sinAlpha + srcMat[1*4+2]*cosAlpha;

    dstMat[2*4+0] =  srcMat[2*4+0];
    dstMat[2*4+1] =  srcMat[2*4+1]*cosAlpha + srcMat[2*4+2]*sinAlpha;
    dstMat[2*4+2] = -srcMat[2*4+1]*sinAlpha + srcMat[2*4+2]*cosAlpha;
}

__device__ void orthonormalize(float *mat)
{
    float f;

    f = 1.0f/sqrtf(mat[0*4+0]*mat[0*4+0] + mat[0*4+1]*mat[0*4+1] + mat[0*4+2]*mat[0*4+2]);
    mat[0*4+0] *= f;
    mat[0*4+1] *= f;
    mat[0*4+2] *= f;

    f = mat[0*4+0]*mat[1*4+0] + mat[0*4+1]*mat[1*4+1] + mat[0*4+2]*mat[1*4+2];
    mat[1*4+0] -= f * mat[0*4+0];
    mat[1*4+1] -= f * mat[0*4+1];
    mat[1*4+2] -= f * mat[0*4+2];
    f = 1.0f/sqrtf(mat[1*4+0]*mat[1*4+0] + mat[1*4+1]*mat[1*4+1] + mat[1*4+2]*mat[1*4+2]);
    mat[1*4+0] *= f;
    mat[1*4+1] *= f;
    mat[1*4+2] *= f;

    f = mat[0*4+0]*mat[2*4+0] + mat[0*4+1]*mat[2*4+1] + mat[0*4+2]*mat[2*4+2];
    mat[2*4+0] -= f * mat[0*4+0];
    mat[2*4+1] -= f * mat[0*4+1];
    mat[2*4+2] -= f * mat[0*4+2];
    f = mat[1*4+0]*mat[2*4+0] + mat[1*4+1]*mat[2*4+1] + mat[1*4+2]*mat[2*4+2];
    mat[2*4+0] -= f * mat[1*4+0];
    mat[2*4+1] -= f * mat[1*4+1];
    mat[2*4+2] -= f * mat[1*4+2];
    f = 1.0f/sqrtf(mat[2*4+0]*mat[2*4+0] + mat[2*4+1]*mat[2*4+1] + mat[2*4+2]*mat[2*4+2]);
    mat[2*4+0] *= f;
    mat[2*4+1] *= f;
    mat[2*4+2] *= f;
}

extern "C" __global__ void optimizeInitialTrack(TrackData *trackData)
{
    const unsigned trackIndex = blockIdx.x + gridDim.x * blockIdx.y;
    if (trackIndex >= optimizeInitialTrackParams.numTracks)
        return;

    const unsigned fullIndex = threadIdx.x + threadIdx.y*8;

    __shared__ TrackData currentState;
    if (fullIndex < sizeof(currentState)/4) {
        const float *fptrSrc = (float*)&trackData[trackIndex];
        float *fptr = (float*)&currentState;
        fptr[fullIndex] = fptrSrc[fullIndex];
    }
    __syncthreads();

    __shared__ TrackData modifiedState;

    __shared__ float accu[8*8];
    __shared__ float lumScale1;
    __shared__ float lumScale2;
    computeAvgLumImage1(&currentState, accu, fullIndex);
    if (fullIndex == 0) {
        lumScale1 = (0.2f * 16.0f * 16.0f) / fmax(accu[0], 0.01f);
    }
    __syncthreads();
    computeAvgLumImage2(&currentState, accu, fullIndex);
    if (fullIndex == 0) {
        lumScale2 = (0.2f * 16.0f * 16.0f) / fmax(accu[0], 0.01f);
    }
    __syncthreads();

    for (unsigned iterations = 0; iterations < 100; iterations++) {
        __shared__ float depthGrad;
        __shared__ float epipolarOffsetGrad;
        __shared__ float angle1Grad;
        __shared__ float angle2Grad;

        {
            copyState(currentState, modifiedState, fullIndex);
            if (fullIndex == 0) {
                modifiedState.epipolarOffsetHalf += 0.1f;
            }
            __syncthreads();
            computeAvgDifference(&modifiedState, accu, fullIndex, lumScale1, lumScale2);
            if (fullIndex == 0) {
                epipolarOffsetGrad = accu[0];
                modifiedState.epipolarOffsetHalf -= 0.2f;
            }
            __syncthreads();
            computeAvgDifference(&modifiedState, accu, fullIndex, lumScale1, lumScale2);
            if (fullIndex == 0) {
                epipolarOffsetGrad = (epipolarOffsetGrad - accu[0]) / 0.2f;
            }
            __syncthreads();
        }

        {
            copyState(currentState, modifiedState, fullIndex);
            if (fullIndex == 0) {
                modifiedState.surfaceToWorld[0*4+3] += currentState.surfaceToWorld[0*4+2] * currentState.size * 0.1f;
                modifiedState.surfaceToWorld[1*4+3] += currentState.surfaceToWorld[1*4+2] * currentState.size * 0.1f;
                modifiedState.surfaceToWorld[2*4+3] += currentState.surfaceToWorld[2*4+2] * currentState.size * 0.1f;
            }
            __syncthreads();
            computeAvgDifference(&modifiedState, accu, fullIndex, lumScale1, lumScale2);
            if (fullIndex == 0) {
                depthGrad = accu[0];
                modifiedState.surfaceToWorld[0*4+3] -= currentState.surfaceToWorld[0*4+2] * currentState.size * 0.2f;
                modifiedState.surfaceToWorld[1*4+3] -= currentState.surfaceToWorld[1*4+2] * currentState.size * 0.2f;
                modifiedState.surfaceToWorld[2*4+3] -= currentState.surfaceToWorld[2*4+2] * currentState.size * 0.2f;
            }
            __syncthreads();
            computeAvgDifference(&modifiedState, accu, fullIndex, lumScale1, lumScale2);
            if (fullIndex == 0) {
                depthGrad = (depthGrad - accu[0]) / 0.2f;
            }
            __syncthreads();
        }

        {
            copyState(currentState, modifiedState, fullIndex);
            if (fullIndex == 0) {
                rotate1(currentState.surfaceToWorld, modifiedState.surfaceToWorld, 0.05f);
            }
            __syncthreads();
            computeAvgDifference(&modifiedState, accu, fullIndex, lumScale1, lumScale2);
            if (fullIndex == 0) {
                angle1Grad = accu[0];
                rotate1(currentState.surfaceToWorld, modifiedState.surfaceToWorld, -0.05f);
            }
            __syncthreads();
            computeAvgDifference(&modifiedState, accu, fullIndex, lumScale1, lumScale2);
            if (fullIndex == 0) {
                angle1Grad = (angle1Grad - accu[0]) / 0.1f;
            }
            __syncthreads();
        }
        {
            copyState(currentState, modifiedState, fullIndex);
            if (fullIndex == 0) {
                rotate2(currentState.surfaceToWorld, modifiedState.surfaceToWorld, 0.05f);
            }
            __syncthreads();
            computeAvgDifference(&modifiedState, accu, fullIndex, lumScale1, lumScale2);
            if (fullIndex == 0) {
                angle2Grad = accu[0];
                rotate2(currentState.surfaceToWorld, modifiedState.surfaceToWorld, -0.05f);
            }
            __syncthreads();
            computeAvgDifference(&modifiedState, accu, fullIndex, lumScale1, lumScale2);
            if (fullIndex == 0) {
                angle2Grad = (angle2Grad - accu[0]) / 0.1f;
            }
            __syncthreads();
        }

        if (fullIndex == 0) {
            currentState.epipolarOffsetHalf -= epipolarOffsetGrad * 0.4f;
            currentState.surfaceToWorld[0*4+3] -= currentState.surfaceToWorld[0*4+2] * currentState.size * 0.4f * depthGrad;
            currentState.surfaceToWorld[1*4+3] -= currentState.surfaceToWorld[1*4+2] * currentState.size * 0.4f * depthGrad;
            currentState.surfaceToWorld[2*4+3] -= currentState.surfaceToWorld[2*4+2] * currentState.size * 0.4f * depthGrad;

            rotate1(currentState.surfaceToWorld, modifiedState.surfaceToWorld, -0.4f * angle1Grad);
            rotate2(modifiedState.surfaceToWorld, currentState.surfaceToWorld, -0.4f * angle2Grad);

            orthonormalize(currentState.surfaceToWorld);
        }
        __syncthreads();
    }

    if (fullIndex < sizeof(currentState)/4) {
        float *fptrDst = (float*)&trackData[trackIndex];
        const float *fptr = (float*)&currentState;
        fptrDst[fullIndex] = fptr[fullIndex];
    }
}


extern "C" __global__ void computeErrors(TrackData *trackData, float *dst)
{
    const unsigned trackIndex = blockIdx.x + gridDim.x * blockIdx.y;
    if (trackIndex >= optimizeInitialTrackParams.numTracks)
        return;

    const unsigned fullIndex = threadIdx.x + threadIdx.y*8;

    __shared__ TrackData currentState;
    if (fullIndex < sizeof(currentState)/4) {
        const float *fptrSrc = (float*)&trackData[trackIndex];
        float *fptr = (float*)&currentState;
        fptr[fullIndex] = fptrSrc[fullIndex];
    }
    __syncthreads();

    __shared__ float accu[8*8];
    __shared__ float lumScale1;
    __shared__ float lumScale2;
    computeAvgLumImage1(&currentState, accu, fullIndex);
    if (fullIndex == 0) {
        lumScale1 = (0.2f * 16.0f * 16.0f) / fmax(accu[0], 0.01f);
    }
    __syncthreads();
    computeAvgLumImage2(&currentState, accu, fullIndex);
    if (fullIndex == 0) {
        lumScale2 = (0.2f * 16.0f * 16.0f) / fmax(accu[0], 0.01f);
    }
    __syncthreads();

    computeAvgDifference(&currentState, accu, fullIndex, lumScale1, lumScale2);

    if (fullIndex == 0) {
        dst[trackIndex] = accu[0];
    }
}

extern "C" __global__ void extractProjections(ExtractProjectionsKernelParams kernelParams)
{
    const unsigned trackIndex = blockIdx.x + gridDim.x * blockIdx.y;
    if (trackIndex >= optimizeInitialTrackParams.numTracks)
        return;

    const unsigned fullIndex = threadIdx.x + threadIdx.y*8;

    __shared__ TrackData currentState;
    if (fullIndex < sizeof(currentState)/4) {
        const float *fptrSrc = (float*)&kernelParams.trackData[trackIndex];
        float *fptr = (float*)&currentState;
        fptr[fullIndex] = fptrSrc[fullIndex];
    }
    __syncthreads();

    __shared__ float lumScale1;
    __shared__ float lumScale2;

    __shared__ float accu[8*8];

    computeAvgLumImage1(&currentState, accu, fullIndex);
    if (fullIndex == 0) {
        lumScale1 = (0.2f * 16.0f * 16.0f) / fmax(accu[0], 0.01f);
    }
    __syncthreads();
    computeAvgLumImage2(&currentState, accu, fullIndex);
    if (fullIndex == 0) {
        lumScale2 = (0.2f * 16.0f * 16.0f) / fmax(accu[0], 0.01f);
    }

    __shared__ uint32_t *leftPtr, *rightPtr;
    if (fullIndex == 0) {
        leftPtr = kernelParams.dst + trackIndex*32*32*2;
        rightPtr = kernelParams.dst + trackIndex*32*32*2 + 32*32;
    }

    __syncthreads();

    for (unsigned ix = 0; ix < 4; ix++)
        for (unsigned iy = 0; iy < 4; iy++) {
            float4 colors1, colors2;
            {
                float u = (threadIdx.x - 15.5f + ix*8.0f + currentState.epipolarOffsetHalf) * currentState.size;
                float v = (threadIdx.y - 15.5f + iy*8.0f) * currentState.size;

                float x = currentState.surfaceToWorld[0*4+0] * u + currentState.surfaceToWorld[0*4+1] * v + currentState.surfaceToWorld[0*4+3];
                float y = currentState.surfaceToWorld[1*4+0] * u + currentState.surfaceToWorld[1*4+1] * v + currentState.surfaceToWorld[1*4+3];
                float z = currentState.surfaceToWorld[2*4+0] * u + currentState.surfaceToWorld[2*4+1] * v + currentState.surfaceToWorld[2*4+3];

                float texelX, texelY;
                project2D(optimizeInitialTrackParams.P1norm, x, y, z, texelX, texelY);
                colors1 = tex2DLod(sourceImagePyramid1, texelX, texelY, currentState.lod[0]);

                float fac = lumScale1 * colors1.w;

                colors1.x *= fac;
                colors1.y *= fac;
                colors1.z *= fac;
                float m = fmax(colors1.x, fmax(colors1.y, colors1.z));

                unsigned w = min((int)(m * 255+1), 255);

                float rcpM = 255.0f*255.0f / w;

                uint32_t data = colors1.x * rcpM;
                data |= ((unsigned)(colors1.y * rcpM)) << 8;
                data |= ((unsigned)(colors1.z * rcpM)) << 16;
                data |= w << 24;


                leftPtr[(threadIdx.y+iy*8)*32+threadIdx.x+ix*8] = data;
            }
            {
                float u = (threadIdx.x - 15.5f + ix*8.0f - currentState.epipolarOffsetHalf) * currentState.size;
                float v = (threadIdx.y - 15.5f + iy*8.0f) * currentState.size;

                float x = currentState.surfaceToWorld[0*4+0] * u + currentState.surfaceToWorld[0*4+1] * v + currentState.surfaceToWorld[0*4+3];
                float y = currentState.surfaceToWorld[1*4+0] * u + currentState.surfaceToWorld[1*4+1] * v + currentState.surfaceToWorld[1*4+3];
                float z = currentState.surfaceToWorld[2*4+0] * u + currentState.surfaceToWorld[2*4+1] * v + currentState.surfaceToWorld[2*4+3];


                float texelX, texelY;
                project2D(optimizeInitialTrackParams.P2norm, x, y, z, texelX, texelY);
                colors2 = tex2DLod(sourceImagePyramid2, texelX, texelY, currentState.lod[1]);

                float fac = lumScale2 * colors2.w;

                colors2.x *= fac;
                colors2.y *= fac;
                colors2.z *= fac;
                float m = fmax(colors2.x, fmax(colors2.y, colors2.z));

                unsigned w = min((int)(m * 255+1), 255);

                float rcpM = 255.0f*255.0f / w;

                uint32_t data = colors2.x * rcpM;
                data |= ((unsigned)(colors2.y * rcpM)) << 8;
                data |= ((unsigned)(colors2.z * rcpM)) << 16;
                data |= w << 24;


                rightPtr[(threadIdx.y+iy*8)*32+threadIdx.x+ix*8] = data;
            }
        }
}

const unsigned NumPreWarps = 11;
__constant__ float ExtractionPreWarpMatrices[4*4*NumPreWarps];

__device__ void mulMatrix4x4with4x4(float *dst, const float *src1, const float *src2, unsigned fullIndex)
{
    if (fullIndex < 16) {
        const unsigned x = fullIndex & 3;
        const unsigned y = fullIndex >> 2;
        dst[y*4+x] = src1[y*4+0] * src2[0*4+x] +
                     src1[y*4+1] * src2[1*4+x] +
                     src1[y*4+2] * src2[2*4+x] +
                     src1[y*4+3] * src2[3*4+x];
    }
}

__device__ void mulMatrix3x4with4x4(float *dst, const float *src1, const float *src2, unsigned fullIndex)
{
    if (fullIndex < 16) {
        const unsigned x = fullIndex & 3;
        const unsigned y = fullIndex >> 2;
        if (y < 3)
            dst[y*4+x] = src1[y*4+0] * src2[0*4+x] +
                         src1[y*4+1] * src2[1*4+x] +
                         src1[y*4+2] * src2[2*4+x] +
                         src1[y*4+3] * src2[3*4+x];
        else
            dst[3*4+x] = src2[3*4+x];
    }
}

__device__ void projectSurface2D(float *P, float x, float y, float &rx, float &ry)
{
    float tx = P[0*4+0] * x + P[0*4+1] * y + P[0*4+3];
    float ty = P[1*4+0] * x + P[1*4+1] * y + P[1*4+3];
    float tw = P[3*4+0] * x + P[3*4+1] * y + P[3*4+3];

    rx = tx*(1.0f/tw);
    ry = ty*(1.0f/tw);
}


extern "C" __global__ void extractPreWarpedProjections(ExtractProjectionsKernelParams kernelParams)
{
    const unsigned trackIndex = blockIdx.x + gridDim.x * blockIdx.y;
    if (trackIndex >= optimizeInitialTrackParams.numTracks)
        return;

    const unsigned fullIndex = threadIdx.x + threadIdx.y*8;

    __shared__ TrackData currentState;
    if (fullIndex < sizeof(currentState)/4) {
        const float *fptrSrc = (float*)&kernelParams.trackData[trackIndex];
        float *fptr = (float*)&currentState;
        fptr[fullIndex] = fptrSrc[fullIndex];
    }
    __syncthreads();

    __shared__ float lumScale1;
    __shared__ float lumScale2;

    __shared__ float accu[8*8];

    computeAvgLumImage1(&currentState, accu, fullIndex);
    if (fullIndex == 0) {
        lumScale1 = (0.2f * 16.0f * 16.0f) / fmax(accu[0], 0.01f);
    }
    __syncthreads();
    computeAvgLumImage2(&currentState, accu, fullIndex);
    if (fullIndex == 0) {
        lumScale2 = (0.2f * 16.0f * 16.0f) / fmax(accu[0], 0.01f);
    }

    __shared__ uint32_t *dstPtr;
    if (fullIndex == 0) {
        dstPtr = kernelParams.dst + trackIndex*16*16*11;
    }

    __syncthreads();

    __shared__ float helperMatrix[4*4];
    __shared__ float matrix1[4*4];
    __shared__ float matrix2[4*4];

    for (unsigned warpIndex = 0; warpIndex < NumPreWarps; warpIndex++) {

        mulMatrix3x4with4x4(helperMatrix, currentState.surfaceToWorld, ExtractionPreWarpMatrices+warpIndex*4*4, fullIndex);
        __syncthreads();
        mulMatrix4x4with4x4(matrix1, optimizeInitialTrackParams.P1norm, helperMatrix, fullIndex);
        mulMatrix4x4with4x4(matrix2, optimizeInitialTrackParams.P2norm, helperMatrix, fullIndex);
        __syncthreads();


        for (unsigned ix = 0; ix < 2; ix++)
            for (unsigned iy = 0; iy < 2; iy++) {
                float4 colors1, colors2;
                {
                    float u = (threadIdx.x - 7.5f + ix*8.0f + currentState.epipolarOffsetHalf) * currentState.size;
                    float v = (threadIdx.y - 7.5f + iy*8.0f) * currentState.size;

                    float texelX, texelY;
                    projectSurface2D(matrix1, u, v, texelX, texelY);
                    colors1 = tex2DLod(sourceImagePyramid1, texelX, texelY, currentState.lod[0]);

                    float fac = lumScale1 * colors1.w;

                    colors1.x *= fac;
                    colors1.y *= fac;
                    colors1.z *= fac;
                }
                {
                    float u = (threadIdx.x - 7.5f + ix*8.0f - currentState.epipolarOffsetHalf) * currentState.size;
                    float v = (threadIdx.y - 7.5f + iy*8.0f) * currentState.size;

                    float texelX, texelY;
                    projectSurface2D(matrix2, u, v, texelX, texelY);
                    colors2 = tex2DLod(sourceImagePyramid2, texelX, texelY, currentState.lod[1]);

                    float fac = lumScale2 * colors2.w;

                    colors2.x *= fac;
                    colors2.y *= fac;
                    colors2.z *= fac;
                }
                float r = (colors1.x + colors2.x) * 0.5f;
                float g = (colors1.y + colors2.y) * 0.5f;
                float b = (colors1.z + colors2.z) * 0.5f;
                float m = fmax(r, fmax(g, b));

                unsigned w = min((int)(m * 255+1), 255);

                float rcpM = 255.0f*255.0f / w;

                uint32_t data = r * rcpM;
                data |= ((unsigned)(g * rcpM)) << 8;
                data |= ((unsigned)(b * rcpM)) << 16;
                data |= w << 24;

                dstPtr[warpIndex*16*16 + (iy*8+threadIdx.y) * 16 + ix*8+threadIdx.x] = data;
            }
    }
}

