/*
    Structure from Motion with Deferred Feature Matching and Subset Bundle Adjustment
    Copyright (C) 2015 Andreas Ley <andy-ley@arcor.de>

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
*/

#include "PatchAtlas.h"
#include <assert.h>

#include "../tools/RasterImage.h"
#include <sstream>




#include "../cudaUtilities/CudaCodeModule.h"
#include "../cudaUtilities/CudaKernel.h"

#include "../cudaUtilities/CudaTextureMemory.h"
#include "../cudaUtilities/CudaDeviceMemory.h"

#include "../cudaUtilities/CudaTextureReference.h"
#include "../cudaUtilities/CudaSurfaceReference.h"



PatchAtlas::PatchAtlas()
{
    const unsigned layerWidth = 2048;
    const unsigned numLayers = 32;
    m_atlasTexture = std::unique_ptr<CudaUtils::CudaMipmappedTexture>(new CudaUtils::CudaMipmappedTexture());
    m_atlasTexture->resize(layerWidth, layerWidth, numLayers, CU_AD_FORMAT_UNSIGNED_INT8, 4, CUDA_ARRAY3D_SURFACE_LDST | CUDA_ARRAY3D_LAYERED, PatchAtlasConstants::ATLAS_MIP_LEVEL);

    m_constants.halfTexelSize = 0.5f / (float)layerWidth;
    m_constants.patchSize = PatchAtlasConstants::ATLAS_PATCH_SIZE / (float)layerWidth;
    const unsigned rows = layerWidth/PatchAtlasConstants::ATLAS_PATCH_SIZE;

    m_numPatches = rows * rows * numLayers;
    m_patchParams.resize(m_numPatches);

    m_inUse.resize((m_numPatches + 63)/64);
    for (unsigned i = 0; i < m_inUse.size(); i++)
        m_inUse[i] = 0;

    for (unsigned z = 0; z < numLayers; z++) {
        for (unsigned y = 0; y < rows; y++) {
            for (unsigned x = 0; x < rows; x++) {
                m_patchParams[z*rows*rows+y*rows+x].patchX = x;
                m_patchParams[z*rows*rows+y*rows+x].patchY = y;
                m_patchParams[z*rows*rows+y*rows+x].layer = z;
                m_patchParams[z*rows*rows+y*rows+x].padding = 0;
            }
        }
    }

    m_videoPatchParams = std::unique_ptr<CudaUtils::CudaDeviceMemory>(new CudaUtils::CudaDeviceMemory());
    m_videoPatchParams->resize(m_numPatches * sizeof(PatchAtlasPatchParams));
}

PatchAtlas::~PatchAtlas()
{
    //dtor
}

void PatchAtlas::setSourceImageToAtlasPatchScaleOffset(unsigned index, float centerX, float centerY, float w, float h)
{
    PatchAtlasPatchParams &param = m_patchParams[index];
    param.sourceImageToAtlasPatchScaleOffsetX[0] = m_constants.patchSize / w;
    param.sourceImageToAtlasPatchScaleOffsetY[0] = m_constants.patchSize / h;
    param.sourceImageToAtlasPatchScaleOffsetX[1] = -centerX * param.sourceImageToAtlasPatchScaleOffsetX[0] +
                                    (param.patchX * PatchAtlasConstants::ATLAS_PATCH_SIZE) * m_constants.halfTexelSize * 2.0f +
                                    PatchAtlasConstants::ATLAS_PATCH_SIZE * m_constants.halfTexelSize;
    param.sourceImageToAtlasPatchScaleOffsetY[1] = -centerY * param.sourceImageToAtlasPatchScaleOffsetY[0] +
                                    (param.patchY * PatchAtlasConstants::ATLAS_PATCH_SIZE) * m_constants.halfTexelSize * 2.0f +
                                    PatchAtlasConstants::ATLAS_PATCH_SIZE * m_constants.halfTexelSize;
}

void PatchAtlas::setWorldToAtlasPatchFromProjectionMatrix(unsigned index, const LinAlg::Matrix4x4f &P)
{
    const PatchAtlasPatchParams &param = m_patchParams[index];

    LinAlg::Matrix4x4f imageMat;
    imageMat[0][0] = param.sourceImageToAtlasPatchScaleOffsetX[0];
    imageMat[1][1] = param.sourceImageToAtlasPatchScaleOffsetY[0];

    imageMat[0][3] = param.sourceImageToAtlasPatchScaleOffsetX[1];
    imageMat[1][3] = param.sourceImageToAtlasPatchScaleOffsetY[1];

    setWorldToAtlasPatchMatrix(index, imageMat * P);
}



unsigned PatchAtlas::allocatePatch()
{
    for (unsigned i = 0; i < m_inUse.size(); i++) {
        if (m_inUse[i] != (uint64_t)-1l) {
            for (unsigned j = 0; j < 64; j++) {
                if ((m_inUse[i] & ((uint64_t)1 << (uint64_t)j)) == (uint64_t)0) {
                    if ((i*64+j) < m_numPatches) {
                        m_inUse[i] |= (uint64_t)1 << (uint64_t)j;
                        return i*64+j;
                    }
                }
            }
        }
    }
    return -1;
}

void PatchAtlas::freePatch(unsigned index)
{
    m_inUse[index / 64] &= ~((uint64_t)1 << (uint64_t)(index % 64));
}


unsigned PatchAtlas::getNumPatchesInUse() const
{
    unsigned count = 0;
    for (unsigned i = 0; i < m_inUse.size(); i++) {
        if (m_inUse[i] == (uint64_t)-1l) {
            count += 64;
        } else
        if (m_inUse[i] == (uint64_t)0) {
        } else {
            for (unsigned j = 0; j < 64; j++) {
                if ((m_inUse[i] & ((uint64_t)1 << (uint64_t)j)) != (uint64_t)0) {
                    if ((i*64+j) < m_numPatches) {
                        count++;
                    }
                }
            }
        }
    }
    return count;
}


struct DebugKernelParams
{
    unsigned width;
    unsigned height;
    unsigned layer;
    void *dstPtr;
};

void PatchAtlas::debugDumpLayer(unsigned layerIndex)
{
#if 1
    CudaUtils::CudaCodeModule debugExtractCodeModule;
    std::unique_ptr<CudaUtils::CudaKernel> debugExtractKernel;
    std::unique_ptr<CudaUtils::CudaSurfaceReference> surfRef;

    debugExtractCodeModule.loadFromFile("../SFMBackend/kernels/Release/debugExtractTextureLayer.fatbin");

    debugExtractKernel = std::unique_ptr<CudaUtils::CudaKernel>(debugExtractCodeModule.getKernel("debugExtractFromLayeredArray"));
    surfRef = std::unique_ptr<CudaUtils::CudaSurfaceReference>(debugExtractCodeModule.getSurfReference("layeredArray"));

    CudaUtils::CudaDeviceMemory linearData;
    DebugKernelParams kernelParams;

    RasterImage rasterImage;
    for (unsigned level = 0; level < PatchAtlasConstants::ATLAS_MIP_LEVEL; level++) {
        rasterImage.resize(m_atlasTexture->getLevel(level).getWidth(), m_atlasTexture->getLevel(level).getHeight());
        linearData.resize(m_atlasTexture->getLevel(level).getWidth() * m_atlasTexture->getLevel(level).getHeight() * 4);
        kernelParams.width = m_atlasTexture->getLevel(level).getWidth();
        kernelParams.height = m_atlasTexture->getLevel(level).getHeight();
        kernelParams.layer = layerIndex;
        kernelParams.dstPtr = linearData.getPtr();

        surfRef->bindTexture(&m_atlasTexture->getLevel(level));

        debugExtractKernel->launch(LinAlg::Fill(16u, 16u, 1u),
                                      LinAlg::Fill<unsigned>((m_atlasTexture->getLevel(level).getWidth()+15)/16,
                                                             (m_atlasTexture->getLevel(level).getHeight()+15)/16, 1u),
                                      &kernelParams, sizeof(kernelParams));

        linearData.download(rasterImage.getData(), linearData.size());

        for (unsigned i = 0; i < rasterImage.getWidth() * rasterImage.getHeight(); i++) {

            unsigned char *pixel = (unsigned char*)&rasterImage.getData()[i];
            pixel[0] = pixel[0] * pixel[3]/255;
            pixel[1] = pixel[1] * pixel[3]/255;
            pixel[2] = pixel[2] * pixel[3]/255;
            pixel[3] = 255;
        }

        std::stringstream str;
        str << "patchAtlasLayer"<<layerIndex<<"level"<<level<<".png";

        rasterImage.writeToFile(str.str().c_str());
    }

#else
    RasterImage rasterImage;
    for (unsigned level = 0; level < PatchAtlasConstants::ATLAS_MIP_LEVEL; level++) {
        rasterImage.resize(m_atlasTexture->getLevel(level).getWidth(), m_atlasTexture->getLevel(level).getHeight());

        m_atlasTexture->getLevel(level).syncDownloadSingleLayer(rasterImage.getData(), rasterImage.getWidth()*4, layerIndex);

        for (unsigned i = 0; i < rasterImage.getWidth() * rasterImage.getHeight(); i++) {

            unsigned char *pixel = (unsigned char*)&rasterImage.getData()[i];
            pixel[0] = pixel[0] * pixel[3]/255;
            pixel[1] = pixel[1] * pixel[3]/255;
            pixel[2] = pixel[2] * pixel[3]/255;
            pixel[3] = 255;
        }

        std::stringstream str;
        str << "patchAtlasLayer"<<layerIndex<<"level"<<level<<".png";

        rasterImage.writeToFile(str.str().c_str());
    }
#endif
}

