/*
    Structure from Motion with Deferred Feature Matching and Subset Bundle Adjustment
    Copyright (C) 2015 Andreas Ley <andy-ley@arcor.de>

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
*/

#include "PairwisePatchAligner.h"
#include "PatchCache.h"
#include "../cudaKernels/computePairwisePatchAlignment.h"
#include "../tools/RasterImage.h"

PairwisePatchAligner::PairwisePatchAligner()
{
    m_codeModule.loadFromFile("../SFMBackend/kernels/Release/computePairwisePatchAlignment.fatbin");

    m_alignmentKernel = std::unique_ptr<CudaUtils::CudaKernel>(m_codeModule.getKernel("computePairwisePatchAlignment"));
    m_patchAtlasTexRef = std::unique_ptr<CudaUtils::CudaTextureReference>(m_codeModule.getTexReference("patchAtlas"));

    m_patchAtlasTexRef->setTexelFilterMode(CudaUtils::CudaTextureReference::FILTER_MODE_LINEAR);
    m_patchAtlasTexRef->setMipmapFilterMode(CudaUtils::CudaTextureReference::FILTER_MODE_NEAREST);
    m_patchAtlasTexRef->setCoordinateNormalization(true);

}

PairwisePatchAligner::~PairwisePatchAligner()
{
    //dtor
}


void PairwisePatchAligner::optimize(std::vector<AlignmentPair> &pairs, PatchCache *patchCache)
{
    m_patchAtlasTexRef->bindMipmappedTexture(patchCache->getAtlasTexture());
    m_patchAtlasTexRef->setMinMaxMipLevel(0.0f, 0.0f);

    std::vector<PairwisePatchAlignmentJob> cpuJobData;
    cpuJobData.resize(pairs.size());


    for (unsigned i = 0; i < pairs.size(); i++) {
        PairwisePatchAlignmentJob &job = cpuJobData[i];
        const AlignmentPair &alignmentPair = pairs[i];

        const LinAlg::Vector2f &templateScreenToAtlasScale = patchCache->getScreenSpaceToCacheSpaceScale(alignmentPair.templatePatchCacheSlot);
        const LinAlg::Vector2f &templateScreenToAtlasOffset = patchCache->getScreenSpaceToCacheSpaceOffset(alignmentPair.templatePatchCacheSlot);
        job.templatePatchAtlasLayer = patchCache->getLayer(alignmentPair.templatePatchCacheSlot);

        const LinAlg::Vector2f &alignmentScreenToAtlasScale = patchCache->getScreenSpaceToCacheSpaceScale(alignmentPair.alignmentPatchCacheSlot);
        const LinAlg::Vector2f &alignmentScreenToAtlasOffset = patchCache->getScreenSpaceToCacheSpaceOffset(alignmentPair.alignmentPatchCacheSlot);
        job.alignmentPatchAtlasLayer = patchCache->getLayer(alignmentPair.alignmentPatchCacheSlot);


        LinAlg::Vector2f templatePatchAtlasLocation = (alignmentPair.templateScreenSpacePosition & templateScreenToAtlasScale) + templateScreenToAtlasOffset;
        job.templatePatchAtlasLocation[0] = templatePatchAtlasLocation[0];
        job.templatePatchAtlasLocation[1] = templatePatchAtlasLocation[1];
        job.templatePatchAtlasSize = 1.0f / PATCH_CACHE_LAYER_SIZE;
/*
        LinAlg::Matrix3x3f hom = LinAlg::Scale2D(LinAlg::Fill(job.templatePatchAtlasSize, job.templatePatchAtlasSize)) *
                                 LinAlg::Translation2D(templatePatchAtlasLocation) *
                                 LinAlg::Translation2D(templateScreenToAtlasOffset.negated()) *
                                 LinAlg::Scale2D(templateScreenToAtlasScale.getRcp()) *
                                 alignmentPair.templateScreenToAlignmentScreen *
                                 LinAlg::Scale2D(alignmentScreenToAtlasScale) *
                                 LinAlg::Translation2D(alignmentScreenToAtlasOffset);
*/
        LinAlg::Matrix3x3f hom =
                                 LinAlg::Translation2D(alignmentScreenToAtlasOffset) *
                                 LinAlg::Scale2D(alignmentScreenToAtlasScale) *
                                 alignmentPair.templateScreenToAlignmentScreen *
                                 LinAlg::Scale2D(templateScreenToAtlasScale.getRcp()) *
                                 LinAlg::Translation2D(templateScreenToAtlasOffset.negated()) *
                                 LinAlg::Translation2D(templatePatchAtlasLocation) *
                                 LinAlg::Scale2D(LinAlg::Fill(job.templatePatchAtlasSize, job.templatePatchAtlasSize));


        job.alignment[0] = hom[0][0] - 1.0f;
        job.alignment[1] = hom[1][0];
        job.alignment[2] = hom[0][1];
        job.alignment[3] = hom[1][1] - 1.0f;

        job.alignment[4] = hom[0][2];
        job.alignment[5] = hom[1][2];
    }



    m_jobData.resize(cpuJobData.size() * sizeof(PairwisePatchAlignmentJob));
    m_jobData.upload(&cpuJobData[0], cpuJobData.size() * sizeof(PairwisePatchAlignmentJob));


#ifdef COMPUTEPAIRWISEPATCHALIGNMENT_ENABLE_DEBUG_OUTPUT
    CudaUtils::CudaDeviceMemory debugVData;
    debugVData.resize(cpuJobData.size() * sizeof(PairwisePatchAlignmentDebugOutput));
#endif
    {
        PairwisePatchAlignmentKernelParams kernelParams;
        kernelParams.jobs = ((PairwisePatchAlignmentJob*) m_jobData.getPtr());
        kernelParams.numJobs = cpuJobData.size();
#ifdef COMPUTEPAIRWISEPATCHALIGNMENT_ENABLE_DEBUG_OUTPUT
        kernelParams.debugOutput = (PairwisePatchAlignmentDebugOutput*)debugVData.getPtr();
#endif
        m_alignmentKernel->launch(LinAlg::Fill(PairwisePatchAlignmentPatchSize, PairwisePatchAlignmentPatchSize, 1u),
                                      LinAlg::Fill<unsigned>(cpuJobData.size(), 1u, 1u),
                                      &kernelParams, sizeof(kernelParams));
    }

    m_jobData.download(&cpuJobData[0], cpuJobData.size() * sizeof(PairwisePatchAlignmentJob));

    for (unsigned i = 0; i < pairs.size(); i++) {
        const PairwisePatchAlignmentJob &job = cpuJobData[i];
        AlignmentPair &alignmentPair = pairs[i];

        const LinAlg::Vector2f &templateScreenToAtlasScale = patchCache->getScreenSpaceToCacheSpaceScale(alignmentPair.templatePatchCacheSlot);
        const LinAlg::Vector2f &templateScreenToAtlasOffset = patchCache->getScreenSpaceToCacheSpaceOffset(alignmentPair.templatePatchCacheSlot);

        const LinAlg::Vector2f &alignmentScreenToAtlasScale = patchCache->getScreenSpaceToCacheSpaceScale(alignmentPair.alignmentPatchCacheSlot);
        const LinAlg::Vector2f &alignmentScreenToAtlasOffset = patchCache->getScreenSpaceToCacheSpaceOffset(alignmentPair.alignmentPatchCacheSlot);

        LinAlg::Vector2f templatePatchAtlasLocation = (alignmentPair.templateScreenSpacePosition & templateScreenToAtlasScale) + templateScreenToAtlasOffset;

/*
        LinAlg::Matrix3x3f hom =
                                 LinAlg::Translation2D(alignmentScreenToAtlasOffset) *
                                 LinAlg::Scale2D(alignmentScreenToAtlasScale) *
                                 alignmentPair.templateScreenToAlignmentScreen *
                                 LinAlg::Scale2D(templateScreenToAtlasScale.getRcp()) *
                                 LinAlg::Translation2D(templateScreenToAtlasOffset.negated()) *
                                 LinAlg::Translation2D(templatePatchAtlasLocation) *
                                 LinAlg::Scale2D(LinAlg::Fill(job.templatePatchAtlasSize, job.templatePatchAtlasSize));
*/

        bool finite =
            std::isfinite(job.alignment[0]) &&
            std::isfinite(job.alignment[1]) &&
            std::isfinite(job.alignment[2]) &&
            std::isfinite(job.alignment[3]) &&
            std::isfinite(job.alignment[4]) &&
            std::isfinite(job.alignment[5]);


        if (!finite) {
            alignmentPair.alignmentError = 1e6f;
        } else {

            LinAlg::Matrix3x3f hom;
            hom[0][0] = job.alignment[0] +1.0f;
            hom[1][0] = job.alignment[1];
            hom[0][1] = job.alignment[2];
            hom[1][1] = job.alignment[3] +1.0f;
            hom[0][2] = job.alignment[4];
            hom[1][2] = job.alignment[5];


            alignmentPair.templateScreenToAlignmentScreen =
                                     LinAlg::Scale2D(alignmentScreenToAtlasScale.getRcp()) *
                                     LinAlg::Translation2D(alignmentScreenToAtlasOffset.negated()) *
                                     hom *
                                     LinAlg::Scale2D(LinAlg::Fill(1.0f/job.templatePatchAtlasSize, 1.0f/job.templatePatchAtlasSize)) *
                                     LinAlg::Translation2D(templatePatchAtlasLocation.negated()) *
                                     LinAlg::Translation2D(templateScreenToAtlasOffset) *
                                     LinAlg::Scale2D(templateScreenToAtlasScale);

            alignmentPair.alignmentError = job.alignmentError;
        }
    }

#ifdef COMPUTEPAIRWISEPATCHALIGNMENT_ENABLE_DEBUG_OUTPUT
    std::vector<PairwisePatchAlignmentDebugOutput> debugData;
    debugData.resize(cpuJobData.size());
    debugVData.download(&debugData[0], cpuJobData.size() * sizeof(PairwisePatchAlignmentDebugOutput));

    {
        RasterImage debugImage;
        unsigned numPairs = std::min<unsigned>(cpuJobData.size(), 500u);
        const unsigned debugPatchSize = PairwisePatchAlignmentPatchSize;
        debugImage.resize(numPairs*(debugPatchSize+4), (debugPatchSize+4)*3);
        for (unsigned i = 0; i < numPairs; i++) {

            LinAlg::Vector<3, unsigned char> color = LinAlg::clampColor(LinAlg::ColorRamp(cpuJobData[i].alignmentError / 1.5f));

            uint32_t colorUint32 = (color[0] << 0) |
                                   (color[1] << 8) |
                                   (color[2] << 16) |
                                   (0xFF << 24);

            debugImage.drawBox(LinAlg::Fill<int>(i*(debugPatchSize+4), 0),LinAlg::Fill<int>(i*(debugPatchSize+4)+(debugPatchSize+3), (debugPatchSize+4)*3-1), colorUint32, true);

            for (unsigned y = 0; y < debugPatchSize; y++)
                for (unsigned x = 0; x < debugPatchSize; x++) {
                    float lum = debugData[i].templateImage[y*debugPatchSize+x];
                    unsigned char *dstPixel = (unsigned char*) &debugImage.getData()[(0+y+2)*debugImage.getWidth() + i*(debugPatchSize+4)+x+2];
                    dstPixel[0] = std::max(std::min<int>(lum*255.0f, 255), 0);
                    dstPixel[1] = std::max(std::min<int>(lum*255.0f, 255), 0);
                    dstPixel[2] = std::max(std::min<int>(lum*255.0f, 255), 0);
                    dstPixel[3] = 255;
                }

            for (unsigned y = 0; y < debugPatchSize; y++)
                for (unsigned x = 0; x < debugPatchSize; x++) {
                    float lum = debugData[i].alignedImage[y*debugPatchSize+x];
                    unsigned char *dstPixel = (unsigned char*) &debugImage.getData()[(debugPatchSize+4+y+2)*debugImage.getWidth() + i*(debugPatchSize+4)+x+2];
                    dstPixel[0] = std::max(std::min<int>(lum*255.0f, 255), 0);
                    dstPixel[1] = std::max(std::min<int>(lum*255.0f, 255), 0);
                    dstPixel[2] = std::max(std::min<int>(lum*255.0f, 255), 0);
                    dstPixel[3] = 255;
                }

            for (unsigned y = 0; y < debugPatchSize; y++)
                for (unsigned x = 0; x < debugPatchSize; x++) {
                    float lum = debugData[i].unalignedImage[y*debugPatchSize+x];
                    unsigned char *dstPixel = (unsigned char*) &debugImage.getData()[((debugPatchSize+4)*2+y+2)*debugImage.getWidth() + i*(debugPatchSize+4)+x+2];
                    dstPixel[0] = std::max(std::min<int>(lum*255.0f, 255), 0);
                    dstPixel[1] = std::max(std::min<int>(lum*255.0f, 255), 0);
                    dstPixel[2] = std::max(std::min<int>(lum*255.0f, 255), 0);
                    dstPixel[3] = 255;
                }
        }
        debugImage.writeToFile("alignment.png");
    }
#endif


}
