/*
    Structure from Motion with Deferred Feature Matching and Subset Bundle Adjustment
    Copyright (C) 2015 Andreas Ley <andy-ley@arcor.de>

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
*/

#include "trackAlignmentOptimizer.h"
#include "PatchAtlas.h"
#include "../cudaKernels/trackOptimization.h"
#include "../tools/RasterImage.h"

TrackAlignmentOptimizer::TrackAlignmentOptimizer()
{
    m_codeModule.loadFromFile("../SFMBackend/kernels/Release/trackOptimization.fatbin");

    m_alignmentKernel = std::unique_ptr<CudaUtils::CudaKernel>(m_codeModule.getKernel("optimizeTracks"));
    m_patchAtlasTexRef = std::unique_ptr<CudaUtils::CudaTextureReference>(m_codeModule.getTexReference("patchAtlas"));

    m_patchAtlasTexRef->setTexelFilterMode(CudaUtils::CudaTextureReference::FILTER_MODE_LINEAR);
    m_patchAtlasTexRef->setMipmapFilterMode(CudaUtils::CudaTextureReference::FILTER_MODE_LINEAR);
    m_patchAtlasTexRef->setCoordinateNormalization(true);

    m_kernelPatchAtlasConstantParams = std::unique_ptr<CudaUtils::CudaConstantMemory>(m_codeModule.getConstantMemory("patchAtlasConstants"));


    m_extractionKernel = std::unique_ptr<CudaUtils::CudaKernel>(m_codeModule.getKernel("extractPreWarpedProjections"));
    m_kernelPreWarpConstantParams = std::unique_ptr<CudaUtils::CudaConstantMemory>(m_codeModule.getConstantMemory("ExtractionPreWarpMatrices"));

}

TrackAlignmentOptimizer::~TrackAlignmentOptimizer()
{
    //dtor
}


void TrackAlignmentOptimizer::optimize(std::vector<Track> &tracks, PatchAtlas *patchAtlas, unsigned numIterations)
{
    m_patchAtlasTexRef->bindMipmappedTexture(patchAtlas->getAtlasTexture());
    m_patchAtlasTexRef->setMinMaxMipLevel(0.0f, patchAtlas->getAtlasTexture()->getNumLevel()-1);

    unsigned memorySize = tracks.size() * sizeof(TrackHead);
    for (unsigned i = 0; i < tracks.size(); i++) {
        memorySize += tracks[i].observations.size() * sizeof(TrackObservation);
    }

    std::vector<unsigned char*> cpuIndexData;
    cpuIndexData.resize(tracks.size());
    std::vector<unsigned char> cpuTrackData;
    cpuTrackData.resize(memorySize);


    {
        unsigned char *ptr = &cpuTrackData[0];
        for (unsigned i = 0; i < tracks.size(); i++) {
            cpuIndexData[i] = (unsigned char*) (ptr - &cpuTrackData[0]);
            TrackHead *head = (TrackHead*) ptr;
            ptr += sizeof(TrackHead);

            head->size = tracks[i].size;

            LinAlg::Vector3f wsPos = tracks[i].worldSpacePosition.StripHom() * (1.0f / tracks[i].worldSpacePosition[3]);

            head->trackSurfaceToWorld[0*4+0] = tracks[i].orientation[0][0];
            head->trackSurfaceToWorld[0*4+1] = tracks[i].orientation[0][1];
            head->trackSurfaceToWorld[0*4+2] = tracks[i].orientation[0][2];
            head->trackSurfaceToWorld[0*4+3] = wsPos[0];

            head->trackSurfaceToWorld[1*4+0] = tracks[i].orientation[1][0];
            head->trackSurfaceToWorld[1*4+1] = tracks[i].orientation[1][1];
            head->trackSurfaceToWorld[1*4+2] = tracks[i].orientation[1][2];
            head->trackSurfaceToWorld[1*4+3] = wsPos[1];

            head->trackSurfaceToWorld[2*4+0] = tracks[i].orientation[2][0];
            head->trackSurfaceToWorld[2*4+1] = tracks[i].orientation[2][1];
            head->trackSurfaceToWorld[2*4+2] = tracks[i].orientation[2][2];
            head->trackSurfaceToWorld[2*4+3] = wsPos[2];

            head->numObservations = tracks[i].observations.size();

 //   std::cout << "Track " << i << std::endl;
            for (unsigned j = 0; j < tracks[i].observations.size(); j++) {
 //   std::cout << " obs  " <<  j << std::endl;
                TrackObservation *obs = (TrackObservation*) ptr;
                ptr += sizeof(TrackObservation);

                obs->patchAtlasIndex = tracks[i].observations[j].patchAtlasIndex;
                obs->screenSpaceOffset[0] = tracks[i].observations[j].screenSpaceOffset[0];
                obs->screenSpaceOffset[1] = tracks[i].observations[j].screenSpaceOffset[1];
                obs->screenSize = tracks[i].observations[j].screenSize;

                LinAlg::Matrix4x4f P;

                PatchAtlasPatchParams param = patchAtlas->getPatchParams()[tracks[i].observations[j].patchAtlasIndex];
                memcpy(&P, param.worldToAtlasPatch, 4*4*4);

                LinAlg::Vector4f atlasPos = P * tracks[i].worldSpacePosition;
                atlasPos /= atlasPos[3];
/*
                std::cout << obs->patchAtlasIndex << std::endl;
                std::cout << (std::string) atlasPos << std::endl;
                std::cout << (unsigned)param.patchX << "  "  << (unsigned)param.patchY << std::endl;
                std::cout << param.patchX * PatchAtlasConstants::ATLAS_PATCH_SIZE / 2048.0f << "  "  << param.patchY * PatchAtlasConstants::ATLAS_PATCH_SIZE / 2048.0f << std::endl;
                std::cout << (param.patchX+1) * PatchAtlasConstants::ATLAS_PATCH_SIZE / 2048.0f << "  "  << (param.patchY+1) * PatchAtlasConstants::ATLAS_PATCH_SIZE / 2048.0f << std::endl;
                std::cout << std::endl;
                */
            }
        }
    }
    m_trackData.resize(cpuTrackData.size());
    m_trackData.upload(&cpuTrackData[0], cpuTrackData.size());

    for (unsigned i = 0; i < cpuIndexData.size(); i++) {
        cpuIndexData[i] = cpuIndexData[i] + (uint64_t)m_trackData.getPtr();
    }
    m_indexData.resize(cpuIndexData.size()*8);
    m_indexData.upload(&cpuIndexData[0], cpuIndexData.size()*8);

    m_kernelPatchAtlasConstantParams->upload(&patchAtlas->getConstants(), sizeof(PatchAtlasConstants));

    m_patchAtlasData.resize(patchAtlas->getPatchParams().size() * sizeof(PatchAtlasPatchParams));
    m_patchAtlasData.upload(&patchAtlas->getPatchParams()[0], patchAtlas->getPatchParams().size() * sizeof(PatchAtlasPatchParams));

#ifdef TRACK_ALIGN_OUTPUT_DEBUG_DATA
    CudaUtils::CudaDeviceMemory debugVData;
    debugVData.resize(24*24*tracks.size()*2*4);
#endif
    {
        OptimizeTracksKernalParams kernelParams;
        kernelParams.index = (unsigned char**)m_indexData.getPtr();
        kernelParams.patchAtlasPatchParams = (PatchAtlasPatchParams*) m_patchAtlasData.getPtr();
#ifdef TRACK_ALIGN_OUTPUT_DEBUG_DATA
        kernelParams.debugOutput = (uint32_t*)debugVData.getPtr();
#endif
        kernelParams.numIterations = numIterations;

        for (unsigned i = 0; i < tracks.size(); i+=64) {
            kernelParams.blockOffset = i;
            m_alignmentKernel->launch(LinAlg::Fill(8u, 8u, 1u),
                                          LinAlg::Fill<unsigned>(std::min<unsigned>(64, tracks.size()-i), 1u, 1u),
                                          &kernelParams, sizeof(kernelParams));
        }
    }
    CudaUtils::CudaDeviceMemory warpedProjectionVMem;
    {
        std::vector<LinAlg::Matrix4x4f> preWarpMatrices;
        preWarpMatrices.resize(11);
        preWarpMatrices[1] = LinAlg::RotateZ(-(float)M_PI / 4.0f * 0.0f) * LinAlg::Scale3D(LinAlg::Fill(1.5f, 1.0f, 1.0f)) * LinAlg::RotateZ((float)M_PI / 4.0f * 0.0f);
        preWarpMatrices[2] = LinAlg::RotateZ(-(float)M_PI / 4.0f * 1.0f) * LinAlg::Scale3D(LinAlg::Fill(1.5f, 1.0f, 1.0f)) * LinAlg::RotateZ((float)M_PI / 4.0f * 1.0f);
        preWarpMatrices[3] = LinAlg::RotateZ(-(float)M_PI / 4.0f * 2.0f) * LinAlg::Scale3D(LinAlg::Fill(1.5f, 1.0f, 1.0f)) * LinAlg::RotateZ((float)M_PI / 4.0f * 2.0f);
        preWarpMatrices[4] = LinAlg::RotateZ(-(float)M_PI / 4.0f * 3.0f) * LinAlg::Scale3D(LinAlg::Fill(1.5f, 1.0f, 1.0f)) * LinAlg::RotateZ((float)M_PI / 4.0f * 3.0f);
        preWarpMatrices[5] = LinAlg::RotateZ(-(float)M_PI / 6.0f * 0.0f) * LinAlg::Scale3D(LinAlg::Fill(2.0f, 1.0f, 1.0f)) * LinAlg::RotateZ((float)M_PI / 6.0f * 0.0f);
        preWarpMatrices[6] = LinAlg::RotateZ(-(float)M_PI / 6.0f * 1.0f) * LinAlg::Scale3D(LinAlg::Fill(2.0f, 1.0f, 1.0f)) * LinAlg::RotateZ((float)M_PI / 6.0f * 1.0f);
        preWarpMatrices[7] = LinAlg::RotateZ(-(float)M_PI / 6.0f * 2.0f) * LinAlg::Scale3D(LinAlg::Fill(2.0f, 1.0f, 1.0f)) * LinAlg::RotateZ((float)M_PI / 6.0f * 2.0f);
        preWarpMatrices[8] = LinAlg::RotateZ(-(float)M_PI / 6.0f * 3.0f) * LinAlg::Scale3D(LinAlg::Fill(2.0f, 1.0f, 1.0f)) * LinAlg::RotateZ((float)M_PI / 6.0f * 3.0f);
        preWarpMatrices[9] = LinAlg::RotateZ(-(float)M_PI / 6.0f * 4.0f) * LinAlg::Scale3D(LinAlg::Fill(2.0f, 1.0f, 1.0f)) * LinAlg::RotateZ((float)M_PI / 6.0f * 4.0f);
        preWarpMatrices[10] = LinAlg::RotateZ(-(float)M_PI / 6.0f * 5.0f) * LinAlg::Scale3D(LinAlg::Fill(2.0f, 1.0f, 1.0f)) * LinAlg::RotateZ((float)M_PI / 6.0f * 5.0f);
        m_kernelPreWarpConstantParams->upload(&preWarpMatrices[0][0][0], 11*4*4*4);

        warpedProjectionVMem.resize(tracks.size()*16*16*4*11);
        {
#if 0
            ExtractProjectionsKernelParams params;
            params.index = (unsigned char**)m_indexData.getPtr();
            params.patchAtlasPatchParams = (PatchAtlasPatchParams*) m_patchAtlasData.getPtr();
            params.dst = (uint32_t*)warpedProjectionVMem.getPtr();
            for (unsigned i = 0; i < tracks.size(); i+=128) {
                params.blockOffset = i;
                m_extractionKernel->launch(LinAlg::Fill(8u, 8u, 1u),
                                              LinAlg::Fill<unsigned>(std::min<unsigned>(128, tracks.size()-i), 1u, 1u),
                                              &params, sizeof(params));
            }
#endif
        }

    }
    m_trackData.download(&cpuTrackData[0], cpuTrackData.size());

    std::vector<uint32_t> preWarpedProjectionData;
    preWarpedProjectionData.resize(tracks.size()*16*16*11);

    //warpedProjectionVMem.download(&preWarpedProjectionData[0], tracks.size()*16*16*11*4);

    {
        unsigned char *ptr = &cpuTrackData[0];
        for (unsigned i = 0; i < tracks.size(); i++) {
            cpuIndexData[i] = (unsigned char*) (ptr - &cpuTrackData[0]);
            TrackHead *head = (TrackHead*) ptr;
            ptr += sizeof(TrackHead);

            //head->size = tracks[i].size;

           // LinAlg::Vector3f wsPos = tracks[i].worldSpacePosition.StripHom() * (1.0f / tracks[i].worldSpacePosition[3]);

            tracks[i].orientation[0][0] = head->trackSurfaceToWorld[0*4+0];
            tracks[i].orientation[0][1] = head->trackSurfaceToWorld[0*4+1];
            tracks[i].orientation[0][2] = head->trackSurfaceToWorld[0*4+2];

            tracks[i].orientation[1][0] = head->trackSurfaceToWorld[1*4+0];
            tracks[i].orientation[1][1] = head->trackSurfaceToWorld[1*4+1];
            tracks[i].orientation[1][2] = head->trackSurfaceToWorld[1*4+2];

            tracks[i].orientation[2][0] = head->trackSurfaceToWorld[2*4+0];
            tracks[i].orientation[2][1] = head->trackSurfaceToWorld[2*4+1];
            tracks[i].orientation[2][2] = head->trackSurfaceToWorld[2*4+2];

            tracks[i].remainingError = head->remainingError;

            //memcpy(tracks[i].preWarpedPatches, &preWarpedProjectionData[i*16*16*11], 16*16*11*4);


    //std::cout << "Track " << i << std::endl;
            for (unsigned j = 0; j < tracks[i].observations.size(); j++) {
   // std::cout << " obs  " <<  j << std::endl;
                TrackObservation *obs = (TrackObservation*) ptr;
                ptr += sizeof(TrackObservation);
/*
                std::cout << tracks[i].observations[j].screenSpaceOffset[0] - obs->screenSpaceOffset[0] << std::endl;
                std::cout << tracks[i].observations[j].screenSpaceOffset[1] - obs->screenSpaceOffset[1] << std::endl;

                std::cout << (tracks[i].observations[j].screenSpaceOffset[0] - obs->screenSpaceOffset[0]) / obs->screenSize << " patch pixels" << std::endl;
                std::cout << (tracks[i].observations[j].screenSpaceOffset[1] - obs->screenSpaceOffset[1]) / obs->screenSize << " patch pixels" << std::endl;
*/
                tracks[i].observations[j].screenSpaceOffset[0] = obs->screenSpaceOffset[0];
                tracks[i].observations[j].screenSpaceOffset[1] = obs->screenSpaceOffset[1];

/*
                obs->patchAtlasIndex = tracks[i].observations[j].patchAtlasIndex;
                obs->screenSpaceOffset[0] = tracks[i].observations[j].screenSpaceOffset[0];
                obs->screenSpaceOffset[1] = tracks[i].observations[j].screenSpaceOffset[1];

                LinAlg::Matrix4x4f P;

                PatchAtlasPatchParams param = patchAtlas->getPatchParams()[tracks[i].observations[j].patchAtlasIndex];
                memcpy(&P, param.worldToAtlasPatch, 4*4*4);

                LinAlg::Vector4f atlasPos = P * tracks[i].worldSpacePosition;
                atlasPos /= atlasPos[3];

                std::cout << obs->patchAtlasIndex << std::endl;
                std::cout << (std::string) atlasPos << std::endl;
                std::cout << (unsigned)param.patchX << "  "  << (unsigned)param.patchY << std::endl;
                std::cout << param.patchX * PatchAtlasConstants::ATLAS_PATCH_SIZE / 2048.0f << "  "  << param.patchY * PatchAtlasConstants::ATLAS_PATCH_SIZE / 2048.0f << std::endl;
                std::cout << (param.patchX+1) * PatchAtlasConstants::ATLAS_PATCH_SIZE / 2048.0f << "  "  << (param.patchY+1) * PatchAtlasConstants::ATLAS_PATCH_SIZE / 2048.0f << std::endl;
                std::cout << std::endl;
                */
            }
        }
    }

#ifdef TRACK_ALIGN_OUTPUT_DEBUG_DATA
    std::vector<uint32_t> debugData;
    debugData.resize(24*24*tracks.size()*2);
    debugVData.download(&debugData[0], debugData.size()*4);

    {
        RasterImage debugImage;
        unsigned numTrackInImage = std::min<unsigned>(tracks.size(), 500u);
        const unsigned debugPatchSize = 24;
        debugImage.resize(numTrackInImage*(debugPatchSize+4), (debugPatchSize+4)*2);
        for (unsigned i = 0; i < numTrackInImage; i++) {

            /*
            LinAlg::Vector<3, unsigned char> color = LinAlg::clampColor(LinAlg::ColorRamp(errorData[i]*10.0f));

            uint32_t colorUint32 = (color[0] << 0) |
                                   (color[1] << 8) |
                                   (color[2] << 16) |
                                   (0xFF << 24);
            */
            uint32_t colorUint32 = 0xFF000000;

            debugImage.drawBox(LinAlg::Fill<int>(i*(24+4), 0),LinAlg::Fill<int>(i*(debugPatchSize+4)+(debugPatchSize+3), (debugPatchSize+4)*2-1), colorUint32, true);

            for (unsigned y = 0; y < debugPatchSize; y++)
                for (unsigned x = 0; x < debugPatchSize; x++) {
                    const unsigned char *srcPixel = (const unsigned char*) &debugData[i*debugPatchSize*debugPatchSize*2+y*debugPatchSize+x];
                    unsigned char *dstPixel = (unsigned char*) &debugImage.getData()[(0+y+2)*debugImage.getWidth() + i*(debugPatchSize+4)+x+2];
                    dstPixel[0] = std::min(srcPixel[0] * srcPixel[3]*1/255, 255);
                    dstPixel[1] = std::min(srcPixel[1] * srcPixel[3]*1/255, 255);
                    dstPixel[2] = std::min(srcPixel[2] * srcPixel[3]*1/255, 255);
                    dstPixel[3] = 255;
                }

            for (unsigned y = 0; y < debugPatchSize; y++)
                for (unsigned x = 0; x < debugPatchSize; x++) {
                    const unsigned char *srcPixel = (const unsigned char*) &debugData[i*debugPatchSize*debugPatchSize*2+debugPatchSize*debugPatchSize+y*debugPatchSize+x];
                    unsigned char *dstPixel = (unsigned char*) &debugImage.getData()[(debugPatchSize+4+y+2)*debugImage.getWidth() + i*(debugPatchSize+4)+x+2];
                    dstPixel[0] = std::min(srcPixel[0] * srcPixel[3]*1/255, 255);
                    dstPixel[1] = std::min(srcPixel[1] * srcPixel[3]*1/255, 255);
                    dstPixel[2] = std::min(srcPixel[2] * srcPixel[3]*1/255, 255);
                    dstPixel[3] = 255;
                }
        }
        debugImage.writeToFile("debugDump.png");
    }


    {
        RasterImage debugImage;
        unsigned numTrackInImage = std::min<unsigned>(tracks.size(), 200u);
        debugImage.resize(numTrackInImage*(16+4), (16+4)*11);
        for (unsigned i = 0; i < numTrackInImage; i++) {

            LinAlg::Vector<3, unsigned char> color = LinAlg::clampColor(LinAlg::ColorRamp(0.0f));//errorData[i]*10.0f));

            uint32_t colorUint32 = (color[0] << 0) |
                                   (color[1] << 8) |
                                   (color[2] << 16) |
                                   (0xFF << 24);

            debugImage.drawBox(LinAlg::Fill<int>(i*(16+4), 0),LinAlg::Fill<int>(i*(16+4)+(16+3), (16+4)*11-1), colorUint32, true);

            for (unsigned warp = 0; warp < 11; warp++) {
                for (unsigned y = 0; y < 16; y++)
                    for (unsigned x = 0; x < 16; x++) {
                        const unsigned char *srcPixel = (const unsigned char*) &preWarpedProjectionData[i*16*16*11+warp*16*16+y*16+x];
                        unsigned char *dstPixel = (unsigned char*) &debugImage.getData()[(warp*(16+4)+y+2)*debugImage.getWidth() + i*(16+4)+x+2];
                        dstPixel[0] = std::min(srcPixel[0] * srcPixel[3]*4/255, 255);
                        dstPixel[1] = std::min(srcPixel[1] * srcPixel[3]*4/255, 255);
                        dstPixel[2] = std::min(srcPixel[2] * srcPixel[3]*4/255, 255);
                        dstPixel[3] = 255;
                    }
            }
        }
        debugImage.writeToFile("warpedprojectedPatches.png");
    }

#endif
}
