/*
    Structure from Motion with Deferred Feature Matching and Subset Bundle Adjustment
    Copyright (C) 2015 Andreas Ley <andy-ley@arcor.de>

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
*/

#include "FeaturePointAlignmentEstimator.h"

#include "../cudaUtilities/cudaProfilingScope.h"
#include "../tools/CPUStopWatch.h"
#include "SFM.h"

namespace SFM {

struct ViewRay
{
    LinAlg::Vector3f focalPoint;
    LinAlg::Vector3f normalizedDir;
};


ViewRay computeViewRayFromScreenSpacePosition(const Camera &camera,
                                              const LinAlg::Vector2f &screenSpacePosition)
{
    ViewRay result;
#if 1
    result.focalPoint = camera.getCameraPosition();
    LinAlg::Vector4f p = camera.getInvProjectionViewMatrix() * screenSpacePosition.AddHom(1.0f).AddHom(1.0f);
    result.normalizedDir = (p.StripHom() / p[3] - result.focalPoint).normalized();
#else
    LinAlg::Vector4f p;
    p = camera.getInvProjectionViewMatrix() * screenSpacePosition.AddHom(1.0f).AddHom(1.0f);
    result.focalPoint = p.StripHom() / p[3];
    p = camera.getInvProjectionViewMatrix() * screenSpacePosition.AddHom(2.0f).AddHom(1.0f);
    result.normalizedDir = (p.StripHom() / p[3] - result.focalPoint).normalized();
#endif
    return result;
}

LinAlg::Vector3f computePlaneViewRayIntersectionPoint(const LinAlg::Vector4f &normalizedPlane,
                                                      ViewRay &viewRay)
{
    float d = normalizedPlane * viewRay.focalPoint.AddHom(1.0f);
    float cosAlpha = normalizedPlane.StripHom() * viewRay.normalizedDir;
    return viewRay.focalPoint - viewRay.normalizedDir * d/cosAlpha;
}

LinAlg::Vector2f projectPoint(const LinAlg::Matrix4x4f &projectionView, const LinAlg::Vector3f &WSpos)
{
    LinAlg::Vector4f result = projectionView * WSpos.AddHom(1.0f);
    return result.StripHom().StripHom() / result[3];
}

void FeaturePointAlignmentEstimator::process(std::vector<FeaturePointPair> &pairs)
{
    #ifdef SFM_PERFORM_PATCH_ALIGNMENT
    AddCudaScopedProfileInterval("realignTrackPatches");

    std::set<std::pair<Frame*, unsigned> > neededFeaturePoints;
    std::set<std::pair<Frame*, unsigned> > neededFeaturePointsNeedingUpload;
    {
        AddCudaScopedProfileInterval("Gathering required observations");

        for (unsigned i = 0; i < pairs.size(); i++) {
            neededFeaturePoints.insert(std::pair<Frame*, unsigned>(pairs[i].templateFPFrame, pairs[i].templateFPIndex));
            neededFeaturePoints.insert(std::pair<Frame*, unsigned>(pairs[i].secondFPFrame, pairs[i].secondFPIndex));

            {
                Frame *frame = pairs[i].templateFPFrame;
                unsigned index = pairs[i].templateFPIndex;
                if (frame->getFeaturePoint(index).patchCacheSlot.valid()) {
                    m_patchCache.lockSlot(frame->getFeaturePoint(index).patchCacheSlot.getSlotIndex());
                } else {
                    neededFeaturePointsNeedingUpload.insert(std::pair<Frame*, unsigned>(frame, index));
                }
            }
            {
                Frame *frame = pairs[i].secondFPFrame;
                unsigned index = pairs[i].secondFPIndex;
                if (frame->getFeaturePoint(index).patchCacheSlot.valid()) {
                    m_patchCache.lockSlot(frame->getFeaturePoint(index).patchCacheSlot.getSlotIndex());
                } else {
                    neededFeaturePointsNeedingUpload.insert(std::pair<Frame*, unsigned>(frame, index));
                }
            }
        }
    }


    std::vector<PatchCache::SlotHandle> newSlots;
    newSlots.resize(neededFeaturePointsNeedingUpload.size());
    m_patchCache.allocateSlots(&newSlots[0], newSlots.size());
    std::vector<PackedPatchToPatchCacheTransfer::TransferData> transferData;
    transferData.resize(newSlots.size());
    if (transferData.size() > 0) {
        unsigned index = 0;
        for (auto it = neededFeaturePointsNeedingUpload.begin(); it != neededFeaturePointsNeedingUpload.end(); ++it, index++) {
            it->first->getFeaturePoint(it->second).patchCacheSlot =
            transferData[index].slot = newSlots[index];
            transferData[index].packedPatch = &it->first->getFeaturePoint(it->second).packedPatch;
        }

        AddCudaScopedProfileInterval("Uploading new patches");
        m_transferer.transfer(transferData, m_patchCache);
    }



    std::vector<PairwisePatchAligner::AlignmentPair> alignmentPairs;
    alignmentPairs.resize(pairs.size());
    {
        AddCudaScopedProfileInterval("Gathering data");

        for (unsigned i = 0; i < alignmentPairs.size(); i++) {

            PairwisePatchAligner::AlignmentPair &alignmentPair = alignmentPairs[i];

            Frame *frame1 = pairs[i].templateFPFrame;
            Frame *frame2 = pairs[i].secondFPFrame;

            const Frame::ImageFeaturePoint &featurePoint1 = frame1->getFeaturePoint(pairs[i].templateFPIndex);
            const Frame::ImageFeaturePoint &featurePoint2 = frame2->getFeaturePoint(pairs[i].secondFPIndex);

            const LinAlg::Vector2f &screenSpacePosition1 = pairs[i].templateFPScreenSpacePosition;
            const LinAlg::Vector2f &screenSpacePosition2 = pairs[i].secondFPScreenSpacePosition;

            assert(featurePoint1.patchCacheSlot.valid());
            assert(featurePoint2.patchCacheSlot.valid());

            alignmentPair.templatePatchCacheSlot = featurePoint1.patchCacheSlot.getSlotIndex();
            alignmentPair.templateScreenSpacePosition = screenSpacePosition1;

            alignmentPair.alignmentPatchCacheSlot = featurePoint2.patchCacheSlot.getSlotIndex();


            LinAlg::Vector3f trackWSPos = pairs[i].estimatedTrackPosition.StripHom() / pairs[i].estimatedTrackPosition[3];
#if 0
            LinAlg::Vector4f u1, v1;
            {
                LinAlg::Vector4f screenPos = frame1->getCamera().getProjectionViewMatrix() * pairs[i].estimatedTrackPosition;
                screenPos /= screenPos[3];

                LinAlg::Vector4f screenPosRight = screenPos;
                screenPosRight[0] += 0.01f;

                LinAlg::Vector4f screenPosTop = screenPos;
                screenPosTop[1] += 0.01f;

                u1 = frame1->getCamera().getInvProjectionViewMatrix() * screenPosRight;
                v1 = frame1->getCamera().getInvProjectionViewMatrix() * screenPosTop;
            }
            LinAlg::Vector3f u2, v2;
            {
                LinAlg::Vector4f screenPos = frame2->getCamera().getProjectionViewMatrix() * pairs[i].estimatedTrackPosition;
                screenPos /= screenPos[3];

                LinAlg::Vector4f p;
                p = frame2->getCamera().getProjectionViewMatrix() * u1;
                u2 = p.StripHom() / p[3] - screenPos.StripHom();
                p = frame2->getCamera().getProjectionViewMatrix() * v1;
                v2 = p.StripHom() / p[3] - screenPos.StripHom();
            }

            alignmentPair.templateScreenToAlignmentScreen[0][0] = u2[0] * (1.0f / 0.01f);
            alignmentPair.templateScreenToAlignmentScreen[0][1] = v2[0] * (1.0f / 0.01f);
            alignmentPair.templateScreenToAlignmentScreen[0][2] = -screenSpacePosition1[0] * alignmentPair.templateScreenToAlignmentScreen[0][0]
                                                                  -screenSpacePosition1[1] * alignmentPair.templateScreenToAlignmentScreen[0][1]
                                                                  + screenSpacePosition2[0];

            alignmentPair.templateScreenToAlignmentScreen[1][0] = u2[1] * (1.0f / 0.01f);
            alignmentPair.templateScreenToAlignmentScreen[1][1] = v2[1] * (1.0f / 0.01f);
            alignmentPair.templateScreenToAlignmentScreen[1][2] = -screenSpacePosition1[0] * alignmentPair.templateScreenToAlignmentScreen[1][0]
                                                                  -screenSpacePosition1[1] * alignmentPair.templateScreenToAlignmentScreen[1][1]
                                                                  + screenSpacePosition2[1];
#else
            LinAlg::Vector4f patchPlane = pairs[i].normal.AddHom(-(pairs[i].normal * trackWSPos));

            LinAlg::Vector2f projectedScreenSpacePosition1 = projectPoint(frame1->getCamera().getProjectionViewMatrix(), trackWSPos);
            ViewRay vrTop = computeViewRayFromScreenSpacePosition(frame1->getCamera(), projectedScreenSpacePosition1 + LinAlg::Fill(0.0f, 0.001f));
            ViewRay vrRight = computeViewRayFromScreenSpacePosition(frame1->getCamera(), projectedScreenSpacePosition1 + LinAlg::Fill(0.001f, 0.0f));

            LinAlg::Vector3f wsTop = computePlaneViewRayIntersectionPoint(patchPlane, vrTop);
            LinAlg::Vector3f wsRight = computePlaneViewRayIntersectionPoint(patchPlane, vrRight);

            LinAlg::Vector2f screenCenter = projectPoint(frame2->getCamera().getProjectionViewMatrix(), trackWSPos);
            LinAlg::Vector2f screenTop = projectPoint(frame2->getCamera().getProjectionViewMatrix(), wsTop);
            LinAlg::Vector2f screenRight = projectPoint(frame2->getCamera().getProjectionViewMatrix(), wsRight);


            LinAlg::Vector2f screenRightVector = (screenRight - screenCenter) * (1.0f / 0.001f);
            LinAlg::Vector2f screenTopVector = (screenTop - screenCenter) * (1.0f / 0.001f);

            alignmentPair.templateScreenToAlignmentScreen[0][0] = screenRightVector[0];
            alignmentPair.templateScreenToAlignmentScreen[0][1] = screenTopVector[0];
            alignmentPair.templateScreenToAlignmentScreen[0][2] = -screenSpacePosition1[0] * alignmentPair.templateScreenToAlignmentScreen[0][0]
                                                                  -screenSpacePosition1[1] * alignmentPair.templateScreenToAlignmentScreen[0][1]
                                                                  + screenSpacePosition2[0];

            alignmentPair.templateScreenToAlignmentScreen[1][0] = screenRightVector[1];
            alignmentPair.templateScreenToAlignmentScreen[1][1] = screenTopVector[1];
            alignmentPair.templateScreenToAlignmentScreen[1][2] = -screenSpacePosition1[0] * alignmentPair.templateScreenToAlignmentScreen[1][0]
                                                                  -screenSpacePosition1[1] * alignmentPair.templateScreenToAlignmentScreen[1][1]
                                                                  + screenSpacePosition2[1];
#endif
        }
    }


    Engine::CPUStopWatch stopWatch2;

    if (alignmentPairs.size() > 0)
        m_aligner.optimize(alignmentPairs, &m_patchCache);

    std::cout << "pairwise aligned " << alignmentPairs.size() << " pairs in " << stopWatch2.getNanoseconds() * 1e-6f << " ms." << std::endl;

    {
        AddCudaScopedProfileInterval("Storing data");

        for (unsigned i = 0; i < alignmentPairs.size(); i++) {

            pairs[i].resultingHomography = alignmentPairs[i].templateScreenToAlignmentScreen;
            pairs[i].secondFPScreenSpacePosition = (alignmentPairs[i].templateScreenToAlignmentScreen * alignmentPairs[i].templateScreenSpacePosition.AddHom(1.0f)).StripHom();
            pairs[i].residualError = alignmentPairs[i].alignmentError;
        }
    }

    m_patchCache.incTimestamp();
    for (auto it = neededFeaturePoints.begin(); it != neededFeaturePoints.end(); ++it) {
        m_patchCache.unlockSlot(it->first->getFeaturePoint(it->second).patchCacheSlot.getSlotIndex());
    }
    #else
    assert(false && "Feature point alignment was disabled!");
    #endif
}

}
