Program Listing for File gmmtmap.hpp
↰ Return to documentation for file (include/mod/gmmtmap.hpp)
/*
* Copyright (c) Chittaranjan Srinivas Swaminathan
* This file is part of mod.
*
* mod is free software: you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation, either version 3 of the License,
* or (at your option) any later version.
*
* mod is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with mod. If not, see
* <https://www.gnu.org/licenses/>.
*/
#pragma once
#include <Eigen/Core>
#include <array>
#include <boost/chrono.hpp>
#include <boost/geometry.hpp>
#include <boost/geometry/geometries/point_xy.hpp>
#include <boost/geometry/index/rtree.hpp>
#include <boost/log/trivial.hpp>
#include <mod/base.hpp>
#include <vector>
namespace MoD {
namespace bg = boost::geometry;
namespace bgi = boost::geometry::index;
typedef bg::model::d2::point_xy<double> Point2D;
typedef bg::model::box<Point2D> Box;
typedef std::pair<Point2D, std::array<size_t, 2>> TreeValue;
struct GMMTMapCluster {
double mixing_factor;
std::vector<std::array<double, 2>> mean;
std::vector<double> heading;
inline GMMTMapCluster() = default;
inline GMMTMapCluster(double pi, const std::vector<std::array<double, 2>> &mean, std::vector<double> heading) {
this->mixing_factor = pi;
this->mean = mean;
this->heading = heading;
}
};
class GMMTMap : public Base {
public:
explicit GMMTMap(const std::string &fileName) { readFromXML(fileName); }
void readFromXML(const std::string &fileName);
void computeHeadingAndConstructRTree();
std::vector<TreeValue> getNearestNeighbors(double x, double y) const;
inline std::vector<TreeValue> operator()(double x, double y) const { return this->getNearestNeighbors(x, y); };
inline int getM() const { return M_; }
inline int getK() const { return K_; }
inline double getStdDev() const { return stddev_; }
inline double getMixingFactorByClusterID(size_t cluster_idx) {
if (cluster_idx >= this->clusters_.size()) {
BOOST_LOG_TRIVIAL(error) << "getMixingFactorByClusterID() called with "
"cluster_id >= number of clusters.";
return 1.0;
}
return this->clusters_[cluster_idx].mixing_factor;
}
inline double getHeadingAtDist(size_t cluster_idx, size_t mean_idx) {
if (cluster_idx >= this->clusters_.size()) {
BOOST_LOG_TRIVIAL(error) << "getHeadingAtDist() called with cluster_idx "
">= number of clusters.";
BOOST_LOG_TRIVIAL(error) << "Total clusters: " << this->clusters_.size() << ", Cluster ID: " << cluster_idx;
return 0.0;
}
if (mean_idx >= this->clusters_[cluster_idx].heading.size()) {
BOOST_LOG_TRIVIAL(error) << "getHeadingAtDist() called with mean_idx >= "
"number of traj-means in cluster.";
BOOST_LOG_TRIVIAL(error) << "Total means: " << this->clusters_[cluster_idx].heading.size()
<< ", Cluster ID and Mean ID: " << cluster_idx << ", " << mean_idx;
return 0.0;
}
return this->clusters_[cluster_idx].heading[mean_idx];
}
protected:
int M_;
int K_;
double stddev_;
std::vector<GMMTMapCluster> clusters_;
bgi::rtree<TreeValue, bgi::quadratic<16>> rtree_;
};
typedef std::shared_ptr<GMMTMap> GMMTMapPtr;
typedef std::shared_ptr<const GMMTMap> GMMTMapConstPtr;
} // namespace MoD