Skip to content

Commit

Permalink
Improve KeepoutFilter mask receiving performance (ros-navigation#3420)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexeyMerzlyakov authored Feb 22, 2023
1 parent 63b690a commit 99f7ff6
Show file tree
Hide file tree
Showing 9 changed files with 111 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ class BinaryFilter : public CostmapFilter

nav_msgs::msg::OccupancyGrid::SharedPtr filter_mask_;

std::string mask_frame_; // Frame where mask located in
std::string global_frame_; // Frame of currnet layer (master_grid)

double base_, multiplier_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,17 @@ class CostmapFilter : public Layer
return filter_mask->data[my * filter_mask->info.width + mx];
}

/**
* @brief Get the cost of a cell in the filter mask
* @param filter_mask Filter mask to get the cost from
* @param mx The x coordinate of the cell
* @param my The y coordinate of the cell
* @return The cost to set the cell to
*/
unsigned char getMaskCost(
nav_msgs::msg::OccupancyGrid::ConstSharedPtr filter_mask,
const unsigned int mx, const unsigned int & my) const;

/**
* @brief: Name of costmap filter info topic
*/
Expand All @@ -220,7 +231,7 @@ class CostmapFilter : public Layer
std::string mask_topic_;

/**
* @brief: mask_frame_->global_frame_ transform tolerance
* @brief: mask_frame->global_frame_ transform tolerance
*/
tf2::Duration transform_tolerance_;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,8 @@ class KeepoutFilter : public CostmapFilter
rclcpp::Subscription<nav2_msgs::msg::CostmapFilterInfo>::SharedPtr filter_info_sub_;
rclcpp::Subscription<nav_msgs::msg::OccupancyGrid>::SharedPtr mask_sub_;

std::unique_ptr<Costmap2D> mask_costmap_;
nav_msgs::msg::OccupancyGrid::SharedPtr filter_mask_;

std::string mask_frame_; // Frame where mask located in
std::string global_frame_; // Frame of currnet layer (master_grid)
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ class SpeedFilter : public CostmapFilter

nav_msgs::msg::OccupancyGrid::SharedPtr filter_mask_;

std::string mask_frame_; // Frame where mask located in
std::string global_frame_; // Frame of currnet layer (master_grid)

double base_, multiplier_;
Expand Down
5 changes: 2 additions & 3 deletions nav2_costmap_2d/plugins/costmap_filters/binary_filter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ namespace nav2_costmap_2d

BinaryFilter::BinaryFilter()
: filter_info_sub_(nullptr), mask_sub_(nullptr),
binary_state_pub_(nullptr), filter_mask_(nullptr), mask_frame_(""), global_frame_(""),
binary_state_pub_(nullptr), filter_mask_(nullptr), global_frame_(""),
default_state_(false), binary_state_(default_state_)
{
}
Expand Down Expand Up @@ -162,7 +162,6 @@ void BinaryFilter::maskCallback(
}

filter_mask_ = msg;
mask_frame_ = msg->header.frame_id;
}

void BinaryFilter::process(
Expand All @@ -183,7 +182,7 @@ void BinaryFilter::process(
geometry_msgs::msg::Pose2D mask_pose; // robot coordinates in mask frame

// Transforming robot pose from current layer frame to mask frame
if (!transformPose(global_frame_, pose, mask_frame_, mask_pose)) {
if (!transformPose(global_frame_, pose, filter_mask_->header.frame_id, mask_pose)) {
return;
}

Expand Down
21 changes: 21 additions & 0 deletions nav2_costmap_2d/plugins/costmap_filters/costmap_filter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@
#include "tf2_geometry_msgs/tf2_geometry_msgs.hpp"
#include "geometry_msgs/msg/point_stamped.hpp"

#include "nav2_costmap_2d/cost_values.hpp"
#include "nav2_util/occ_grid_values.hpp"

namespace nav2_costmap_2d
{

Expand Down Expand Up @@ -206,4 +209,22 @@ bool CostmapFilter::worldToMask(
return true;
}

unsigned char CostmapFilter::getMaskCost(
nav_msgs::msg::OccupancyGrid::ConstSharedPtr filter_mask,
const unsigned int mx, const unsigned int & my) const
{
const unsigned int index = my * filter_mask->info.width + mx;

const char data = filter_mask->data[index];
if (data == nav2_util::OCC_GRID_UNKNOWN) {
return NO_INFORMATION;
} else {
// Linear conversion from OccupancyGrid data range [OCC_GRID_FREE..OCC_GRID_OCCUPIED]
// to costmap data range [FREE_SPACE..LETHAL_OBSTACLE]
return std::round(
static_cast<double>(data) * (LETHAL_OBSTACLE - FREE_SPACE) /
(nav2_util::OCC_GRID_OCCUPIED - nav2_util::OCC_GRID_FREE));
}
}

} // namespace nav2_costmap_2d
65 changes: 33 additions & 32 deletions nav2_costmap_2d/plugins/costmap_filters/keepout_filter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ namespace nav2_costmap_2d
{

KeepoutFilter::KeepoutFilter()
: filter_info_sub_(nullptr), mask_sub_(nullptr), mask_costmap_(nullptr),
mask_frame_(""), global_frame_("")
: filter_info_sub_(nullptr), mask_sub_(nullptr), filter_mask_(nullptr),
global_frame_("")
{
}

Expand Down Expand Up @@ -130,7 +130,7 @@ void KeepoutFilter::maskCallback(
throw std::runtime_error{"Failed to lock node"};
}

if (!mask_costmap_) {
if (!filter_mask_) {
RCLCPP_INFO(
logger_,
"KeepoutFilter: Received filter mask from %s topic.", mask_topic_.c_str());
Expand All @@ -139,12 +139,11 @@ void KeepoutFilter::maskCallback(
logger_,
"KeepoutFilter: New filter mask arrived from %s topic. Updating old filter mask.",
mask_topic_.c_str());
mask_costmap_.reset();
filter_mask_.reset();
}

// Making a new mask_costmap_
mask_costmap_ = std::make_unique<Costmap2D>(*msg);
mask_frame_ = msg->header.frame_id;
// Store filter_mask_
filter_mask_ = msg;
}

void KeepoutFilter::process(
Expand All @@ -154,7 +153,7 @@ void KeepoutFilter::process(
{
std::lock_guard<CostmapFilter::mutex_t> guard(*getMutex());

if (!mask_costmap_) {
if (!filter_mask_) {
// Show warning message every 2 seconds to not litter an output
RCLCPP_WARN_THROTTLE(
logger_, *(clock_), 2000,
Expand All @@ -167,20 +166,22 @@ void KeepoutFilter::process(
int mg_min_x, mg_min_y; // masger_grid indexes of bottom-left window corner
int mg_max_x, mg_max_y; // masger_grid indexes of top-right window corner

if (mask_frame_ != global_frame_) {
const std::string mask_frame = filter_mask_->header.frame_id;

if (mask_frame != global_frame_) {
// Filter mask and current layer are in different frames:
// prepare frame transformation if mask_frame_ != global_frame_
// prepare frame transformation if mask_frame != global_frame_
geometry_msgs::msg::TransformStamped transform;
try {
transform = tf_->lookupTransform(
mask_frame_, global_frame_, tf2::TimePointZero,
mask_frame, global_frame_, tf2::TimePointZero,
transform_tolerance_);
} catch (tf2::TransformException & ex) {
RCLCPP_ERROR(
logger_,
"KeepoutFilter: Failed to get costmap frame (%s) "
"transformation to mask frame (%s) with error: %s",
global_frame_.c_str(), mask_frame_.c_str(), ex.what());
global_frame_.c_str(), mask_frame.c_str(), ex.what());
return;
}
tf2::fromMsg(transform.transform, tf2_transform);
Expand All @@ -192,9 +193,9 @@ void KeepoutFilter::process(
} else {
// Filter mask and current layer are in the same frame:
// apply the following optimization - iterate only in overlapped
// (min_i, min_j)..(max_i, max_j) & mask_costmap_ area.
// (min_i, min_j)..(max_i, max_j) & filter_mask_ area.
//
// mask_costmap_
// filter_mask_
// *----------------------------*
// | |
// | |
Expand All @@ -213,10 +214,10 @@ void KeepoutFilter::process(
double wx, wy; // world coordinates

// Calculating bounds corresponding to bottom-left overlapping (1) corner
// mask_costmap_ -> master_grid intexes conversion
const double half_cell_size = 0.5 * mask_costmap_->getResolution();
wx = mask_costmap_->getOriginX() + half_cell_size;
wy = mask_costmap_->getOriginY() + half_cell_size;
// filter_mask_ -> master_grid indexes conversion
const double half_cell_size = 0.5 * filter_mask_->info.resolution;
wx = filter_mask_->info.origin.position.x + half_cell_size;
wy = filter_mask_->info.origin.position.y + half_cell_size;
master_grid.worldToMapNoBounds(wx, wy, mg_min_x, mg_min_y);
// Calculation of (1) corner bounds
if (mg_min_x >= max_i || mg_min_y >= max_j) {
Expand All @@ -227,11 +228,11 @@ void KeepoutFilter::process(
mg_min_y = std::max(min_j, mg_min_y);

// Calculating bounds corresponding to top-right window (2) corner
// mask_costmap_ -> master_grid intexes conversion
wx = mask_costmap_->getOriginX() +
mask_costmap_->getSizeInCellsX() * mask_costmap_->getResolution() + half_cell_size;
wy = mask_costmap_->getOriginY() +
mask_costmap_->getSizeInCellsY() * mask_costmap_->getResolution() + half_cell_size;
// filter_mask_ -> master_grid intexes conversion
wx = filter_mask_->info.origin.position.x +
filter_mask_->info.width * filter_mask_->info.resolution + half_cell_size;
wy = filter_mask_->info.origin.position.y +
filter_mask_->info.height * filter_mask_->info.resolution + half_cell_size;
master_grid.worldToMapNoBounds(wx, wy, mg_max_x, mg_max_y);
// Calculation of (2) corner bounds
if (mg_max_x <= min_i || mg_max_y <= min_j) {
Expand All @@ -251,8 +252,8 @@ void KeepoutFilter::process(
unsigned int i, j; // master_grid iterators
unsigned int index; // corresponding index of master_grid
double gl_wx, gl_wy; // world coordinates in a global_frame_
double msk_wx, msk_wy; // world coordinates in a mask_frame_
unsigned int mx, my; // mask_costmap_ coordinates
double msk_wx, msk_wy; // world coordinates in a mask_frame
unsigned int mx, my; // filter_mask_ coordinates
unsigned char data, old_data; // master_grid element data

// Main master_grid updating loop
Expand All @@ -262,11 +263,11 @@ void KeepoutFilter::process(
for (j = mg_min_y_u; j < mg_max_y_u; j++) {
index = master_grid.getIndex(i, j);
old_data = master_array[index];
// Calculating corresponding to (i, j) point at mask_costmap_:
// Calculating corresponding to (i, j) point at filter_mask_:
// Get world coordinates in global_frame_
master_grid.mapToWorld(i, j, gl_wx, gl_wy);
if (mask_frame_ != global_frame_) {
// Transform (i, j) point from global_frame_ to mask_frame_
if (mask_frame != global_frame_) {
// Transform (i, j) point from global_frame_ to mask_frame
tf2::Vector3 point(gl_wx, gl_wy, 0);
point = tf2_transform * point;
msk_wx = point.x();
Expand All @@ -276,9 +277,9 @@ void KeepoutFilter::process(
msk_wx = gl_wx;
msk_wy = gl_wy;
}
// Get mask coordinates corresponding to (i, j) point at mask_costmap_
if (mask_costmap_->worldToMap(msk_wx, msk_wy, mx, my)) {
data = mask_costmap_->getCost(mx, my);
// Get mask coordinates corresponding to (i, j) point at filter_mask_
if (worldToMask(filter_mask_, msk_wx, msk_wy, mx, my)) {
data = getMaskCost(filter_mask_, mx, my);
// Update if mask_ data is valid and greater than existing master_grid's one
if (data == NO_INFORMATION) {
continue;
Expand All @@ -303,7 +304,7 @@ bool KeepoutFilter::isActive()
{
std::lock_guard<CostmapFilter::mutex_t> guard(*getMutex());

if (mask_costmap_) {
if (filter_mask_) {
return true;
}
return false;
Expand Down
5 changes: 2 additions & 3 deletions nav2_costmap_2d/plugins/costmap_filters/speed_filter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ namespace nav2_costmap_2d

SpeedFilter::SpeedFilter()
: filter_info_sub_(nullptr), mask_sub_(nullptr),
speed_limit_pub_(nullptr), filter_mask_(nullptr), mask_frame_(""), global_frame_(""),
speed_limit_pub_(nullptr), filter_mask_(nullptr), global_frame_(""),
speed_limit_(NO_SPEED_LIMIT), speed_limit_prev_(NO_SPEED_LIMIT)
{
}
Expand Down Expand Up @@ -169,7 +169,6 @@ void SpeedFilter::maskCallback(
}

filter_mask_ = msg;
mask_frame_ = msg->header.frame_id;
}

void SpeedFilter::process(
Expand All @@ -190,7 +189,7 @@ void SpeedFilter::process(
geometry_msgs::msg::Pose2D mask_pose; // robot coordinates in mask frame

// Transforming robot pose from current layer frame to mask frame
if (!transformPose(global_frame_, pose, mask_frame_, mask_pose)) {
if (!transformPose(global_frame_, pose, filter_mask_->header.frame_id, mask_pose)) {
return;
}

Expand Down
40 changes: 40 additions & 0 deletions nav2_costmap_2d/test/unit/costmap_filter_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "nav_msgs/msg/occupancy_grid.hpp"
#include "geometry_msgs/msg/pose2_d.hpp"
#include "nav2_costmap_2d/costmap_2d.hpp"
#include "nav2_costmap_2d/cost_values.hpp"
#include "nav2_costmap_2d/costmap_filters/costmap_filter.hpp"

class CostmapFilterWrapper : public nav2_costmap_2d::CostmapFilter
Expand All @@ -36,6 +37,13 @@ class CostmapFilterWrapper : public nav2_costmap_2d::CostmapFilter
return nav2_costmap_2d::CostmapFilter::worldToMask(filter_mask, wx, wy, mx, my);
}

unsigned char getMaskCost(
nav_msgs::msg::OccupancyGrid::ConstSharedPtr filter_mask,
const unsigned int mx, const unsigned int & my) const
{
return nav2_costmap_2d::CostmapFilter::getMaskCost(filter_mask, mx, my);
}

// API coverage
void initializeFilter(const std::string &) {}
void process(
Expand Down Expand Up @@ -89,6 +97,38 @@ TEST(CostmapFilter, testWorldToMask)
ASSERT_FALSE(cf.worldToMask(mask, 6.0, 6.0, mx, my));
}

TEST(CostmapFilter, testGetMaskCost)
{
// Create occupancy grid for test as follows:
// [-1, 0,
// 50, 100]

const unsigned int width = 2;
const unsigned int height = 2;

auto mask = std::make_shared<nav_msgs::msg::OccupancyGrid>();
mask->header.frame_id = "map";
mask->info.resolution = 1.0;
mask->info.width = width;
mask->info.height = height;
mask->info.origin.position.x = 0.0;
mask->info.origin.position.y = 0.0;

mask->data.resize(width * height);
mask->data[0] = nav2_util::OCC_GRID_UNKNOWN;
mask->data[1] = nav2_util::OCC_GRID_FREE;
mask->data[2] = nav2_util::OCC_GRID_OCCUPIED / 2;
mask->data[3] = nav2_util::OCC_GRID_OCCUPIED;

CostmapFilterWrapper cf;

// Test all value cases
ASSERT_EQ(cf.getMaskCost(mask, 0, 0), nav2_costmap_2d::NO_INFORMATION);
ASSERT_EQ(cf.getMaskCost(mask, 1, 0), nav2_costmap_2d::FREE_SPACE);
ASSERT_EQ(cf.getMaskCost(mask, 0, 1), nav2_costmap_2d::LETHAL_OBSTACLE / 2);
ASSERT_EQ(cf.getMaskCost(mask, 1, 1), nav2_costmap_2d::LETHAL_OBSTACLE);
}

int main(int argc, char ** argv)
{
// Initialize the system
Expand Down

0 comments on commit 99f7ff6

Please sign in to comment.