/*
    Structure from Motion with Deferred Feature Matching and Subset Bundle Adjustment
    Copyright (C) 2015 Andreas Ley <andy-ley@arcor.de>

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
*/

#include "InitialPatchGatherer.h"

#include "../tools/RasterImage.h"

#include <sstream>

#include <string.h>
#include "../cudaUtilities/cudaProfilingScope.h"

InitialPatchGatherer::InitialPatchGatherer()
{
    m_codeModule.loadFromFile("../SFMBackend/kernels/Release/initialPatchGathering.fatbin");

    m_findOrientedBlobsKernel = std::unique_ptr<CudaUtils::CudaKernel>(m_codeModule.getKernel("findOrientedBlobs"));
    m_sourceImageTexRef = std::unique_ptr<CudaUtils::CudaTextureReference>(m_codeModule.getTexReference("sourceImagePyramid"));
    m_orientedBlobScoreSurfRef = std::unique_ptr<CudaUtils::CudaSurfaceReference>(m_codeModule.getSurfReference("orientedBlobScoreOutput"));

    m_sourceImageTexRef->setTexelFilterMode(CudaUtils::CudaTextureReference::FILTER_MODE_LINEAR);
    m_sourceImageTexRef->setMipmapFilterMode(CudaUtils::CudaTextureReference::FILTER_MODE_LINEAR);
    m_sourceImageTexRef->setCoordinateNormalization(true);


    m_nonMaximumSuppressKernel = std::unique_ptr<CudaUtils::CudaKernel>(m_codeModule.getKernel("nonMaximumSuppress"));
    m_stage1OutputPyramidTexRef = std::unique_ptr<CudaUtils::CudaTextureReference>(m_codeModule.getTexReference("stage1OutputPyramid"));
    m_stage1OutputPyramidTexRef->setTexelFilterMode(CudaUtils::CudaTextureReference::FILTER_MODE_NEAREST);
    m_stage1OutputPyramidTexRef->setMipmapFilterMode(CudaUtils::CudaTextureReference::FILTER_MODE_NEAREST);
    m_stage1OutputPyramidTexRef->setCoordinateNormalization(true);

    m_patchExtractionKernel = std::unique_ptr<CudaUtils::CudaKernel>(m_codeModule.getKernel("patchExtraction"));
    //m_computeSelfSimilaritiesKernel = std::unique_ptr<CudaUtils::CudaKernel>(m_codeModule.getKernel("computeSelfSimilarities"));

}

InitialPatchGatherer::~InitialPatchGatherer()
{
    //dtor
}

struct EvaluatePatchUsefullnessStage1KernelParams {
    unsigned lod;
    float scaleX;
    float scaleY;
    unsigned yOffset;
    unsigned width;
    unsigned height;
};

struct NonMaximumSuppressKernelParams {
    unsigned lod;
    unsigned maxlod;
    float scaleX;
    float scaleY;
    float minThresh;
    void *candidateData;
    unsigned maxEntries;
};

struct PatchExtractionPatches {
    float score;
    float x, y;
    float angle;
    float lod;
    uint32_t data[16*16];
//    float covarMat[3];
    //float debugPath[10*4];
} __attribute__((packed));



struct PatchExtractionParams {
    unsigned lod0Width;
    unsigned lod0Height;
    uint32_t *candidateData;
    PatchExtractionPatches *patchData;
    unsigned offset;
    unsigned numCandidates;
};


struct computeSelfSimilaritiesParams {
    uint32_t *patchData;
    unsigned count;
    unsigned offset;
} __attribute__((packed));

void InitialPatchGatherer::gatherPatches(CudaUtils::CudaMipmappedTexture &image, std::vector<PatchCandidate> &candidates)
{
    AddCudaScopedProfileInterval("InitialPatchGatherer::gatherPatches");


    m_orientedBlobScoreOutput.resize(image.getLevel(1).getWidth(), image.getLevel(1).getHeight(), 0, CU_AD_FORMAT_FLOAT, 1, CUDA_ARRAY3D_SURFACE_LDST);

    m_sourceImageTexRef->bindMipmappedTexture(&image);
    m_sourceImageTexRef->setMinMaxMipLevel(0.0f, image.getNumLevel()-1);

    for (unsigned i = 1; i < image.getNumLevel(); i++) {
        m_orientedBlobScoreSurfRef->bindTexture(&m_orientedBlobScoreOutput.getLevel(i-1));

        EvaluatePatchUsefullnessStage1KernelParams params;
        params.lod = i;
        params.scaleX = 1.0f / image.getLevel(i).getWidth();
        params.scaleY = 1.0f / image.getLevel(i).getHeight();
        params.width = image.getLevel(i).getWidth();
        params.height = image.getLevel(i).getHeight();

        for (unsigned y = 0; y < image.getLevel(i).getHeight(); y += 1024u) {
            params.yOffset = y;
            m_findOrientedBlobsKernel->launch(LinAlg::Fill(32u, 4u, 1u),
                                              LinAlg::Fill((image.getLevel(i).getWidth() + 3)/4,
                                                           std::min(1024u, image.getLevel(i).getHeight()-y), 1u),
                                              &params, sizeof(params));
        }
    }

#if 0
    for (unsigned i = 1; i < image.getNumLevel(); i++) {
        RasterImage dstImage;
        dstImage.resize(m_orientedBlobScoreOutput.getLevel(i-1).getWidth(), m_orientedBlobScoreOutput.getLevel(i-1).getHeight());
        m_orientedBlobScoreOutput.getLevel(i-1).syncDownloadAll(dstImage.getData(), 4*dstImage.getWidth());

        for (unsigned i = 0; i < dstImage.getWidth() * dstImage.getHeight(); i++) {

            float f = *((float*)&dstImage.getData()[i]) * 1.0f;

            unsigned char *pixel = (unsigned char*)&dstImage.getData()[i];

            LinAlg::Vector<3, unsigned char> c = LinAlg::clampColor(LinAlg::ColorRamp(f));

            pixel[0] = c[0];
            pixel[1] = c[1];
            pixel[2] = c[2];
            pixel[3] = 255;
        }

        std::stringstream str;
        str << "debugOutput"<<i<<".png";

        dstImage.writeToFile(str.str().c_str());
    }
#endif

    std::cout << "Running non maximum supression" << std::endl;

    const unsigned maxPatchCandidates = 50000;
    m_patchCandidateArray.resize(4*(1+maxPatchCandidates*2));
    cuMemsetD8((CUdeviceptr)m_patchCandidateArray.getPtr(), 0, 4);

    std::vector<unsigned> numPatchesUpToLevel;
    numPatchesUpToLevel.resize(m_orientedBlobScoreOutput.getNumLevel()-3);

    {
        m_stage1OutputPyramidTexRef->bindMipmappedTexture(&m_orientedBlobScoreOutput);
        m_stage1OutputPyramidTexRef->setMinMaxMipLevel(0.0f, m_orientedBlobScoreOutput.getNumLevel()-1);
        NonMaximumSuppressKernelParams params;
        params.maxEntries = maxPatchCandidates;
        params.candidateData = m_patchCandidateArray.getPtr();
        params.maxlod = m_orientedBlobScoreOutput.getNumLevel()-3;
        params.minThresh = 0.001f;

        for (unsigned i = 0; i < m_orientedBlobScoreOutput.getNumLevel()-3; i++) {
            params.lod = i;
            params.scaleX = 1.0f / m_orientedBlobScoreOutput.getLevel(i).getWidth();
            params.scaleY = 1.0f / m_orientedBlobScoreOutput.getLevel(i).getHeight();


            m_nonMaximumSuppressKernel->launch(LinAlg::Fill(8u, 8u, 1u),
                                              LinAlg::Fill(m_orientedBlobScoreOutput.getLevel(i).getWidth(),
                                                           m_orientedBlobScoreOutput.getLevel(i).getHeight(), 1u),
                                                  &params, sizeof(params));

            std::vector<uint32_t> data;
            data.resize(1+maxPatchCandidates*2);
            m_patchCandidateArray.download(&data[0], data.size()*4);

            numPatchesUpToLevel[i] = std::min(data[0], maxPatchCandidates);

            std::cout << numPatchesUpToLevel[i] << std::endl;
#if 0
            {
                RasterImage dstImage;
                dstImage.resize(image.getLevel(i).getWidth(), image.getLevel(i).getHeight());

                image.getLevel(i).syncDownloadAll(dstImage.getData(), 4*dstImage.getWidth());

                for (unsigned i = 0; i < dstImage.getWidth() * dstImage.getHeight(); i++) {

                    unsigned char *pixel = (unsigned char*)&dstImage.getData()[i];
                    pixel[0] = pixel[0] * pixel[3]/255;
                    pixel[1] = pixel[1] * pixel[3]/255;
                    pixel[2] = pixel[2] * pixel[3]/255;
                    pixel[3] = 255;
                }

                unsigned startIndex = i == 0?0:numPatchesUpToLevel[i-1];
                for (unsigned i = startIndex; i < data[0]; i++) {
                    int x = (data[1+i*2] & 0xFFFF) * 2;
                    int y = ((data[1+i*2] & 0xFFFF0000) >> 16) * 2;
                    dstImage.drawCircle(LinAlg::Fill(x, y), 16, 0xFF0000FF);
                }

                std::stringstream str;
                str << "candidates"<<i<<".png";

                dstImage.writeToFile(str.str().c_str());

            }
#endif

        }
    }
    uint32_t numPatches;
    m_patchCandidateArray.download(&numPatches, 4);
    numPatches = std::min(numPatches, maxPatchCandidates);

    std::cout << "Extracting patches:" << numPatches << std::endl;

    m_patchesArray.resize(4 + maxPatchCandidates * sizeof(PatchExtractionPatches));
    cuMemsetD8((CUdeviceptr)m_patchesArray.getPtr(), 0, 4);

    {
        PatchExtractionParams params;
        params.lod0Width = image.getLevel(0).getWidth();
        params.lod0Height = image.getLevel(0).getHeight();
        params.candidateData = ((uint32_t*)m_patchCandidateArray.getPtr()) + 1;
        params.patchData = (PatchExtractionPatches*) m_patchesArray.getPtr();
        params.numCandidates = numPatches;

        for (unsigned i = 0; i < numPatches; i += 4096) {
            params.offset = i;
            m_patchExtractionKernel->launch(LinAlg::Fill(32u, 4u, 1u),
                                            LinAlg::Fill((std::min(4096u, numPatches-i) + 3)/4,
                                                           1u, 1u),
                                                  &params, sizeof(params));
        }
    }
/*
    std::cout << "Computing self similarities" << std::endl;
    {
        computeSelfSimilaritiesParams params;
        params.count = numPatches;
        params.patchData = ((uint32_t*)m_patchesArray.getPtr());

        for (unsigned i = 0; i < numPatches; i+= 1024) {
            params.offset = i;
            m_computeSelfSimilaritiesKernel->launch(LinAlg::Fill(16u, 16u, 1u),
                                                      LinAlg::Fill(std::min(1024u, numPatches-i),
                                                                   1u, 1u),
                                                          &params, sizeof(params));
        }
    }
*/
    std::vector<PatchExtractionPatches> rawPatches;
    {
        rawPatches.resize(m_patchesArray.size()/sizeof(PatchExtractionPatches));
        m_patchesArray.download(&rawPatches[0], rawPatches.size()*sizeof(PatchExtractionPatches));

        candidates.clear();
        candidates.reserve(numPatches);
        for (unsigned j = 0; j < numPatches; j++) {
            PatchCandidate candidate;
            candidate.score = rawPatches[j].score;
            candidate.angle = rawPatches[j].angle;
            candidate.lod = rawPatches[j].lod;
            candidate.x = rawPatches[j].x / (float)(image.getLevel(0).getWidth() >> (int)rawPatches[j].lod);
            candidate.y = rawPatches[j].y / (float)(image.getLevel(0).getHeight() >> (int)rawPatches[j].lod);
            memcpy(candidate.data, rawPatches[j].data, 16*16*4);
/*
            std::cout << j << "  : " << rawPatches[j].covarMat[0] << "  "
                                      << rawPatches[j].covarMat[1] << "  "
                                       << rawPatches[j].covarMat[2] << std::endl;
*/
/*
            candidate.covarMat[0][0] = rawPatches[j].covarMat[0];
            candidate.covarMat[0][1] = rawPatches[j].covarMat[1];
            candidate.covarMat[1][0] = rawPatches[j].covarMat[1];
            candidate.covarMat[1][1] = rawPatches[j].covarMat[2];
*/
//            std::cout << rawPatches[j].score << std::endl;
/*
            std::cout << j << std::endl;
            for (unsigned it = 0; it < 10; it++) {
                std::cout << rawPatches[j].debugPath[it*4+0] << "  " <<
                             rawPatches[j].debugPath[it*4+1] << "  " <<
                             rawPatches[j].debugPath[it*4+2] << "  " <<
                             rawPatches[j].debugPath[it*4+3] << std::endl;
            }
*/
            candidates.push_back(candidate);
        }
    }



#if 0
    std::cout << "Draw debug output" << std::endl;

    {
        RasterImage dstImage;
        dstImage.resize(image.getLevel(0).getWidth(), image.getLevel(0).getHeight());

        image.getLevel(0).syncDownloadAll(dstImage.getData(), 4*dstImage.getWidth());

        for (unsigned i = 0; i < dstImage.getWidth() * dstImage.getHeight(); i++) {

            unsigned char *pixel = (unsigned char*)&dstImage.getData()[i];
            pixel[0] = pixel[0] * pixel[3]/255;
            pixel[1] = pixel[1] * pixel[3]/255;
            pixel[2] = pixel[2] * pixel[3]/255;
            pixel[3] = 255;
        }

        for (unsigned i = 0; i < candidates.size(); i++) {



            float x = candidates[i].x * dstImage.getWidth();
            float y = candidates[i].y * dstImage.getHeight();

            float cosAngle = std::cos(candidates[i].angle);
            float sinAngle = std::sin(candidates[i].angle);

            unsigned lod = (unsigned)candidates[i].lod;
            float scale = 1.0f + candidates[i].lod - lod;

            float patchW = 8.0f * (1 << lod) * scale;

            LinAlg::Vector<3, unsigned char> color = LinAlg::clampColor(LinAlg::ColorRamp(candidates[i].score*0.02f));
            uint32_t colorUint32 = (color[0] << 0) |
                                   (color[1] << 8) |
                                   (color[2] << 16) |
                                   (0xFF << 24);
/*
            {
                float *debugData = rawPatches[i].debugPath;

                float scaleX = dstImage.getWidth() / (float)(dstImage.getWidth() >> lod);
                float scaleY = dstImage.getHeight() / (float)(dstImage.getHeight() >> lod);

                for (unsigned it = 0; it < 9; it++) {
                    dstImage.drawLine(LinAlg::Fill<int>(debugData[it*4+0] * scaleX, debugData[it*4+1] * scaleY),
                                      LinAlg::Fill<int>(debugData[(it+1)*4+0] * scaleX, debugData[(it+1)*4+1] * scaleY), 0xFF00FF00);
                }
            }
*/
            dstImage.drawLine(LinAlg::Fill<int>(x+patchW*(cosAngle-sinAngle), y+patchW*(sinAngle+cosAngle)),
                              LinAlg::Fill<int>(x+patchW*(-cosAngle-sinAngle), y+patchW*(-sinAngle+cosAngle)), colorUint32);

            dstImage.drawLine(LinAlg::Fill<int>(x+patchW*(-cosAngle-sinAngle), y+patchW*(-sinAngle+cosAngle)),
                              LinAlg::Fill<int>(x+patchW*(-cosAngle+sinAngle), y+patchW*(-sinAngle-cosAngle)), colorUint32);

            dstImage.drawLine(LinAlg::Fill<int>(x+patchW*(-cosAngle+sinAngle), y+patchW*(-sinAngle-cosAngle)),
                              LinAlg::Fill<int>(x+patchW*(cosAngle+sinAngle), y+patchW*(sinAngle-cosAngle)), colorUint32);

            dstImage.drawLine(LinAlg::Fill<int>(x+patchW*(cosAngle+sinAngle), y+patchW*(sinAngle-cosAngle)),
                             LinAlg::Fill<int>(x+patchW*(cosAngle-sinAngle), y+patchW*(sinAngle+cosAngle)), colorUint32);

            dstImage.drawLine(LinAlg::Fill<int>(x, y),
                              LinAlg::Fill<int>(x+patchW*cosAngle, y+patchW*sinAngle), 0xFF0000FF);


        }

        dstImage.writeToFile("patches.png");
    }
#endif
#if 0
    {
        unsigned patchRows = (int)(1+sqrtf(candidates.size()));
        RasterImage dstImage;
        dstImage.resize(20*patchRows, 20*patchRows);

        for (unsigned i = 0; i < candidates.size(); i++) {
            unsigned col = i % patchRows;
            unsigned row = i / patchRows;

            LinAlg::Vector<3, unsigned char> color = LinAlg::clampColor(LinAlg::ColorRamp(candidates[i].score*0.02f));
            uint32_t colorUint32 = (color[0] << 0) |
                                   (color[1] << 8) |
                                   (color[2] << 16) |
                                   (0xFF << 24);

            dstImage.drawBox(LinAlg::Fill<int>(col*20, row*20),LinAlg::Fill<int>(col*20+19, row*20+19), colorUint32, true);

            for (unsigned y = 0; y < 16; y++)
                for (unsigned x = 0; x < 16; x++) {
                    const unsigned char *srcPixel = (const unsigned char*) &candidates[i].data[y*16+x];
                    unsigned char *dstPixel = (unsigned char*) &dstImage.getData()[(row*20+y+2)*dstImage.getWidth() + col*20+x+2];
                    dstPixel[0] = std::min(srcPixel[0] * srcPixel[3]/255, 255);
                    dstPixel[1] = std::min(srcPixel[1] * srcPixel[3]/255, 255);
                    dstPixel[2] = std::min(srcPixel[2] * srcPixel[3]/255, 255);
                    dstPixel[3] = 255;
                }

        }
        dstImage.writeToFile("patchData.png");
    }
#endif
}

