/*
    Structure from Motion with Deferred Feature Matching and Subset Bundle Adjustment
    Copyright (C) 2015 Andreas Ley <andy-ley@arcor.de>

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
*/

#undef _GLIBCXX_USE_INT128

#include "computePairwisePatchAlignment.h"
#include "cudaKernelUtils/WarpLinAlg.hpp"
#include "cudaKernelUtils/WarpLoadStore.hpp"
#include <stdio.h>

using cuUtils::WarpVector;
using cuUtils::WarpMatrix;


const float patchTexelOffset = PairwisePatchAlignmentPatchSize/2 - 0.5f;


texture<uchar1, cudaTextureType2DLayered, cudaReadModeNormalizedFloat> patchAtlas;



struct SharedHomography {
    float values[6];
    __device__ void bcastLoad(const float *src) {
        float f;
        if (cub::LaneId() < 6)
            f = src[cub::LaneId()];
        #pragma unroll
        for (unsigned i = 0; i < 6; i++)
            values[i] = __shfl(f, i);
    }

    __device__ float transformXCoord(float x, float y) const {
        return (1.0f + values[0]) * x + values[2] * y + values[4];
    }
    __device__ float transformYCoord(float x, float y) const {
        return values[1] * x + (1.0f + values[3]) * y + values[5];
    }
};


__device__ WarpVector<6> composeTransformation(const WarpVector<6> &p, const WarpVector<6> &deltap)
{
    WarpVector<6> result;

    result.m_elements[0] = p.m_elements[0] +
                           deltap.m_elements[0] +
                           __shfl(p.m_elements[0],      cub::LaneId()  & 1)         *
                           __shfl(deltap.m_elements[0], cub::LaneId()  & (~1)        ) +
                           __shfl(p.m_elements[0],      (cub::LaneId() & 1)    + 2) *
                           __shfl(deltap.m_elements[0], (cub::LaneId() & (~1)) + 1);

    if ((cub::LaneId() % 8) > 5)
        result.m_elements[0] = 0.0f;

    return result;
}


__device__ WarpVector<6> invertTransformation(const WarpVector<6> &p)
{
    WarpVector<6> result;

    float denom = (1.0f + __shfl(p.m_elements[0], 0)) * (1.0f + __shfl(p.m_elements[0], 3)) -
                    __shfl(p.m_elements[0], 1) * __shfl(p.m_elements[0], 2);

    float rcpDenom = 1.0f / denom;

    // todo: make parallel

    {
        float f = -__shfl(p.m_elements[0], 0)
                  -__shfl(p.m_elements[0], 0) * __shfl(p.m_elements[0], 3)
                  +__shfl(p.m_elements[1], 1) * __shfl(p.m_elements[0], 2);
        if ((cub::LaneId() % 8) == 0)
            result.m_elements[0] = f;
    }

    {
        float f = -__shfl(p.m_elements[0], 1);
        if ((cub::LaneId() % 8) == 1)
            result.m_elements[0] = f;
    }

    {
        float f = -__shfl(p.m_elements[0], 2);
        if ((cub::LaneId() % 8) == 2)
            result.m_elements[0] = f;
    }

    {
        float f = -__shfl(p.m_elements[0], 3)
                  -__shfl(p.m_elements[0], 0) * __shfl(p.m_elements[0], 3)
                  +__shfl(p.m_elements[1], 1) * __shfl(p.m_elements[0], 2);
        if ((cub::LaneId() % 8) == 3)
            result.m_elements[0] = f;
    }

    {
        float f = -__shfl(p.m_elements[0], 4)
                  -__shfl(p.m_elements[0], 3) * __shfl(p.m_elements[0], 4)
                  +__shfl(p.m_elements[1], 2) * __shfl(p.m_elements[0], 5);
        if ((cub::LaneId() % 8) == 4)
            result.m_elements[0] = f;
    }

    {
        float f = -__shfl(p.m_elements[0], 5)
                  -__shfl(p.m_elements[0], 0) * __shfl(p.m_elements[0], 5)
                  +__shfl(p.m_elements[1], 1) * __shfl(p.m_elements[0], 4);
        if ((cub::LaneId() % 8) == 5)
            result.m_elements[0] = f;
    }

    result.m_elements[0] *= rcpDenom;

    return result;
}




template<unsigned numChannels>
struct Image
{
    float value[numChannels];

    template<unsigned channel = 0>
    __device__ void sampleFromAtlas(float x, float y, unsigned layer, float lod) {
        value[channel] = tex2DLayeredLod(patchAtlas, x, y, layer, lod).x;
    }

    __device__ Image<numChannels> operator-(const Image<numChannels> &rhs) {
        Image<numChannels> result;
        #pragma unroll
        for (unsigned i = 0; i < numChannels; i++)
            result.value[i] = value[i] - rhs.value[i];

        return result;
    }

    __device__ Image<numChannels> operator*(float scalar) {
        Image<numChannels> result;
        #pragma unroll
        for (unsigned i = 0; i < numChannels; i++)
            result.value[i] = value[i] * scalar;

        return result;
    }

    __device__ Image<numChannels> operator+(float scalar) {
        Image<numChannels> result;
        #pragma unroll
        for (unsigned i = 0; i < numChannels; i++)
            result.value[i] = value[i] + scalar;

        return result;
    }
};



__device__ void loadTemplateImage(const PairwisePatchAlignmentJob &job, Image<1> &templateImage, Image<6> &steepestDescentImages)
{
    float localX = (threadIdx.x - patchTexelOffset);
    float localY = (threadIdx.y - patchTexelOffset);

    //float r = localX*localX+localY*localY;
    float stretch = 1.0f;//0.5f + r / (float)PairwisePatchAlignmentPatchSqrSize;
    localX = localX*stretch;
    localY = localY*stretch;

    float x = job.templatePatchAtlasLocation[0] + localX * job.templatePatchAtlasSize;
    float y = job.templatePatchAtlasLocation[1] + localY * job.templatePatchAtlasSize;


    templateImage.sampleFromAtlas<0>(x, y, job.templatePatchAtlasLayer, 0.0f);

    float dLumdx = (tex2DLayeredLod(patchAtlas, x + job.templatePatchAtlasSize * 0.2f, y, job.templatePatchAtlasLayer, 0.0f).x -
                    tex2DLayeredLod(patchAtlas, x - job.templatePatchAtlasSize * 0.2f, y, job.templatePatchAtlasLayer, 0.0f).x) * (1.0f / 0.4f);

    float dLumdy = (tex2DLayeredLod(patchAtlas, x, y + job.templatePatchAtlasSize * 0.2f, job.templatePatchAtlasLayer, 0.0f).x -
                    tex2DLayeredLod(patchAtlas, x, y - job.templatePatchAtlasSize * 0.2f, job.templatePatchAtlasLayer, 0.0f).x) * (1.0f / 0.4f);

    steepestDescentImages.value[0] = localX * dLumdx;
    steepestDescentImages.value[1] = localX * dLumdy;
    steepestDescentImages.value[2] = localY * dLumdx;
    steepestDescentImages.value[3] = localY * dLumdy;
    steepestDescentImages.value[4] = 1.0f * dLumdx;
    steepestDescentImages.value[5] = 1.0f * dLumdy;
}

__device__ void computeAvgLum(Image<1> &templateImage,
                              float *sharedMemory,
                                  float *destination)
{
    const unsigned fullIndex = threadIdx.x + threadIdx.y*PairwisePatchAlignmentPatchSize;

    sharedMemory[fullIndex] = templateImage.value[0];

    __syncthreads();

    for (unsigned stride = PairwisePatchAlignmentPatchSqrSize/2; stride > 16; stride >>= 1) {
        if (fullIndex < stride) {
            sharedMemory[fullIndex] += sharedMemory[fullIndex+stride];
        }
        __syncthreads();
    }
    if (cuUtils::getWarpIdInBlock2D() == 0) {
        float sum = sharedMemory[fullIndex];

        #pragma unroll
        for (int j=16; j>=1; j/=2) {
            sum += __shfl_xor(sum, j, 32);
        }

        if (fullIndex == 0) {
            *destination = sum / (PairwisePatchAlignmentPatchSize*PairwisePatchAlignmentPatchSize);
        }
    }
    __syncthreads();
}


__device__ void computeLeastSquaresContrastBrightness(Image<1> &templateImage,
                                                      Image<1> &warpedAlignmentImage,
                                                      const float &avgTemplateLum,
                                                      float *sharedMemory,
                                                      float &scale,
                                                      float &offset)
{
    const unsigned fullIndex = threadIdx.x + threadIdx.y*PairwisePatchAlignmentPatchSize;

    sharedMemory[PairwisePatchAlignmentPatchSqrSize*0 + fullIndex] = warpedAlignmentImage.value[0];
    sharedMemory[PairwisePatchAlignmentPatchSqrSize*1 + fullIndex] = warpedAlignmentImage.value[0] * warpedAlignmentImage.value[0];
    sharedMemory[PairwisePatchAlignmentPatchSqrSize*2 + fullIndex] = templateImage.value[0] * warpedAlignmentImage.value[0];

    __syncthreads();

    for (unsigned stride = PairwisePatchAlignmentPatchSqrSize/2; stride > 16; stride >>= 1) {
        if (fullIndex < stride) {
            sharedMemory[PairwisePatchAlignmentPatchSqrSize*0 + fullIndex] += sharedMemory[PairwisePatchAlignmentPatchSqrSize*0 + fullIndex+stride];
            sharedMemory[PairwisePatchAlignmentPatchSqrSize*1 + fullIndex] += sharedMemory[PairwisePatchAlignmentPatchSqrSize*1 + fullIndex+stride];
            sharedMemory[PairwisePatchAlignmentPatchSqrSize*2 + fullIndex] += sharedMemory[PairwisePatchAlignmentPatchSqrSize*2 + fullIndex+stride];
        }
        __syncthreads();
    }
    if (cuUtils::getWarpIdInBlock2D() == 0) {
        float sumY  = sharedMemory[PairwisePatchAlignmentPatchSqrSize*0 + fullIndex];
        float sumYY = sharedMemory[PairwisePatchAlignmentPatchSqrSize*1 + fullIndex];
        float sumXY = sharedMemory[PairwisePatchAlignmentPatchSqrSize*2 + fullIndex];

        #pragma unroll
        for (int j=16; j>=1; j/=2) {
            sumY += __shfl_xor(sumY, j, 32);
            sumYY += __shfl_xor(sumYY, j, 32);
            sumXY += __shfl_xor(sumXY, j, 32);
        }

        if (fullIndex == 0) {
            float sumX = avgTemplateLum * (PairwisePatchAlignmentPatchSqrSize);
            float det = sumY*sumY - (PairwisePatchAlignmentPatchSqrSize)*sumYY;
            if (fabs(det) < 1e-9f) {
                sharedMemory[PairwisePatchAlignmentPatchSqrSize*0 + fullIndex] = 0.0f;
                sharedMemory[PairwisePatchAlignmentPatchSqrSize*1 + fullIndex] = 1.0f;
            } else {
                sharedMemory[PairwisePatchAlignmentPatchSqrSize*0 + fullIndex] = (-sumYY * sumX + sumY * sumXY)/det;
                sharedMemory[PairwisePatchAlignmentPatchSqrSize*1 + fullIndex] = (sumY*sumX  - (PairwisePatchAlignmentPatchSqrSize)*sumXY)/det;
            }
        }
    }
    __syncthreads();
    offset = sharedMemory[PairwisePatchAlignmentPatchSqrSize*0];
    scale = sharedMemory[PairwisePatchAlignmentPatchSqrSize*1];
    __syncthreads();
}


__device__ WarpVector<6> computeUpdateStep(const WarpMatrix<6, 6> &hessianMatrix,
                                  const Image<6> &steepestDescentImages,
                                  Image<1> &templateImage,
                                  Image<1> &warpedAlignmentImage,
                                  const float &avgTemplateLum,
                                  float *sharedMemory)
{
    const unsigned fullIndex = threadIdx.x + threadIdx.y*PairwisePatchAlignmentPatchSize;

    float lumOffset, lumScale;
    computeLeastSquaresContrastBrightness(templateImage, warpedAlignmentImage, avgTemplateLum, sharedMemory, lumScale, lumOffset);

    Image<1> differenceImage = (warpedAlignmentImage * lumScale + lumOffset) - templateImage;

    WarpVector<6> v;

    for (unsigned i = 0; i < 6; i+=3) {
        sharedMemory[PairwisePatchAlignmentPatchSqrSize*0 + fullIndex] = steepestDescentImages.value[(i+0)] * differenceImage.value[0];
        sharedMemory[PairwisePatchAlignmentPatchSqrSize*1 + fullIndex] = steepestDescentImages.value[(i+1)] * differenceImage.value[0];
        sharedMemory[PairwisePatchAlignmentPatchSqrSize*2 + fullIndex] = steepestDescentImages.value[(i+2)] * differenceImage.value[0];

        __syncthreads();

        for (unsigned stride = PairwisePatchAlignmentPatchSqrSize/2; stride > 16; stride >>= 1) {
            if (fullIndex < stride) {
                sharedMemory[PairwisePatchAlignmentPatchSqrSize*0 + fullIndex] += sharedMemory[PairwisePatchAlignmentPatchSqrSize*0 + fullIndex+stride];
                sharedMemory[PairwisePatchAlignmentPatchSqrSize*1 + fullIndex] += sharedMemory[PairwisePatchAlignmentPatchSqrSize*1 + fullIndex+stride];
                sharedMemory[PairwisePatchAlignmentPatchSqrSize*2 + fullIndex] += sharedMemory[PairwisePatchAlignmentPatchSqrSize*2 + fullIndex+stride];
            }
            __syncthreads();
        }
        if (cuUtils::getWarpIdInBlock2D() == 0) {
            float sum1 = sharedMemory[PairwisePatchAlignmentPatchSqrSize*0 + fullIndex];
            float sum2 = sharedMemory[PairwisePatchAlignmentPatchSqrSize*1 + fullIndex];
            float sum3 = sharedMemory[PairwisePatchAlignmentPatchSqrSize*2 + fullIndex];

            #pragma unroll
            for (int j=16; j>=1; j/=2) {
                sum1 += __shfl_xor(sum1, j, 32);
                sum2 += __shfl_xor(sum2, j, 32);
                sum3 += __shfl_xor(sum3, j, 32);
            }


            if ((cub::LaneId() % 8) == i+0)
                v.m_elements[0] = sum1;

            if ((cub::LaneId() % 8) == i+1)
                v.m_elements[0] = sum2;

            if ((cub::LaneId() % 8) == i+2)
                v.m_elements[0] = sum3;
        }
    }
/*
    if (blockIdx.x == 0) {
        float dp[6];
        for (unsigned i = 0; i < 6; i++)
            dp[i] = v.bcast(i);

        if (fullIndex == 0) {
            printf("v: %f %f %f %f %f %f\n", dp[0], dp[1], dp[2], dp[3], dp[4], dp[5]);
        }
    }
*/
    WarpVector<6> result;
    if (cuUtils::getWarpIdInBlock2D() == 0)
        result = cuUtils::conjugateGradientSolve(hessianMatrix, v, 6);
/*
    if (blockIdx.x == 0) {
        float dp[6];
        for (unsigned i = 0; i < 6; i++)
            dp[i] = (hessianMatrix * result).bcast(i);

        if (fullIndex == 0) {
            printf("v: %f %f %f %f %f %f\n", dp[0], dp[1], dp[2], dp[3], dp[4], dp[5]);
        }
    }*/
/*
    if (((cub::LaneId() % 8) != 4) &&
        ((cub::LaneId() % 8) != 5))
            result.m_elements[0] = 0.0f;
*/
    return result;
}


template<unsigned i1, unsigned j1, unsigned i2, unsigned j2, unsigned i3, unsigned j3>
__device__ void computeHessianElements(Image<6> &steepestDescentImages, float *sharedMemory, float *dst)
{
    const unsigned fullIndex = threadIdx.x + threadIdx.y*PairwisePatchAlignmentPatchSize;

    sharedMemory[PairwisePatchAlignmentPatchSqrSize*0 + fullIndex] = steepestDescentImages.value[i1] * steepestDescentImages.value[j1];
    sharedMemory[PairwisePatchAlignmentPatchSqrSize*1 + fullIndex] = steepestDescentImages.value[i2] * steepestDescentImages.value[j2];
    sharedMemory[PairwisePatchAlignmentPatchSqrSize*2 + fullIndex] = steepestDescentImages.value[i3] * steepestDescentImages.value[j3];

    __syncthreads();

    for (unsigned stride = PairwisePatchAlignmentPatchSqrSize/2; stride > 16; stride >>= 1) {
        if (fullIndex < stride) {
            sharedMemory[PairwisePatchAlignmentPatchSqrSize*0 + fullIndex] += sharedMemory[PairwisePatchAlignmentPatchSqrSize*0 + fullIndex+stride];
            sharedMemory[PairwisePatchAlignmentPatchSqrSize*1 + fullIndex] += sharedMemory[PairwisePatchAlignmentPatchSqrSize*1 + fullIndex+stride];
            sharedMemory[PairwisePatchAlignmentPatchSqrSize*2 + fullIndex] += sharedMemory[PairwisePatchAlignmentPatchSqrSize*2 + fullIndex+stride];
        }
        __syncthreads();
    }
    if (cuUtils::getWarpIdInBlock2D() == 0) {
        float sum1 = sharedMemory[PairwisePatchAlignmentPatchSqrSize*0 + fullIndex];
        float sum2 = sharedMemory[PairwisePatchAlignmentPatchSqrSize*1 + fullIndex];
        float sum3 = sharedMemory[PairwisePatchAlignmentPatchSqrSize*2 + fullIndex];

        #pragma unroll
        for (int j=16; j>=1; j/=2) {
            sum1 += __shfl_xor(sum1, j, 32);
            sum2 += __shfl_xor(sum2, j, 32);
            sum3 += __shfl_xor(sum3, j, 32);
        }


        if (cub::LaneId() == 0) { // todo: speed up
            dst[i1*6+j1] = dst[j1*6+i1] = sum1;
            dst[i2*6+j2] = dst[j2*6+i2] = sum2;
            dst[i3*6+j3] = dst[j3*6+i3] = sum3;
        }
    }

}



__device__ void computeHessian(Image<6> &steepestDescentImages, float *sharedMemory, float *dst)
{
    computeHessianElements<0, 0,
                           1, 0,
                           1, 1>(steepestDescentImages, sharedMemory, dst);
    computeHessianElements<2, 0,
                           2, 1,
                           2, 2>(steepestDescentImages, sharedMemory, dst);
    computeHessianElements<3, 0,
                           3, 1,
                           3, 2>(steepestDescentImages, sharedMemory, dst);
    computeHessianElements<3, 3,
                           4, 0,
                           4, 1>(steepestDescentImages, sharedMemory, dst);
    computeHessianElements<4, 2,
                           4, 3,
                           4, 4>(steepestDescentImages, sharedMemory, dst);
    computeHessianElements<5, 0,
                           5, 1,
                           5, 2>(steepestDescentImages, sharedMemory, dst);
    computeHessianElements<5, 3,
                           5, 4,
                           5, 5>(steepestDescentImages, sharedMemory, dst);
}

__device__ float computeDifference(const Image<1> &A, const Image<1> &B, float *sharedMemory)
{
    const unsigned fullIndex = threadIdx.x + threadIdx.y*PairwisePatchAlignmentPatchSize;

    float d = A.value[0] - B.value[0];
    sharedMemory[fullIndex] = d*d;

    __syncthreads();

    for (unsigned stride = PairwisePatchAlignmentPatchSqrSize/2; stride > 16; stride >>= 1) {
        if (fullIndex < stride)
            sharedMemory[fullIndex] += sharedMemory[fullIndex+stride];
        __syncthreads();
    }
    float sum;
    if (cuUtils::getWarpIdInBlock2D() == 0) {
        sum = sharedMemory[fullIndex];

        #pragma unroll
        for (int j=16; j>=1; j/=2) {
            sum += __shfl_xor(sum, j, 32);
        }
    }
    return sum;
}


extern "C" __global__ void __launch_bounds__(24*24, 2048/24/24) computePairwisePatchAlignment(PairwisePatchAlignmentKernelParams kernelParams)
{
    const unsigned fullIndex = threadIdx.x + threadIdx.y*PairwisePatchAlignmentPatchSize;

    const unsigned jobIndex = blockIdx.x;
    if (jobIndex >= kernelParams.numJobs)
        return;

    __shared__ PairwisePatchAlignmentJob job;
    if (cuUtils::getWarpIdInBlock2D() == 0) {
        cuUtils::warpCopy(&job, &kernelParams.jobs[jobIndex]);
       // job.alignment[4] += 0.0005f;
    }

    __syncthreads();

#ifdef COMPUTEPAIRWISEPATCHALIGNMENT_ENABLE_DEBUG_OUTPUT
    {
        PairwisePatchAlignmentDebugOutput &debugOutput = kernelParams.debugOutput[jobIndex];

        SharedHomography hom;
        hom.bcastLoad(job.alignment);

        float x = (threadIdx.x - patchTexelOffset);
        float y = (threadIdx.y - patchTexelOffset);

        float r = x*x+y*y;
        float stretch = 1.0f;//0.5f + r / (float)PairwisePatchAlignmentPatchSqrSize;
        x = x*stretch;
        y = y*stretch;

        Image<1> warpedImage;
        warpedImage.sampleFromAtlas(hom.transformXCoord(x, y), hom.transformYCoord(x, y), job.alignmentPatchAtlasLayer, 0.0f);

        debugOutput.unalignedImage[fullIndex] = warpedImage.value[0];
    }
#endif


/*
    if ((fullIndex == 0) && (jobIndex == 0)) {
        printf("Job %i\n", jobIndex);
        printf("   %f %f - %f - %i\n", job.templatePatchAtlasLocation[0], job.templatePatchAtlasLocation[1], job.templatePatchAtlasSize, job.templatePatchAtlasLayer);
        printf("   %i\n", job.alignmentPatchAtlasLayer);
        printf("   %f %f %f\n", job.alignment[0], job.alignment[2], job.alignment[4]);
        printf("   %f %f %f\n", job.alignment[1], job.alignment[3], job.alignment[5]);
    }
*/

    __shared__ union {
        float reductionMem[PairwisePatchAlignmentPatchSqrSize*3];
    } tempStorage;

    Image<1> templateImage;
    Image<6> steepestDescentImages;
    loadTemplateImage(job, templateImage, steepestDescentImages);

    __shared__ float avgTemplateLum;
    computeAvgLum(templateImage, tempStorage.reductionMem, &avgTemplateLum);

    WarpMatrix<6, 6> hessianMatrix;
    {
        __shared__ float hessianMem[6*6];
        computeHessian(steepestDescentImages, tempStorage.reductionMem, hessianMem);
        hessianMatrix.load(hessianMem);
    }

    for (unsigned iter = 0; iter < PairwisePatchAlignmentMaxIterations; iter++) {
        __syncthreads();
        SharedHomography hom;
        hom.bcastLoad(job.alignment);

        float x = (threadIdx.x - patchTexelOffset);
        float y = (threadIdx.y - patchTexelOffset);

        //float r = x*x+y*y;
        float stretch = 1.0f;//0.5f + r / (float)PairwisePatchAlignmentPatchSqrSize;
        x = x*stretch;
        y = y*stretch;

        Image<1> warpedImage;
        warpedImage.sampleFromAtlas(hom.transformXCoord(x, y), hom.transformYCoord(x, y), job.alignmentPatchAtlasLayer, 0.0f);

        WarpVector<6> deltaP = computeUpdateStep(hessianMatrix,
                                  steepestDescentImages,
                                  templateImage,
                                  warpedImage,
                                  avgTemplateLum,
                                  tempStorage.reductionMem);
/*
        if (jobIndex == 0) {
            float dp[6];
            for (unsigned i = 0; i < 6; i++)
                dp[i] = deltaP.bcast(i);

            if (fullIndex == 0) {
                printf("deltaP: %f %f %f %f %f %f\n", dp[0], dp[1], dp[2], dp[3], dp[4], dp[5]);
            }
        }
*/
        if (cuUtils::getWarpIdInBlock2D() == 0) {
            WarpVector<6> p;
            p.load(job.alignment);
/*
            if (jobIndex == 0) {
                float dp[6];
                for (unsigned i = 0; i < 6; i++)
                    dp[i] = p.bcast(i);

                if (fullIndex == 0) {
                    printf("old p: %f %f %f %f %f %f\n", dp[0], dp[1], dp[2], dp[3], dp[4], dp[5]);
                }
            }
*/

            p = composeTransformation(p, invertTransformation(deltaP));
            p.store(job.alignment);
/*
            if (jobIndex == 0) {
                float dp[6];
                for (unsigned i = 0; i < 6; i++)
                    dp[i] = p.bcast(i);

                if (fullIndex == 0) {
                    printf("new p: %f %f %f %f %f %f\n", dp[0], dp[1], dp[2], dp[3], dp[4], dp[5]);
                }
            }
*/
        }
    }

    if (cuUtils::getWarpIdInBlock2D() == 0)
        cuUtils::warpMemcpy<6*4>(&kernelParams.jobs[jobIndex].alignment, &job.alignment);

    __syncthreads();
    {
#ifdef COMPUTEPAIRWISEPATCHALIGNMENT_ENABLE_DEBUG_OUTPUT
        PairwisePatchAlignmentDebugOutput &debugOutput = kernelParams.debugOutput[jobIndex];

        debugOutput.templateImage[fullIndex] = templateImage.value[0];
#endif

        SharedHomography hom;
        hom.bcastLoad(job.alignment);

        float x = (threadIdx.x - patchTexelOffset);
        float y = (threadIdx.y - patchTexelOffset);

        //float r = x*x+y*y;
        float stretch = 1.0f;//0.5f + r / (float)PairwisePatchAlignmentPatchSqrSize;
        x = x*stretch;
        y = y*stretch;

        Image<1> warpedImage;
        warpedImage.sampleFromAtlas(hom.transformXCoord(x, y), hom.transformYCoord(x, y), job.alignmentPatchAtlasLayer, 0.0f);

        float lumOffset, lumScale;
        computeLeastSquaresContrastBrightness(templateImage, warpedImage, avgTemplateLum, tempStorage.reductionMem, lumScale, lumOffset);


#ifdef COMPUTEPAIRWISEPATCHALIGNMENT_ENABLE_DEBUG_OUTPUT
        debugOutput.alignedImage[fullIndex] = warpedImage.value[0] * lumScale + lumOffset;
#endif

        float error = computeDifference(templateImage, warpedImage * lumScale + lumOffset, tempStorage.reductionMem);
        if (fullIndex == 0) {
            kernelParams.jobs[jobIndex].alignmentError = error;
        }
    }


}
