/*
    Structure from Motion with Deferred Feature Matching and Subset Bundle Adjustment
    Copyright (C) 2015 Andreas Ley <andy-ley@arcor.de>

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
*/

#include "Track.h"

#include <assert.h>
#include "Frame.h"
#include "InternalCameraCalibration.h"

#include "../tools/PCVToolbox.hpp"

namespace SFM {


Track::Track()
{
    m_ba = NULL;
    m_size = -1.0f;
    m_baTrackHandle = -1;
    m_needsNormalUpdate = true;
    m_trackWeight = 1.0f;
    m_state = STATE_DISABLED;
    m_baSubesetVotes = 0;
    m_referenceObservation = NULL;
}

Track::~Track()
{
    switchState(STATE_DISABLED); // removes from bundle adjustment
}

void Track::setup(BundleAdjustment *ba)
{
    m_ba = ba;
}


void Track::switchState(State newState)
{
    if (m_state == newState)
        return;

    if (m_state == STATE_ACTIVE)
        deactivateBA();

    if (newState == STATE_ACTIVE)
        activateBA();

    m_state = newState;
}


void Track::activateBA()
{
    assert(m_baTrackHandle == (unsigned)-1);

    m_baTrackHandle = m_ba->addTrack();

    m_ba->setTrack(m_baTrackHandle, m_lastWSPositionEstimate);

    for (TrackObservation &observation : m_observations)
        observation.activateBA();
}

void Track::deactivateBA()
{
    for (TrackObservation &observation : m_observations)
        if (observation.activeInBA())
            observation.deactivateBA();

    m_ba->removeTrack(m_baTrackHandle);
    m_baTrackHandle = -1;
}

unsigned Track::addObservation(Frame *frame, unsigned framePatchIndex)
{
    const unsigned obsIndex = m_observations.allocate();
    m_observations[obsIndex].setup(this, frame, framePatchIndex);
    setNeedsNormalUpdate();
    m_majorChangeTimestamp.observationsChanged();
    return obsIndex;
}

void Track::setReferenceObservation(TrackObservation *obs)
{
    if (obs == m_referenceObservation) return;


    m_referenceObservation = obs;
    for (TrackObservation &observation : m_observations)
        observation.invalidateReferenceObsHomography();


    if (m_referenceObservation != NULL) {
        LinAlg::Vector3f wsPos;
        wsPos = m_lastWSPositionEstimate.StripHom() * (1.0f / m_lastWSPositionEstimate[3]);

        float distance = std::abs(m_referenceObservation->getFrame()->getCamera().getViewMatrix()[2] * wsPos.AddHom(1.0f));
        m_size = m_referenceObservation->getFrame()->getFeaturePoint(m_referenceObservation->getFrameFeaturePointIndex()).size
         * distance / m_referenceObservation->getFrame()->getCamera().getInternalCalibration()->getProjectionMatrix()[0][0];

        if (m_normal.SQRLen() == 0.0f)
            m_normal = (m_referenceObservation->getFrame()->getCamera().getCameraPosition() - wsPos).normalized();
    } else
        m_size = 0.0f;

}


void Track::readBackFromBA()
{
    const float *wsPos = m_ba->getTrackPosition(m_baTrackHandle);
    m_lastWSPositionEstimate[0] = wsPos[0];
    m_lastWSPositionEstimate[1] = wsPos[1];
    m_lastWSPositionEstimate[2] = wsPos[2];
    m_lastWSPositionEstimate[3] = wsPos[3];

    if (m_referenceObservation != NULL) {
        LinAlg::Vector3f wsPos;
        wsPos = m_lastWSPositionEstimate.StripHom() * (1.0f / m_lastWSPositionEstimate[3]);
        float distance = std::abs(m_referenceObservation->getFrame()->getCamera().getViewMatrix()[2] * wsPos.AddHom(1.0f));

        m_size = m_referenceObservation->getFrame()->getFeaturePoint(m_referenceObservation->getFrameFeaturePointIndex()).size
         * distance / m_referenceObservation->getFrame()->getCamera().getInternalCalibration()->getProjectionMatrix()[0][0];
    }

    for (TrackObservation &observation : m_observations)
        observation.recomputeUndistortedLocation();
}

void Track::obsChanged()
{
    m_majorChangeTimestamp.observationsChanged();
    setNeedsNormalUpdate();

    unsigned numActiveObs = 0;
    for (TrackObservation &observation : m_observations)
        if (!observation.isFaulty())
            numActiveObs++;

    if (numActiveObs < 2) {
        switchState(STATE_DISABLED);
    } else {
        if ((m_referenceObservation == NULL) || m_referenceObservation->isFaulty()) {
            chooseNewReferenceObservation();
        }
    }
}

void Track::chooseNewReferenceObservation()
{
    setReferenceObservation(chooseBestReferenceObservation());
}

TrackObservation *Track::chooseBestReferenceObservation(float *score)
{
    float bestScore = -10.0f;
    TrackObservation *best = NULL;

    LinAlg::Vector3f wsPos = m_lastWSPositionEstimate.StripHom() / m_lastWSPositionEstimate[3];
    for (TrackObservation &observation : m_observations)
        if (!observation.isFaulty()) {

            LinAlg::Vector3f viewDir = observation.getFrame()->getCamera().getCameraPosition() - wsPos;

            float cosAlpha = m_normal * viewDir.normalized();

            if (cosAlpha > bestScore) {
                bestScore = cosAlpha;
                best = &observation;
            }
        }

    if (score != NULL)
        *score = bestScore;
    return best;
}

void Track::updateTrackMatchingDescriptor(const uint32_t *data)
{
    if (m_trackMatchingDescriptor == NULL)
        m_trackMatchingDescriptor = std::unique_ptr<TrackMatchingDescriptor>(new TrackMatchingDescriptor());

#if 0
    m_trackMatchingDescriptor->updatePreWarpedDescriptor(data);
#else
    m_trackMatchingDescriptor->clear();

    bool atInfinity = std::abs(m_lastWSPositionEstimate[3]) < 1e-20f;
    LinAlg::Vector3f euclPos;
    if (!atInfinity)
        euclPos = m_lastWSPositionEstimate.StripHom() / m_lastWSPositionEstimate[3];

    for (TrackObservation &observation : m_observations)
        if (!observation.isFaulty()) {
            Frame *frame = observation.getFrame();
            unsigned featurePointIndex = observation.getFrameFeaturePointIndex();
            LinAlg::Vector3f viewDir;
            if (atInfinity)
                viewDir = m_lastWSPositionEstimate.StripHom();
            else
                viewDir = euclPos - frame->getCamera().getCameraPosition();

            viewDir.normalize();
            m_trackMatchingDescriptor->addDescriptor(frame->getSiftDescriptorDB().getDescriptor(frame->getFeaturePoint(featurePointIndex).siftDBIndex), viewDir);
        }
#endif
}

void Track::setTrackWeight(float trackWeight)
{
    m_trackWeight = trackWeight;
    for (TrackObservation &observation : m_observations)
        if (!observation.isFaulty())
            observation.trackWeightChanged();
}


inline bool fastMathIsFinite(float f)
{
    union {
        float f;
        unsigned i;
    } f2i;

    f2i.f = f;

    return ((f2i.i >> 23) & 0xFF) != 0xFF;
}


void Track::updateNormal()
{
    if (m_state == STATE_DISABLED)
        return;

    LinAlg::Matrix3x3f covar;
    covar *= 0.0f;


    LinAlg::Vector3f trackWSPos = m_lastWSPositionEstimate.StripHom() / m_lastWSPositionEstimate[3];

    LinAlg::Vector3f viewDirection1 = (trackWSPos - m_referenceObservation->getFrame()->getCamera().getCameraPosition()).normalized();
    LinAlg::Vector3f u1, v1;
    {
        LinAlg::Vector4f screenPos = m_referenceObservation->getFrame()->getCamera().getProjectionViewMatrix() * m_lastWSPositionEstimate;
        screenPos /= screenPos[3];

        LinAlg::Vector4f screenPosRight = screenPos;
        screenPosRight[0] += 0.01f;

        LinAlg::Vector4f screenPosTop = screenPos;
        screenPosTop[1] += 0.01f;

        LinAlg::Vector4f p;
        p = m_referenceObservation->getFrame()->getCamera().getInvProjectionViewMatrix() * screenPosRight;
        u1 = p.StripHom() / p[3] - trackWSPos;
        p = m_referenceObservation->getFrame()->getCamera().getInvProjectionViewMatrix() * screenPosTop;
        v1 = p.StripHom() / p[3] - trackWSPos;
    }

    for (TrackObservation &observation : m_observations) {
        if (&observation == m_referenceObservation)
            continue;
        if (observation.isFaulty())
            continue;

        assert(observation.referenceObsHomographyValid());

        LinAlg::Vector3f viewDirection2 = (trackWSPos - observation.getFrame()->getCamera().getCameraPosition()).normalized();


        LinAlg::Vector3f u2, v2;
        {
            LinAlg::Vector4f screenPos = observation.getFrame()->getCamera().getProjectionViewMatrix() * m_lastWSPositionEstimate;
            screenPos /= screenPos[3];

            LinAlg::Vector4f screenPosRight = screenPos;
            screenPosRight[0] += 0.01f * observation.getReferenceObsScreenToThisScreen()[0][0];
            screenPosRight[1] += 0.01f * observation.getReferenceObsScreenToThisScreen()[1][0];

            LinAlg::Vector4f screenPosTop = screenPos;
            screenPosTop[0] += 0.01f * observation.getReferenceObsScreenToThisScreen()[0][1];
            screenPosTop[1] += 0.01f * observation.getReferenceObsScreenToThisScreen()[1][1];

            LinAlg::Vector4f p;
            p = observation.getFrame()->getCamera().getInvProjectionViewMatrix() * screenPosRight;
            u2 = p.StripHom() / p[3] - trackWSPos;
            p = observation.getFrame()->getCamera().getInvProjectionViewMatrix() * screenPosTop;
            v2 = p.StripHom() / p[3] - trackWSPos;
        }

        LinAlg::Vector3f patchPointU;
        {
            float lambda1, lambda2;
            PCV::getPositionOfClosestProximityBetweenLines<3, float>(u1, viewDirection1, u2, viewDirection2, lambda1, lambda2);
            patchPointU = (u1 + viewDirection1 * lambda1 + u2 +viewDirection2 * lambda2) * 0.5f;
        }

        LinAlg::Vector3f patchPointV;
        {
            float lambda1, lambda2;
            PCV::getPositionOfClosestProximityBetweenLines<3, float>(v1, viewDirection1, v2, viewDirection2, lambda1, lambda2);
            patchPointV = (v1 + viewDirection1 * lambda1 + v2 +viewDirection2 * lambda2) * 0.5f;
        }

        for (unsigned k = 0; k < 3; k++)
            for (unsigned l = k; l < 3; l++) {
                covar[k][l] += patchPointU[k]*patchPointU[l] + patchPointV[k]*patchPointV[l];
                assert(fastMathIsFinite(covar[k][l]));
            }

    }

    for (unsigned k = 0; k < 3; k++)
        for (unsigned l = 0; l < k; l++)
            covar[k][l] = covar[l][k];

    std::cout << "covar = " << covar << std::endl;
            for (unsigned k = 0; k < 3; k++)
                for (unsigned l = k; l < 3; l++) {
                    assert(fastMathIsFinite(covar[k][l]));
                }

    LinAlg::Vector3f eigenValues;
    LinAlg::Matrix3x3f eigenVectors;
    LinAlg::computeEigenValues(covar, eigenValues, &eigenVectors);
  //  std::cout << (std::string) eigenValues << std::endl;

    unsigned smallestEV = 0;
    for (unsigned k = 1; k < 3; k++)
        if (eigenValues[k] < eigenValues[smallestEV]) smallestEV = k;

    m_normal[0] = eigenVectors[0][smallestEV];
    m_normal[1] = eigenVectors[1][smallestEV];
    m_normal[2] = eigenVectors[2][smallestEV];

    if (m_normal * viewDirection1 > 0.0f)
        m_normal = m_normal.negated();

    LinAlg::Vector3f Up = LinAlg::cross(u1, m_normal).normalized();
    LinAlg::Vector3f Right = LinAlg::cross(m_normal, Up).normalized();


    m_orientation[0] = Right;
    m_orientation[1] = Up;
    m_orientation[2] = m_normal;
    m_orientation = m_orientation.T();
}

void Track::checkForMajorChange(uint32_t currentTimestamp)
{
    m_majorChangeTimestamp.checkUpdateTimestamp(currentTimestamp, m_lastWSPositionEstimate, m_size);
}

void Track::MajorChangeTimestamp::checkUpdateTimestamp(uint32_t currentTimestamp, const LinAlg::Vector4f &lastWSPositionEstimate, float size)
{
    if (m_observationsChangedSinceLastTimestampUpdate ||
        (size < 1e-20f) ||
        (fabs(lastWSPositionEstimate[3]) < 1e-20f) ||
        (fabs(m_WSPositionEstimate[3]) < 1e-20f)) {
        m_observationsChangedSinceLastTimestampUpdate = false;
        m_WSPositionEstimate = lastWSPositionEstimate;
        m_lastMajorChangeTimestamp = currentTimestamp;
    } else {
        float sqrD = (lastWSPositionEstimate.StripHom() / lastWSPositionEstimate[3] - m_WSPositionEstimate.StripHom() / m_WSPositionEstimate[3]).SQRLen();
        if (sqrD > size*size*0.25f) {
            m_observationsChangedSinceLastTimestampUpdate = false;
            m_WSPositionEstimate = lastWSPositionEstimate;
            m_lastMajorChangeTimestamp = currentTimestamp;
        }
    }
}

}

