/*
    Structure from Motion with Deferred Feature Matching and Subset Bundle Adjustment
    Copyright (C) 2015 Andreas Ley <andy-ley@arcor.de>

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
*/

#include "NewTrackObservationTester.h"
#include "PatchAtlas.h"
#include "../cudaKernels/newObservationTest.h"
#include "../tools/RasterImage.h"


NewTrackObservationTester::NewTrackObservationTester()
{
    m_codeModule.loadFromFile("../SFMBackend/kernels/Release/newObservationTest.fatbin");

    m_testKernel = std::unique_ptr<CudaUtils::CudaKernel>(m_codeModule.getKernel("computeNewObservationProbability"));
    m_patchAtlasTexRef = std::unique_ptr<CudaUtils::CudaTextureReference>(m_codeModule.getTexReference("patchAtlas"));

    m_patchAtlasTexRef->setTexelFilterMode(CudaUtils::CudaTextureReference::FILTER_MODE_LINEAR);
    m_patchAtlasTexRef->setMipmapFilterMode(CudaUtils::CudaTextureReference::FILTER_MODE_LINEAR);
    m_patchAtlasTexRef->setCoordinateNormalization(true);

    m_kernelPatchAtlasConstantParams = std::unique_ptr<CudaUtils::CudaConstantMemory>(m_codeModule.getConstantMemory("patchAtlasConstants"));

}

NewTrackObservationTester::~NewTrackObservationTester()
{
    //dtor
}


void NewTrackObservationTester::test(std::vector<Track> &tracks, std::vector<NewObservation> &newObservations, PatchAtlas *patchAtlas)
{
    m_patchAtlasTexRef->bindMipmappedTexture(patchAtlas->getAtlasTexture());
    m_patchAtlasTexRef->setMinMaxMipLevel(0.0f, patchAtlas->getAtlasTexture()->getNumLevel()-1);

    unsigned memorySize = tracks.size() * sizeof(TrackHead);
    for (unsigned i = 0; i < tracks.size(); i++) {
        memorySize += tracks[i].observations.size() * sizeof(TrackObservation);
    }

    std::vector<unsigned> trackOffsets;
    std::vector<NewObservationCandidate> cpuNewObservationData;
    cpuNewObservationData.resize(newObservations.size());
    trackOffsets.resize(tracks.size());
    std::vector<unsigned char> cpuTrackData;
    cpuTrackData.resize(memorySize);

    {
        unsigned char *ptr = &cpuTrackData[0];
        for (unsigned i = 0; i < tracks.size(); i++) {
            trackOffsets[i] = ((uint64_t)ptr - (uint64_t)&cpuTrackData[0]);
            TrackHead *head = (TrackHead*) ptr;
            ptr += sizeof(TrackHead);

            head->size = tracks[i].size;

            LinAlg::Vector3f wsPos = tracks[i].worldSpacePosition.StripHom() * (1.0f / tracks[i].worldSpacePosition[3]);

            head->trackSurfaceToWorld[0*4+0] = tracks[i].orientation[0][0];
            head->trackSurfaceToWorld[0*4+1] = tracks[i].orientation[0][1];
            head->trackSurfaceToWorld[0*4+2] = tracks[i].orientation[0][2];
            head->trackSurfaceToWorld[0*4+3] = wsPos[0];

            head->trackSurfaceToWorld[1*4+0] = tracks[i].orientation[1][0];
            head->trackSurfaceToWorld[1*4+1] = tracks[i].orientation[1][1];
            head->trackSurfaceToWorld[1*4+2] = tracks[i].orientation[1][2];
            head->trackSurfaceToWorld[1*4+3] = wsPos[1];

            head->trackSurfaceToWorld[2*4+0] = tracks[i].orientation[2][0];
            head->trackSurfaceToWorld[2*4+1] = tracks[i].orientation[2][1];
            head->trackSurfaceToWorld[2*4+2] = tracks[i].orientation[2][2];
            head->trackSurfaceToWorld[2*4+3] = wsPos[2];

            head->numObservations = tracks[i].observations.size();

            for (unsigned j = 0; j < tracks[i].observations.size(); j++) {
                TrackObservation *obs = (TrackObservation*) ptr;
                ptr += sizeof(TrackObservation);

                obs->patchAtlasIndex = tracks[i].observations[j].patchAtlasIndex;
                obs->screenSpaceOffset[0] = tracks[i].observations[j].screenSpaceOffset[0];
                obs->screenSpaceOffset[1] = tracks[i].observations[j].screenSpaceOffset[1];
            }
        }
    }
    m_trackData.resize(cpuTrackData.size());
    m_trackData.upload(&cpuTrackData[0], cpuTrackData.size());

    for (unsigned i = 0; i < cpuNewObservationData.size(); i++) {
        cpuNewObservationData[i].track = (TrackHead*) (((uint64_t)m_trackData.getPtr()) + trackOffsets[newObservations[i].trackIndex]);
        cpuNewObservationData[i].patchAtlasIndex = newObservations[i].patchAtlasIndex;
        cpuNewObservationData[i].screenSpaceOffset[0] = newObservations[i].screenSpaceOffset[0];
        cpuNewObservationData[i].screenSpaceOffset[1] = newObservations[i].screenSpaceOffset[1];
        cpuNewObservationData[i].error = 1234.0f;
    }
    m_newObservationData.resize(cpuNewObservationData.size()*sizeof(NewObservationCandidate));
    m_newObservationData.upload(&cpuNewObservationData[0], cpuNewObservationData.size()*sizeof(NewObservationCandidate));

    m_kernelPatchAtlasConstantParams->upload(&patchAtlas->getConstants(), sizeof(PatchAtlasConstants));

    m_patchAtlasData.resize(patchAtlas->getPatchParams().size() * sizeof(PatchAtlasPatchParams));
    m_patchAtlasData.upload(&patchAtlas->getPatchParams()[0], patchAtlas->getPatchParams().size() * sizeof(PatchAtlasPatchParams));


#ifdef NEWOBS_OUTPUT_DEBUG_PAIRS
    CudaUtils::CudaDeviceMemory debugVData;
    debugVData.resize(24*24*maxDebugOutputPairs*2*4);
#endif

    {
        NewObservationTestParams kernelParams;
        kernelParams.newCandidates = (NewObservationCandidate*)m_newObservationData.getPtr();
        kernelParams.patchAtlasPatchParams = (PatchAtlasPatchParams*) m_patchAtlasData.getPtr();
        kernelParams.trackData = (unsigned char*) m_trackData.getPtr();
#ifdef NEWOBS_OUTPUT_DEBUG_PAIRS
        kernelParams.debugOutput = (uint32_t*)debugVData.getPtr();
#endif

        for (unsigned i = 0; i < newObservations.size(); i+=32768) {
            kernelParams.blockOffset = i;
            m_testKernel->launch(LinAlg::Fill(8u, 8u, 1u),
                                          LinAlg::Fill<unsigned>(std::min<unsigned>(32768, newObservations.size()-i), 1u, 1u),
                                          &kernelParams, sizeof(kernelParams));
        }
    }

    m_newObservationData.download(&cpuNewObservationData[0], cpuNewObservationData.size()*sizeof(NewObservationCandidate));


    for (unsigned i = 0; i < cpuNewObservationData.size(); i++) {
        newObservations[i].error = cpuNewObservationData[i].error;
     //   std::cout << "Error for " << i<< "  " << newObservations[i].error << std::endl;
    }
/*
    std::cout << "E = [ ";

    for (unsigned i = 0; i < cpuNewObservationData.size(); i++) {
        std::cout << "  " << newObservations[i].error;
    }
    std::cout << "];" << std::endl;
    */


#ifdef NEWOBS_OUTPUT_DEBUG_PAIRS
    std::vector<uint32_t> debugData;
    debugData.resize(24*24*maxDebugOutputPairs*2);
    debugVData.download(&debugData[0], debugData.size()*4);

    {
        RasterImage debugImage;
        unsigned numTrackInImage = std::min<unsigned>(newObservations.size(), maxDebugOutputPairs);
        const unsigned debugPatchSize = 24;
        debugImage.resize(numTrackInImage*(debugPatchSize+4), (debugPatchSize+4)*2);
        for (unsigned i = 0; i < numTrackInImage; i++) {

            LinAlg::Vector<3, unsigned char> color = LinAlg::clampColor(LinAlg::ColorRamp(newObservations[i].error/20.0f));

            uint32_t colorUint32 = (color[0] << 0) |
                                   (color[1] << 8) |
                                   (color[2] << 16) |
                                   (0xFF << 24);

            debugImage.drawBox(LinAlg::Fill<int>(i*(24+4), 0),LinAlg::Fill<int>(i*(debugPatchSize+4)+(debugPatchSize+3), (debugPatchSize+4)*2-1), colorUint32, true);

            for (unsigned y = 0; y < debugPatchSize; y++)
                for (unsigned x = 0; x < debugPatchSize; x++) {
                    const unsigned char *srcPixel = (const unsigned char*) &debugData[i*debugPatchSize*debugPatchSize*2+y*debugPatchSize+x];
                    unsigned char *dstPixel = (unsigned char*) &debugImage.getData()[(0+y+2)*debugImage.getWidth() + i*(debugPatchSize+4)+x+2];
                    dstPixel[0] = std::min(srcPixel[0] * srcPixel[3]*1/255, 255);
                    dstPixel[1] = std::min(srcPixel[1] * srcPixel[3]*1/255, 255);
                    dstPixel[2] = std::min(srcPixel[2] * srcPixel[3]*1/255, 255);
                    dstPixel[3] = 255;
                }

            for (unsigned y = 0; y < debugPatchSize; y++)
                for (unsigned x = 0; x < debugPatchSize; x++) {
                    const unsigned char *srcPixel = (const unsigned char*) &debugData[i*debugPatchSize*debugPatchSize*2+debugPatchSize*debugPatchSize+y*debugPatchSize+x];
                    unsigned char *dstPixel = (unsigned char*) &debugImage.getData()[(debugPatchSize+4+y+2)*debugImage.getWidth() + i*(debugPatchSize+4)+x+2];
                    dstPixel[0] = std::min(srcPixel[0] * srcPixel[3]*1/255, 255);
                    dstPixel[1] = std::min(srcPixel[1] * srcPixel[3]*1/255, 255);
                    dstPixel[2] = std::min(srcPixel[2] * srcPixel[3]*1/255, 255);
                    dstPixel[3] = 255;
                }
        }
        debugImage.writeToFile("NewObservationDump.png");
    }
#endif

}
