/*
    Structure from Motion with Deferred Feature Matching and Subset Bundle Adjustment
    Copyright (C) 2015 Andreas Ley <andy-ley@arcor.de>

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
*/

#ifndef PFBUNDLEADJUSTMENT_H
#define PFBUNDLEADJUSTMENT_H

#include "../../tools/LinAlg.h"
#include "PFJacobi.h"

#include "../../config/BundleAdjustmentConfig.h"

#include "../../tools/ChunkedArray.hpp"
#include <assert.h>

namespace SFM {

/**
 * @namespace SFM::PFBundleAdjustment
 * @brief Contains all the classes of a CPU based forward bundle adjustment implementation.
 * @ingroup BundleAdjustment_Group
 */
namespace PFBundleAdjustment {

namespace detail {
struct BAStateInterface;
}

/**
 * @brief CPU based implementation of forward bundle adjustment.
 * @details Uses Levenberg-Marquardt and Preconditioned Conjugate
 * Gradient Descend with Jacobi preconditioning. The internal
 * and external calibrations are split so that multiple cameras
 * can share the same internal calibration.
 * @ingroup BundleAdjustment_Group
 */
class PFBundleAdjustment
{
    public:
        PFBundleAdjustment(const config::BundleAdjustmentStructureConfig &structureConfig);
        PFBundleAdjustment(const PFBundleAdjustment &) = delete;
        void operator=(const PFBundleAdjustment&) = delete;

        ~PFBundleAdjustment();

        /// Change minor bundle adjustment parameters like iteration counts and thresholds.
        void changeParameterConfig(const config::BundleAdjustmentParameterConfig &config);
        /// Clears all internal datastructures and allows structural changes like changing the radial distorion model.
        void clear(const config::BundleAdjustmentStructureConfig *config = NULL);

        typedef unsigned TrackHandle;
        typedef unsigned CameraHandle;
        typedef unsigned InternalCamCalibHandle;
        typedef unsigned ObservationHandle;

        /// Parametrization of the supported distortion models
        struct RadialDistortionParametrization {
            config::BundleAdjustmentStructureConfig::RadialDistortionType type;
            union {
                struct {
                } noRadialDistortion;
                struct {
                    float kappa[3];
                } polynomial234;
            };
            RadialDistortionParametrization() : type(config::BundleAdjustmentStructureConfig::RadialDistortionType::NoRadialDistortion) { }
        };

        void reserveInternalCamCalibs(unsigned count);
        void reserveCameras(unsigned count);
        void reserveObservations(unsigned count);
        void reserveTracks(unsigned count);

        /// Adds a new internal calibration.
        InternalCamCalibHandle addIntCamCalib();
        /// Adds a new external calibration / camera. Must be bound to an internal calibration.
        CameraHandle addCamera(InternalCamCalibHandle intCamCalib);
        /// Adds a new track.
        TrackHandle addTrack();
        /// Adds a new observation. Must be bound to a camera and a track.
        ObservationHandle addObservation(CameraHandle camHandle, TrackHandle trackHandle);

        /// Sets the projection matrix for an internal calibration.
        void setInternalCalibProjectionMatrix(InternalCamCalibHandle intCamCalibHandle, const LinAlg::Matrix4x4f &projectionMatrix);
        /// @brief Sets the radial distortion parameters for an internal calibration.
        /// @details The radial distortion type must be the same that was set as part of the structural configuration for the entire PFBundleAdjustment.
        void setInternalCalibRadialDistortion(InternalCamCalibHandle intCamCalibHandle, const RadialDistortionParametrization &radialDistortion);

        /// Sets the view matrix / external calibration of a camera.
        void setCamera(CameraHandle camHandle, const LinAlg::Matrix4x4f &viewMatrix);
        /// Sets the screen space position and (scalar) weight of an observation.
        void setObservation(ObservationHandle obsHandle, const LinAlg::Vector2f &screenSpacePosition, float weight);
        /// Sets the world space position of a track.
        void setTrack(TrackHandle trackHandle, const LinAlg::Vector4f &position);

        /// Returns the projection matrix of an internal calibration.
        inline const LinAlg::Matrix4x4f &getInternalCalibProjectionMatrix(InternalCamCalibHandle intCamCalibHandle) const {
            return m_baStructure.m_calibrations[intCamCalibHandle].projectionMatrix;
        }
        /// Returns the radial distortion of an internal calibration.
        inline const RadialDistortionParametrization &getInternalCalibRadialDistortion(InternalCamCalibHandle intCamCalibHandle) const {
            return m_baStructure.m_calibrations[intCamCalibHandle].radialDistortion;
        }


        /// Returns the view matrix of a camera.
        inline const LinAlg::Matrix4x4f &getCameraViewMatrix(CameraHandle camHandle) const {
            return m_baStructure.m_cameras[camHandle].viewMatrix;
        }
        /// Returns projective world space position of a track.
        inline const float *getTrackPosition(TrackHandle trackHandle) const {
            return &m_baStructure.m_tracks[trackHandle].position[0];
        }

        /// Removes an internal calibration. The internal calibration must not be referenced by any camera.
        void removeIntCamCalib(InternalCamCalibHandle intCamCalib);
        /// Removes a camera. The camera must not be referenced by any observation.
        void removeCamera(CameraHandle camHandle);
        /// Removes an observation.
        void removeObservation(ObservationHandle obsHandle);
        /// Removes a track. The track must not be references by any observation.
        void removeTrack(TrackHandle trackHandle);

        /// Performs up to the specified number of Levenberg-Marquardt operations. Terminates early if convergence is detected.
        void iterate(unsigned iterations);

        /// Resets the step size of Levenberg-Marquardt.
        void restart() {
            m_lambda = m_parameterConfig.LevenbergMarquardt_InitialLambda;
        }

        /// Returns true if convergence was detected.
        bool converged() {
            return m_lambda >= m_parameterConfig.LevenbergMarquardt_MaxLambdaForConvergence;
        }

        /// Returns minor parameter configuration in use.
        inline const config::BundleAdjustmentParameterConfig &getParameterConfig() const { return m_parameterConfig; }
        /// Returns the structural configuration in use.
        inline const config::BundleAdjustmentStructureConfig &getStructureConfig() const { return m_baStructure.m_structureConfig; }

        struct Calibration {
            Calibration() : numObservations(0), numCameras(0), shuffledIndex(-1) { }

            unsigned numObservations;
            unsigned numCameras;
            LinAlg::Matrix4x4f projectionMatrix;
            RadialDistortionParametrization radialDistortion;

            unsigned shuffledIndex;
        };
        struct Camera {
            Camera(unsigned calibrationIndex_) : numObservations(0), calibrationIndex(calibrationIndex_), shuffledIndex(-1) { }

            unsigned numObservations;
            unsigned calibrationIndex;
            LinAlg::Matrix4x4f viewMatrix;
            unsigned shuffledIndex;
        };
        struct Track {
            Track() : numObservations(0), shuffledIndex(-1) { }

            unsigned numObservations;
            LinAlg::Vector4f position;

            unsigned shuffledIndex;
        };
        struct Observation {
            Observation(unsigned cameraIndex_, unsigned calibrationIndex_, unsigned trackIndex_) :
                        cameraIndex(cameraIndex_), calibrationIndex(calibrationIndex_), trackIndex(trackIndex_), shuffledIndex(-1) { }

            unsigned cameraIndex;
            unsigned calibrationIndex;
            unsigned trackIndex;

            float weight;
            LinAlg::Vector2f screenPos;
            unsigned shuffledIndex;
        };


        struct BAStructure {
            bool m_structureNeedsUpdate;

            ChunkedArray<Calibration> m_calibrations;
            ChunkedArray<Camera> m_cameras;
            ChunkedArray<Track> m_tracks;
            ChunkedArray<Observation> m_observations;

            unsigned m_numCalibrations;
            unsigned m_numCameras;
            unsigned m_numTracks;
            unsigned m_numObservations;

            config::BundleAdjustmentStructureConfig m_structureConfig;

            BAStructure();

            void clear();

            InternalCamCalibHandle addIntCamCalib();
            CameraHandle addCamera(InternalCamCalibHandle intCamCalib);
            TrackHandle addTrack();
            ObservationHandle addObservation(CameraHandle camHandle, TrackHandle trackHandle);

            void setInternalCalibProjectionMatrix(InternalCamCalibHandle intCamCalibHandle, const LinAlg::Matrix4x4f &projectionMatrix);
            void setInternalCalibRadialDistortion(InternalCamCalibHandle intCamCalibHandle, const RadialDistortionParametrization &radialDistortion);

            void setCamera(CameraHandle camHandle, const LinAlg::Matrix4x4f &viewMatrix);
            void setObservation(ObservationHandle obsHandle, const LinAlg::Vector2f &screenSpacePosition, float weight);
            void setTrack(TrackHandle trackHandle, const LinAlg::Vector4f &position);

            void removeIntCamCalib(InternalCamCalibHandle intCamCalib);
            void removeCamera(CameraHandle camHandle);
            void removeObservation(ObservationHandle obsHandle);
            void removeTrack(TrackHandle trackHandle);


            void initializeStates(std::unique_ptr<detail::BAStateInterface> *state0);
            void copyDataToState(detail::BAStateInterface &state);
            void copyDataFromState(const detail::BAStateInterface &state);
        };
    protected:
        config::BundleAdjustmentParameterConfig m_parameterConfig;

        BAStructure m_baStructure;

        float m_lambda;


        std::unique_ptr<PFJacobiInterface> m_J;
        std::unique_ptr<PFJacobiTransposedInterface> m_JT;

        std::unique_ptr<detail::BAStateInterface> m_states[2];
        detail::BAStateInterface *m_currentState;
        detail::BAStateInterface *m_nextState;


        std::vector<float, AlignedAllocator<float>> m_residuals;
        std::vector<float, AlignedAllocator<float>> m_JT_times_residuals;
        std::vector<float, AlignedAllocator<float>> m_diag_JTJ_times_lambda;
        std::vector<float, AlignedAllocator<float>> m_x;
        std::vector<float, AlignedAllocator<float>> m_r;
        std::vector<float, AlignedAllocator<float>> m_z;
        std::vector<float, AlignedAllocator<float>> m_p;
        std::vector<float, AlignedAllocator<float>> m_Jp;
        std::vector<float, AlignedAllocator<float>> m_Ap;
        std::vector<float, AlignedAllocator<float>> m_preconditioner;
        std::vector<float, AlignedAllocator<float>> m_rcpPreconditioner;
};

/**
 * @brief Contains the details of the bundle adjustment algorithm.
 * @ingroup BundleAdjustment_Group
 */
namespace detail {

/**
 * @brief Interface of the bundle adjustment state.
 * @details The actual implementation and data structures depend on the number of parameters the
 * and update parameters that for example the internal calibration and radial distortion models need.
 * @ingroup BundleAdjustment_Group
 */
struct BAStateInterface {
    virtual ~BAStateInterface() { }
    virtual void initializeJacobian(std::unique_ptr<PFJacobiInterface> &J, std::unique_ptr<PFJacobiTransposedInterface> &JT) const = 0;
    virtual void computeResiduals(std::vector<float, AlignedAllocator<float>> &dst) const = 0;
    virtual void computeJacobi(PFJacobiInterface &J, PFJacobiTransposedInterface &JT) const = 0;
    virtual void performUpdateStep(const std::vector<float, AlignedAllocator<float>> &parameterStep, BAStateInterface &dst) const = 0;


    virtual void initializeFromStructure(PFBundleAdjustment::BAStructure &structure) = 0;
    virtual void copyDataFromStructure(const PFBundleAdjustment::BAStructure &structure) = 0;
    virtual void copyDataToStructure(PFBundleAdjustment::BAStructure &structure) const = 0;

    virtual void operator=(const BAStateInterface &rhs) = 0;
};

}

}

/// @}

}

#endif // PFBUNDLEADJUSTMENT_H
