Skip to content

Commit 5574059

Browse files
committed
refactored ik controller
1 parent b795b46 commit 5574059

File tree

1 file changed

+81
-78
lines changed

1 file changed

+81
-78
lines changed

whole_body_controllers/src/ik_controller.cpp

Lines changed: 81 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,16 @@ auto IKController::on_init() -> controller_interface::CallbackReturn
6363
// lock all uncontrolled joints
6464
std::vector<std::string> controlled_joints = {"universe", "root_joint"};
6565
std::ranges::copy(params_.controlled_joints, std::back_inserter(controlled_joints));
66-
const std::vector<pinocchio::JointIndex> locked_joints =
67-
std::ranges::transform(model_->names, [this](const auto & name) {
68-
if (std::ranges::find(controlled_joints, name) == controlled_joints.end()) {
69-
return model_->getJointId(name);
70-
}
71-
});
66+
67+
std::vector<std::string> locked_joint_names;
68+
std::ranges::copy_if(model_->names, std::back_inserter(locked_joint_names), [&controlled_joints](const auto & name) {
69+
return std::ranges::find(controlled_joints, name) == controlled_joints.end();
70+
});
71+
72+
std::vector<pinocchio::JointIndex> locked_joints;
73+
std::ranges::transform(locked_joint_names, std::back_inserter(locked_joints), [this](const auto & name) {
74+
return model_->getJointId(name);
75+
});
7276

7377
// build the reduced model
7478
pinocchio::Model reduced_model;
@@ -169,7 +173,7 @@ auto IKController::on_configure(const rclcpp_lifecycle::State & /*previous_state
169173
loader_ = std::make_unique<pluginlib::ClassLoader<ik_solvers::IKSolver>>("ik_solvers", "ik_solvers::IKSolver");
170174
solver_ = loader_->createSharedInstance(params_.ik_solver);
171175
solver_->initialize(get_node(), model_, data_, params_.ik_solver);
172-
RCLCPP_INFO(logger_, "Configured the IK controller with solver %s", params_.ik_solver); // NOLINT
176+
RCLCPP_INFO(logger_, "Configured the IK controller with solver %s", params_.ik_solver.c_str()); // NOLINT
173177

174178
// TODO(evan-palmer): add controller state publisher
175179

@@ -274,10 +278,10 @@ auto IKController::update_reference_from_subscribers(const rclcpp::Time & /*time
274278
-> controller_interface::return_type
275279
{
276280
auto * current_reference = reference_.readFromRT();
277-
auto reference = common::messages::to_vector(*current_reference);
281+
std::vector<double> reference = common::messages::to_vector(*current_reference);
278282

279-
for (auto & [interface, value] : std::views::zip(reference_interfaces_, reference)) {
280-
interface = value;
283+
for (std::size_t i = 0; i < reference.size(); ++i) {
284+
reference_interfaces_[i] = reference[i];
281285
}
282286

283287
common::messages::reset_message(current_reference);
@@ -291,49 +295,67 @@ auto IKController::update_system_state_values() -> controller_interface::return_
291295
// appropriate frame, so we can just copy the values into the system state values. otherwise, we need to transform
292296
// the states first, then save them.
293297
if (params_.use_external_measured_vehicle_states) {
294-
const auto * vehicle_state_msg = vehicle_state_.readFromRT();
295-
const auto state = common::messages::to_vector(*vehicle_state_msg);
298+
const auto * state_msg = vehicle_state_.readFromRT();
299+
const auto state = common::messages::to_vector(*state_msg);
296300
std::ranges::copy(state.begin(), state.begin() + free_flyer_pos_dofs_.size(), position_state_values_.begin());
297301
std::ranges::copy(state.begin() + free_flyer_pos_dofs_.size(), state.end(), velocity_state_values_.begin());
298302
} else {
303+
auto save_states = [](const auto & interfaces, auto out) {
304+
std::ranges::transform(interfaces, out, [](const auto & interface) {
305+
return interface.get_optional().value_or(std::numeric_limits<double>::quiet_NaN());
306+
});
307+
};
308+
309+
// retrieve the vehicle position and velocity state interfaces
310+
const auto position_interfaces_end = state_interfaces_.begin() + free_flyer_pos_dofs_.size();
311+
const auto velocity_interfaces_start = position_interfaces_end + params_.controlled_joints.size();
312+
const auto velocity_interfaces_end = velocity_interfaces_start + free_flyer_vel_dofs_.size();
313+
314+
const auto position_interfaces = std::span(state_interfaces_.begin(), position_interfaces_end);
315+
const auto velocity_interfaces = std::span(velocity_interfaces_start, velocity_interfaces_end);
316+
317+
std::vector<double> position_states, velocity_states;
318+
position_states.reserve(position_interfaces.size());
319+
velocity_states.reserve(velocity_interfaces.size());
320+
321+
save_states(position_interfaces, std::back_inserter(position_states));
322+
save_states(velocity_interfaces, std::back_inserter(velocity_states));
323+
324+
// transform the states into the appropriate frame and save them
325+
geometry_msgs::msg::Pose pose;
326+
common::messages::to_msg(position_states, &pose);
327+
m2m::transforms::transform_message(pose);
328+
std::ranges::copy(common::messages::to_vector(pose), position_state_values_.begin());
329+
330+
geometry_msgs::msg::Twist twist;
331+
common::messages::to_msg(velocity_states, &twist);
332+
m2m::transforms::transform_message(twist);
333+
std::ranges::copy(common::messages::to_vector(twist), velocity_state_values_.begin());
334+
}
335+
336+
auto find_interface = [](const auto & interfaces, const std::string & name, const std::string & type) {
337+
return std::ranges::find_if(interfaces, [&name, &type](const auto & interface) {
338+
return interface.get_interface_name() == std::format("{}/{}", name, type);
339+
});
340+
};
341+
342+
// save the manipulator states
343+
for (const auto & [i, joint_name] : std::views::enumerate(params_.controlled_joints)) {
344+
const pinocchio::JointModel joint = model_->joints[model_->getJointId(joint_name)];
345+
346+
const auto pos_it = find_interface(state_interfaces_, joint_name, hardware_interface::HW_IF_POSITION);
347+
const double pos = pos_it->get_optional().value_or(std::numeric_limits<double>::quiet_NaN());
348+
position_state_values_[free_flyer_pos_dofs_.size() + i] = pos;
349+
350+
const auto vel_it = find_interface(state_interfaces_, joint_name, hardware_interface::HW_IF_VELOCITY);
351+
const double vel = vel_it->get_optional().value_or(std::numeric_limits<double>::quiet_NaN());
352+
velocity_state_values_[free_flyer_vel_dofs_.size() + i] = vel;
299353
}
300354

301-
// const pinocchio::JointModel root = model_->joints[model_->getJointId("root_joint")];
302-
// if (params_.use_external_measured_vehicle_states) {
303-
// const auto * vehicle_state = vehicle_state_.readFromRT();
304-
// std::ranges::copy(common::messages::to_vector(*vehicle_state), system_state_values_.begin() + root.idx_q());
305-
// } else {
306-
// std::vector<double> vehicle_states;
307-
// vehicle_states.reserve(free_flyer_pos_dofs_.size());
308-
// for (std::size_t i = 0; i < free_flyer_pos_dofs_.size(); ++i) {
309-
// const auto out = state_interfaces_[i].get_optional();
310-
// vehicle_states.push_back(out.value_or(std::numeric_limits<double>::quiet_NaN()));
311-
// }
312-
313-
// geometry_msgs::msg::Pose pose;
314-
// common::messages::to_msg(vehicle_states, &pose);
315-
// m2m::transforms::transform_message(pose);
316-
// std::ranges::copy(common::messages::to_vector(pose), system_state_values_.begin() + root.idx_q());
317-
// }
318-
319-
// // TODO(evan-palmer): debug changing behavior based on state interface order
320-
// for (const auto & joint_name : manipulator_dofs_) {
321-
// const pinocchio::JointModel joint = model_->joints[model_->getJointId(joint_name)];
322-
// auto it = std::ranges::find_if(
323-
// state_interfaces_, [&joint_name](const auto & interface) { return interface.get_prefix_name() == joint_name;
324-
// });
325-
326-
// if (it == state_interfaces_.end()) {
327-
// RCLCPP_ERROR(logger_, std::format("Could not find joint {} in state interfaces", joint_name).c_str()); //
328-
// NOLINT return controller_interface::return_type::ERROR;
329-
// }
330-
// system_state_values_[joint.idx_q()] = it->get_optional().value_or(std::numeric_limits<double>::quiet_NaN());
331-
// }
332-
333-
// if (std::ranges::any_of(system_state_values_, [](double x) { return std::isnan(x); })) {
334-
// RCLCPP_DEBUG(logger_, "Received system state with NaN value."); // NOLINT
335-
// return controller_interface::return_type::ERROR;
336-
// }
355+
if (std::ranges::any_of(position_state_values_, [](double x) { return std::isnan(x); })) {
356+
RCLCPP_DEBUG(logger_, "Received system state with NaN value."); // NOLINT
357+
return controller_interface::return_type::ERROR;
358+
}
337359

338360
return controller_interface::return_type::OK;
339361
}
@@ -365,14 +387,16 @@ auto IKController::update_and_write_commands(const rclcpp::Time & /*time*/, cons
365387
const Eigen::VectorXd q = Eigen::VectorXd::Map(position_state_values_.data(), position_state_values_.size());
366388
const Eigen::Affine3d target_pose = to_eigen(reference_interfaces_);
367389

390+
// TODO(anyone): add solver support for velocity states
391+
// right now we only use the positions for the solver
368392
const auto result = solver_->solve(period, target_pose, q);
369393

370394
if (!result.has_value()) {
371395
const auto err = result.error();
372396
if (err == ik_solvers::SolverError::NO_SOLUTION) {
373-
RCLCPP_WARN(logger_, "The solver could not find a solution to the current IK problem");
397+
RCLCPP_WARN(logger_, "The solver could not find a solution to the current IK problem"); // NOLINT
374398
} else if (err == ik_solvers::SolverError::SOLVER_ERROR) {
375-
RCLCPP_WARN(logger_, "The solver experienced an error while solving the IK problem");
399+
RCLCPP_WARN(logger_, "The solver experienced an error while solving the IK problem"); // NOLINT
376400
}
377401
return controller_interface::return_type::OK;
378402
}
@@ -391,40 +415,19 @@ auto IKController::update_and_write_commands(const rclcpp::Time & /*time*/, cons
391415
m2m::transforms::transform_message(pose);
392416
std::ranges::copy(common::messages::to_vector(pose), point.positions.begin());
393417

394-
if (point.positions.size() != n_pos_dofs_) {
395-
RCLCPP_ERROR(
396-
logger_,
397-
std::format(
398-
"IK solution has mismatched position dimensions. Expected {} but got {}", n_pos_dofs_, point.positions.size())
399-
.c_str());
400-
return controller_interface::return_type::ERROR;
401-
}
402-
403-
if (point.velocities.size() != n_vel_dofs_) {
404-
RCLCPP_ERROR(
405-
logger_,
406-
std::format(
407-
"IK solution has mismatched velocity dimensions. Expected {} but got {}", n_vel_dofs_, point.velocities.size())
408-
.c_str());
409-
return controller_interface::return_type::ERROR;
410-
}
411-
412-
if (has_pos_interface_) {
413-
for (std::size_t i = 0; i < n_pos_dofs_; ++i) {
418+
if (use_position_commands_) {
419+
for (const auto & [i, joint_name] : std::views::enumerate(position_interface_names_)) {
414420
if (!command_interfaces_[i].set_value(point.positions[i])) {
415-
RCLCPP_WARN( // NOLINT
416-
logger_,
417-
std::format("Failed to set position command value for joint {}", pos_dofs_[i]).c_str());
421+
RCLCPP_WARN(logger_, "Failed to set position command value for joint %s", joint_name.c_str()); // NOLINT
418422
}
419423
}
420424
}
421425

422-
if (has_vel_interface_) {
423-
for (std::size_t i = 0; i < n_vel_dofs_; ++i) {
424-
const std::size_t idx = has_pos_interface_ ? n_pos_dofs_ + i : i;
426+
if (use_velocity_commands_) {
427+
for (const auto & [i, joint_name] : std::views::enumerate(velocity_interface_names_)) {
428+
const std::size_t idx = use_position_commands_ ? position_interface_names_.size() + i : i;
425429
if (!command_interfaces_[idx].set_value(point.velocities[i])) {
426-
// NOLINTNEXTLINE
427-
RCLCPP_WARN(logger_, std::format("Failed to set velocity command value for joint {}", vel_dofs_[i]).c_str());
430+
RCLCPP_WARN(logger_, "Failed to set velocity command value for joint %s", joint_name.c_str()); // NOLINT
428431
}
429432
}
430433
}

0 commit comments

Comments
 (0)