Program Listing for File SpatialDetectionNetwork.hpp

Return to documentation for file (include/depthai/pipeline/node/SpatialDetectionNetwork.hpp)

#pragma once

#include <depthai/pipeline/Node.hpp>
#include <depthai/pipeline/node/DetectionNetwork.hpp>

#include "depthai/openvino/OpenVINO.hpp"

// standard
#include <fstream>

// shared
#include <depthai-shared/properties/SpatialDetectionNetworkProperties.hpp>

namespace dai {
namespace node {

class SpatialDetectionNetwork : public NodeCRTP<DetectionNetwork, SpatialDetectionNetwork, SpatialDetectionNetworkProperties> {
   public:
    constexpr static const char* NAME = "SpatialDetectionNetwork";

   protected:
    SpatialDetectionNetwork(const std::shared_ptr<PipelineImpl>& par, int64_t nodeId);
    SpatialDetectionNetwork(const std::shared_ptr<PipelineImpl>& par, int64_t nodeId, std::unique_ptr<Properties> props);

   public:
    Input input{*this, "in", Input::Type::SReceiver, true, 5, true, {{DatatypeEnum::ImgFrame, false}}};

    Input inputDepth{*this, "inputDepth", Input::Type::SReceiver, false, 4, true, {{DatatypeEnum::ImgFrame, false}}};

    Output out{*this, "out", Output::Type::MSender, {{DatatypeEnum::SpatialImgDetections, false}}};

    Output boundingBoxMapping{*this, "boundingBoxMapping", Output::Type::MSender, {{DatatypeEnum::SpatialLocationCalculatorConfig, false}}};

    Output passthrough{*this, "passthrough", Output::Type::MSender, {{DatatypeEnum::ImgFrame, false}}};

    Output passthroughDepth{*this, "passthroughDepth", Output::Type::MSender, {{DatatypeEnum::ImgFrame, false}}};

    Output spatialLocationCalculatorOutput{
        *this, "spatialLocationCalculatorOutput", Output::Type::MSender, {{DatatypeEnum::SpatialLocationCalculatorData, false}}};

    void setBoundingBoxScaleFactor(float scaleFactor);

    void setDepthLowerThreshold(uint32_t lowerThreshold);

    void setDepthUpperThreshold(uint32_t upperThreshold);

    void setSpatialCalculationAlgorithm(dai::SpatialLocationCalculatorAlgorithm calculationAlgorithm);

    void setSpatialCalculationStepSize(int stepSize);
};

class MobileNetSpatialDetectionNetwork : public NodeCRTP<SpatialDetectionNetwork, MobileNetSpatialDetectionNetwork, SpatialDetectionNetworkProperties> {
   public:
    MobileNetSpatialDetectionNetwork(const std::shared_ptr<PipelineImpl>& par, int64_t nodeId);
};

class YoloSpatialDetectionNetwork : public NodeCRTP<SpatialDetectionNetwork, YoloSpatialDetectionNetwork, SpatialDetectionNetworkProperties> {
   public:
    YoloSpatialDetectionNetwork(const std::shared_ptr<PipelineImpl>& par, int64_t nodeId);

    void setNumClasses(const int numClasses);
    void setCoordinateSize(const int coordinates);
    void setAnchors(std::vector<float> anchors);
    void setAnchorMasks(std::map<std::string, std::vector<int>> anchorMasks);
    void setIouThreshold(float thresh);

    int getNumClasses() const;
    int getCoordinateSize() const;
    std::vector<float> getAnchors() const;
    std::map<std::string, std::vector<int>> getAnchorMasks() const;
    float getIouThreshold() const;
};

}  // namespace node
}  // namespace dai