Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 58 additions & 2 deletions src/systems/user_commands/UserCommands.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
*/

#include "UserCommands.hh"
#include <chrono>
#include <future>

#ifdef _MSC_VER
#pragma warning(push)
Expand Down Expand Up @@ -143,6 +145,10 @@ class UserCommandBase
/// \return True if command was properly executed.
public: virtual bool Execute() = 0;

/// \brief Promise that is used to exchange the result of the command
/// execution with the service handler.
public: std::promise<bool> promise;

/// \brief Message containing command.
protected: google::protobuf::Message *msg{nullptr};

Expand Down Expand Up @@ -456,6 +462,15 @@ class gz::sim::systems::UserCommandsPrivate
public: template <typename CommandT, typename InputT>
bool ServiceHandler(const InputT &_req, msgs::Boolean &_res);

/// \brief This is similar to \sa ServiceHandler but it blocks the request
/// until the command is actually executed.
/// \tparam CommandT Type for the command that associated with the service.
/// \tparam InputT Type form gz::msgs of the input parameter.
/// \param[in] _req Input parameter message of the service.
/// \param[out] _res Output parameter message of the service.
public: template <typename CommandT, typename InputT>
bool BlockingServiceHandler(const InputT &_req, msgs::Boolean &_res);

/// \brief Temlpate for advertising services
/// \tparam CommandT Type for the command that associated with the service.
/// \tparam InputT Type form gz::msgs of the input parameter.
Expand Down Expand Up @@ -485,6 +500,10 @@ class gz::sim::systems::UserCommandsPrivate

/// \brief Mutex to protect pending queue.
public: std::mutex pendingMutex;

/// \brief Global timeout settings for services.
/// \TODO(azeey) Consider making this configurable.
public: const unsigned int kServiceHandlerTimeoutMs{5000};
};

/// \brief Pose3d equality comparison function.
Expand Down Expand Up @@ -676,7 +695,9 @@ void UserCommands::PreUpdate(const UpdateInfo &/*_info*/,
for (auto &cmd : cmds)
{
// Execute
if (!cmd->Execute())
bool result = cmd->Execute();
cmd->promise.set_value(result);
if (!result)
continue;

// TODO(louise) Update command with current world state
Expand Down Expand Up @@ -721,8 +742,18 @@ void UserCommandsPrivate::AdvertiseService(const std::string &_topic,
{
this->node.Advertise(
_topic, &UserCommandsPrivate::ServiceHandler<CommandT, InputT>, this);

const auto blockingTopic = _topic + "/blocking";
this->node.Advertise(
blockingTopic,
&UserCommandsPrivate::BlockingServiceHandler<CommandT, InputT>, this);
if (_serviceName != nullptr)
gzmsg << _serviceName << " service on [" << _topic << "]" << std::endl;
{
gzmsg << _serviceName << " service on [" << _topic << "] (async)"
<< std::endl;
gzmsg << _serviceName << " service on [" << blockingTopic << "] (blocking)"
<< std::endl;
}
}

//////////////////////////////////////////////////
Expand All @@ -743,6 +774,31 @@ bool UserCommandsPrivate::ServiceHandler(const InputT &_req,
return true;
}

//////////////////////////////////////////////////
template <typename CommandT, typename InputT>
bool UserCommandsPrivate::BlockingServiceHandler(const InputT &_req,
msgs::Boolean &_res)
{
auto msg = _req.New();
msg->CopyFrom(_req);
auto cmd = std::make_unique<CommandT>(msg, this->iface);
auto future = cmd->promise.get_future();
// Push to pending
{
std::lock_guard<std::mutex> lock(this->pendingMutex);
this->pendingCmds.push_back(std::move(cmd));
}

// This blocks until the command is executed.
if (future.wait_for(std::chrono::milliseconds(kServiceHandlerTimeoutMs)) ==
std::future_status::ready)
{
_res.set_data(future.get());
return true;
}
return false;
}

//////////////////////////////////////////////////
UserCommandBase::UserCommandBase(google::protobuf::Message *_msg,
std::shared_ptr<UserCommandsInterface> &_iface)
Expand Down
144 changes: 106 additions & 38 deletions test/integration/user_commands.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*
*/

#include <future>
#include <string>

#include <gtest/gtest.h>
Expand Down Expand Up @@ -62,6 +63,32 @@ class UserCommandsTest : public InternalFixture<::testing::Test>
{
};

struct AsyncRequestInfo {
bool retval{false};
msgs::Boolean response;
bool result{false};
};

// This calls a request from a new thread so that the calling function can
// continue even if the request blocks.
template <typename RequestT>
auto asyncRequest(transport::Node &_node, const std::string &_topic,
const RequestT &_req)
{
unsigned int timeout = 5000;
auto asyncRetval = std::async(std::launch::async, [&]
{
AsyncRequestInfo info;
info.retval =
_node.Request(_topic, _req, timeout, info.response, info.result);
return info;
});
// Sleep for a little bit for the async thread to spin up and make the service
// request
GZ_SLEEP_MS(10);
return asyncRetval;
}

/////////////////////////////////////////////////
// See https://github.com/gazebosim/gz-sim/issues/1175
TEST_F(UserCommandsTest, GZ_UTILS_TEST_DISABLED_ON_WIN32(Create))
Expand Down Expand Up @@ -137,22 +164,23 @@ TEST_F(UserCommandsTest, GZ_UTILS_TEST_DISABLED_ON_WIN32(Create))
auto pos = pose->mutable_position();
pos->set_z(10);

msgs::Boolean res;
bool result;
unsigned int timeout = 5000;
std::string service{"/world/empty/create"};

std::string service{"/world/empty/create/blocking"};
transport::Node node;
EXPECT_TRUE(node.Request(service, req, timeout, res, result));
EXPECT_TRUE(result);
EXPECT_TRUE(res.data());
auto requestDataFuture = asyncRequest(node, service, req);

// Check entity has not been created yet
EXPECT_EQ(kNullEntity, ecm->EntityByComponents(components::Model(),
components::Name("spawned_model")));

// Run an iteration and check it was created
server.Run(true, 1, false);
{
auto requestData = requestDataFuture.get();
EXPECT_TRUE(requestData.retval);
EXPECT_TRUE(requestData.result);
EXPECT_TRUE(requestData.response.data());
}

EXPECT_EQ(entityCount + 4, ecm->EntityCount());
entityCount = ecm->EntityCount();

Expand All @@ -169,12 +197,16 @@ TEST_F(UserCommandsTest, GZ_UTILS_TEST_DISABLED_ON_WIN32(Create))
req.Clear();
req.set_sdf(modelStr);

EXPECT_TRUE(node.Request(service, req, timeout, res, result));
EXPECT_TRUE(result);
EXPECT_TRUE(res.data());
requestDataFuture = asyncRequest(node, service, req);

// Run an iteration and check it was not created
server.Run(true, 1, false);
{
auto requestData = requestDataFuture.get();
EXPECT_TRUE(requestData.retval);
EXPECT_TRUE(requestData.result);
EXPECT_FALSE(requestData.response.data());
}

EXPECT_EQ(entityCount, ecm->EntityCount());

Expand All @@ -183,12 +215,16 @@ TEST_F(UserCommandsTest, GZ_UTILS_TEST_DISABLED_ON_WIN32(Create))
req.set_sdf(modelStr);
req.set_allow_renaming(true);

EXPECT_TRUE(node.Request(service, req, timeout, res, result));
EXPECT_TRUE(result);
EXPECT_TRUE(res.data());
requestDataFuture = asyncRequest(node, service, req);

// Run an iteration and check it was created with a new name
server.Run(true, 1, false);
{
auto requestData = requestDataFuture.get();
EXPECT_TRUE(requestData.retval);
EXPECT_TRUE(requestData.result);
EXPECT_TRUE(requestData.response.data());
}

EXPECT_EQ(entityCount + 4, ecm->EntityCount());
entityCount = ecm->EntityCount();
Expand All @@ -202,12 +238,16 @@ TEST_F(UserCommandsTest, GZ_UTILS_TEST_DISABLED_ON_WIN32(Create))
req.set_sdf(modelStr);
req.set_name("banana");

EXPECT_TRUE(node.Request(service, req, timeout, res, result));
EXPECT_TRUE(result);
EXPECT_TRUE(res.data());
requestDataFuture = asyncRequest(node, service, req);

// Run an iteration and check it was created with given name
server.Run(true, 1, false);
{
auto requestData = requestDataFuture.get();
EXPECT_TRUE(requestData.retval);
EXPECT_TRUE(requestData.result);
EXPECT_TRUE(requestData.response.data());
}

EXPECT_EQ(entityCount + 4, ecm->EntityCount());
entityCount = ecm->EntityCount();
Expand All @@ -220,12 +260,16 @@ TEST_F(UserCommandsTest, GZ_UTILS_TEST_DISABLED_ON_WIN32(Create))
req.Clear();
req.set_sdf(lightStr);

EXPECT_TRUE(node.Request(service, req, timeout, res, result));
EXPECT_TRUE(result);
EXPECT_TRUE(res.data());
requestDataFuture = asyncRequest(node, service, req);

// Run an iteration and check it was created
server.Run(true, 1, false);
{
auto requestData = requestDataFuture.get();
EXPECT_TRUE(requestData.retval);
EXPECT_TRUE(requestData.result);
EXPECT_TRUE(requestData.response.data());
}

EXPECT_EQ(entityCount + 2, ecm->EntityCount());
entityCount = ecm->EntityCount();
Expand All @@ -239,12 +283,16 @@ TEST_F(UserCommandsTest, GZ_UTILS_TEST_DISABLED_ON_WIN32(Create))
req.Clear();
req.mutable_light()->set_name("light_test");
req.mutable_light()->set_parent_id(1);
EXPECT_TRUE(node.Request(service, req, timeout, res, result));
EXPECT_TRUE(result);
EXPECT_TRUE(res.data());
requestDataFuture = asyncRequest(node, service, req);

// Run an iteration and check it was created
server.Run(true, 1, false);
{
auto requestData = requestDataFuture.get();
EXPECT_TRUE(requestData.retval);
EXPECT_TRUE(requestData.result);
EXPECT_TRUE(requestData.response.data());
}

EXPECT_EQ(entityCount + 2, ecm->EntityCount());
entityCount = ecm->EntityCount();
Expand All @@ -259,17 +307,13 @@ TEST_F(UserCommandsTest, GZ_UTILS_TEST_DISABLED_ON_WIN32(Create))
req.set_sdf(modelStr);
req.set_name("acerola");

EXPECT_TRUE(node.Request(service, req, timeout, res, result));
EXPECT_TRUE(result);
EXPECT_TRUE(res.data());
auto requestDataFuture1 = asyncRequest(node, service, req);

req.Clear();
req.set_sdf(modelStr);
req.set_name("coconut");

EXPECT_TRUE(node.Request(service, req, timeout, res, result));
EXPECT_TRUE(result);
EXPECT_TRUE(res.data());
auto requestDataFuture2 = asyncRequest(node, service, req);

// Check neither exists yet
EXPECT_EQ(kNullEntity, ecm->EntityByComponents(components::Model(),
Expand All @@ -280,6 +324,18 @@ TEST_F(UserCommandsTest, GZ_UTILS_TEST_DISABLED_ON_WIN32(Create))

// Run an iteration and check both models were created
server.Run(true, 1, false);
{
auto requestData = requestDataFuture1.get();
EXPECT_TRUE(requestData.retval);
EXPECT_TRUE(requestData.result);
EXPECT_TRUE(requestData.response.data());
}
{
auto requestData = requestDataFuture2.get();
EXPECT_TRUE(requestData.retval);
EXPECT_TRUE(requestData.result);
EXPECT_TRUE(requestData.response.data());
}

EXPECT_EQ(entityCount + 8, ecm->EntityCount());
entityCount = ecm->EntityCount();
Expand All @@ -293,12 +349,16 @@ TEST_F(UserCommandsTest, GZ_UTILS_TEST_DISABLED_ON_WIN32(Create))
req.Clear();
req.set_sdf(lightsStr);

EXPECT_TRUE(node.Request(service, req, timeout, res, result));
EXPECT_TRUE(result);
EXPECT_TRUE(res.data());
requestDataFuture = asyncRequest(node, service, req);

// Run an iteration and check only the 1st was created
server.Run(true, 1, false);
{
auto requestData = requestDataFuture.get();
EXPECT_TRUE(requestData.retval);
EXPECT_TRUE(requestData.result);
EXPECT_TRUE(requestData.response.data());
}

EXPECT_EQ(entityCount + 2, ecm->EntityCount());
entityCount = ecm->EntityCount();
Expand All @@ -315,12 +375,16 @@ TEST_F(UserCommandsTest, GZ_UTILS_TEST_DISABLED_ON_WIN32(Create))
req.Clear();
req.set_sdf(badStr);

EXPECT_TRUE(node.Request(service, req, timeout, res, result));
EXPECT_TRUE(result);
EXPECT_TRUE(res.data());
requestDataFuture = asyncRequest(node, service, req);

// Run an iteration and check nothing was created
server.Run(true, 1, false);
{
auto requestData = requestDataFuture.get();
EXPECT_TRUE(requestData.retval);
EXPECT_TRUE(requestData.result);
EXPECT_FALSE(requestData.response.data());
}

EXPECT_EQ(entityCount, ecm->EntityCount());

Expand All @@ -330,12 +394,16 @@ TEST_F(UserCommandsTest, GZ_UTILS_TEST_DISABLED_ON_WIN32(Create))
req.Clear();
req.set_sdf_filename(testModel);

EXPECT_TRUE(node.Request(service, req, timeout, res, result));
EXPECT_TRUE(result);
EXPECT_TRUE(res.data());
requestDataFuture = asyncRequest(node, service, req);

// Run an iteration and check it was created
server.Run(true, 1, false);
{
auto requestData = requestDataFuture.get();
EXPECT_TRUE(requestData.retval);
EXPECT_TRUE(requestData.result);
EXPECT_TRUE(requestData.response.data());
}
EXPECT_EQ(entityCount + 4, ecm->EntityCount());

EXPECT_NE(kNullEntity, ecm->EntityByComponents(components::Model(),
Expand Down
Loading