Skip to content

Commit b7a64f2

Browse files
Add torch pendulum (#32)
1 parent e7bb0d5 commit b7a64f2

File tree

6 files changed

+662
-54
lines changed

6 files changed

+662
-54
lines changed

CMakeLists.txt

Lines changed: 53 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ option(CDDP-CPP_BUILD_TESTS "Whether to build tests." ON)
3030
option(CDDP-CPP_GUROBI "Whether to use Gurobi solver." OFF)
3131
option(GUROBI_ROOT "Path to Gurobi installation" "")
3232
set(GUROBI_ROOT /home/tom/.local/lib/gurobi1103/linux64)
33-
option(CDDP-CPP_TORCH "Whether to use LibTorch." ON)
33+
set(CDDP-CPP_TORCH "Whether to use LibTorch." ON) # cannot be turned off
3434
option(CDDP-CPP_TORCH_GPU "Whether to use GPU." ON)
3535

3636
# Find packages
@@ -66,6 +66,55 @@ if (CDDP-CPP_BUILD_TESTS)
6666
include(GoogleTest)
6767
endif()
6868

69+
# LibTorch
70+
if (CDDP-CPP_TORCH)
71+
if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/build/libtorch/share/cmake/Torch/TorchConfig.cmake)
72+
message(STATUS "Found LibTorch at ${CMAKE_CURRENT_SOURCE_DIR}/build/libtorch")
73+
else()
74+
message(STATUS "Downloading LibTorch...")
75+
# Download and extract LibTorch
76+
if (CDDP-CPP_TORCH_GPU)
77+
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu124/libtorch-cxx11-abi-shared-with-deps-2.5.1%2Bcu124.zip")
78+
else()
79+
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-2.5.1%2Bcpu.zip")
80+
endif()
81+
82+
# Set the download directory
83+
set(DOWLOAD_DIR ${CMAKE_CURRENT_SOURCE_DIR}/build)
84+
85+
# Create the download directory
86+
file(MAKE_DIRECTORY ${DOWLOAD_DIR})
87+
88+
# Download the file
89+
file(DOWNLOAD ${LIBTORCH_URL} ${DOWLOAD_DIR}/libtorch-shared-with-deps-latest.zip
90+
STATUS DOWNLOAD_STATUS
91+
SHOW_PROGRESS
92+
)
93+
94+
# Check if the download was successful
95+
list(GET DOWNLOAD_STATUS 0 DOWNLOAD_STATUS_CODE)
96+
if(NOT DOWNLOAD_STATUS_CODE EQUAL 0)
97+
message(FATAL_ERROR "Failed to download LibTorch.")
98+
endif()
99+
100+
# Extract the file
101+
execute_process(
102+
COMMAND ${CMAKE_COMMAND} -E tar xvf ${DOWLOAD_DIR}/libtorch-shared-with-deps-latest.zip
103+
WORKING_DIRECTORY ${DOWLOAD_DIR}
104+
)
105+
106+
# Remove the zip file
107+
file(REMOVE ${DOWLOAD_DIR}/libtorch-shared-with-deps-latest.zip)
108+
endif()
109+
110+
# Set the path to the LibTorch installation
111+
set(LIBTORCH_DIR ${CMAKE_CURRENT_SOURCE_DIR}/build/libtorch)
112+
113+
find_package(Torch REQUIRED PATHS ${LIBTORCH_DIR} NO_DEFAULT_PATH)
114+
message(STATUS "Found LibTorch: ${TORCH_LIBRARIES}")
115+
endif()
116+
117+
69118
# Include directories
70119
include_directories(
71120
${CMAKE_CURRENT_SOURCE_DIR}/include
@@ -79,6 +128,7 @@ set(cddp_core_srcs
79128
src/cddp_core/objective.cpp
80129
src/cddp_core/constraint.cpp
81130
src/cddp_core/cddp_core.cpp
131+
src/cddp_core/torch_dynamical_system.cpp
82132
)
83133

84134
set(dynamics_model_srcs
@@ -102,11 +152,13 @@ target_link_libraries(${PROJECT_NAME}
102152
Python3::Python
103153
Python3::Module
104154
Python3::NumPy
155+
${TORCH_LIBRARIES}
105156
)
106157

107158
target_include_directories(${PROJECT_NAME} PUBLIC
108159
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include/cddp-cpp>
109160
$<INSTALL_INTERFACE:include>
161+
${TORCH_INCLUDE_DIRS}
110162
)
111163

112164
# Gurobi
@@ -132,57 +184,6 @@ if (CDDP-CPP_GUROBI)
132184
endif()
133185
endif()
134186

135-
# LibTorch
136-
if (CDDP-CPP_TORCH)
137-
if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/build/libtorch/share/cmake/Torch/TorchConfig.cmake)
138-
message(STATUS "Found LibTorch at ${CMAKE_CURRENT_SOURCE_DIR}/build/libtorch")
139-
else()
140-
message(STATUS "Downloading LibTorch...")
141-
# Download and extract LibTorch
142-
if (CDDP-CPP_TORCH_GPU)
143-
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu124/libtorch-cxx11-abi-shared-with-deps-2.5.1%2Bcu124.zip")
144-
else()
145-
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-2.5.1%2Bcpu.zip")
146-
endif()
147-
148-
# Set the download directory
149-
set(DOWLOAD_DIR ${CMAKE_CURRENT_SOURCE_DIR}/build)
150-
151-
# Create the download directory
152-
file(MAKE_DIRECTORY ${DOWLOAD_DIR})
153-
154-
# Download the file
155-
file(DOWNLOAD ${LIBTORCH_URL} ${DOWLOAD_DIR}/libtorch-shared-with-deps-latest.zip
156-
STATUS DOWNLOAD_STATUS
157-
SHOW_PROGRESS
158-
)
159-
160-
# Check if the download was successful
161-
list(GET DOWNLOAD_STATUS 0 DOWNLOAD_STATUS_CODE)
162-
if(NOT DOWNLOAD_STATUS_CODE EQUAL 0)
163-
message(FATAL_ERROR "Failed to download LibTorch.")
164-
endif()
165-
166-
# Extract the file
167-
execute_process(
168-
COMMAND ${CMAKE_COMMAND} -E tar xvf ${DOWLOAD_DIR}/libtorch-shared-with-deps-latest.zip
169-
WORKING_DIRECTORY ${DOWLOAD_DIR}
170-
)
171-
172-
# Remove the zip file
173-
file(REMOVE ${DOWLOAD_DIR}/libtorch-shared-with-deps-latest.zip)
174-
endif()
175-
176-
# Set the path to the LibTorch installation
177-
set(LIBTORCH_DIR ${CMAKE_CURRENT_SOURCE_DIR}/build/libtorch)
178-
179-
find_package(Torch REQUIRED PATHS ${LIBTORCH_DIR} NO_DEFAULT_PATH)
180-
target_link_libraries(${PROJECT_NAME} ${TORCH_LIBRARIES})
181-
target_include_directories(${PROJECT_NAME} PUBLIC ${TORCH_INCLUDE_DIRS})
182-
message(STATUS "Found LibTorch: ${TORCH_LIBRARIES}")
183-
endif()
184-
185-
186187
# Build and register tests.
187188
if (CDDP-CPP_BUILD_TESTS)
188189
add_subdirectory(tests)

include/cddp-cpp/cddp.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,14 @@
2121
#include <vector>
2222
#include <Eigen/Dense>
2323

24-
25-
24+
// #include "cddp-cpp/sdqp.hpp"
2625
#include "cddp_core/dynamical_system.hpp"
2726
#include "cddp_core/objective.hpp"
2827
#include "cddp_core/constraint.hpp"
2928
#include "cddp_core/cddp_core.hpp"
3029

30+
#include "cddp_core/torch_dynamical_system.hpp"
31+
3132
// Models
3233
#include "dynamics_model/pendulum.hpp"
3334
#include "dynamics_model/dubins_car.hpp"
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/*
2+
Copyright 2024 Tomo Sasaki
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
#ifndef CDDP_TORCH_DYNAMICAL_SYSTEM_HPP
17+
#define CDDP_TORCH_DYNAMICAL_SYSTEM_HPP
18+
19+
#include "cddp_core/dynamical_system.hpp"
20+
#include <torch/torch.h>
21+
#include <Eigen/Dense>
22+
23+
namespace cddp {
24+
25+
class DynamicsModelInterface : public torch::nn::Module {
26+
public:
27+
virtual torch::Tensor forward(std::vector<torch::Tensor> inputs) = 0;
28+
virtual ~DynamicsModelInterface() = default;
29+
};
30+
31+
class TorchDynamicalSystem : public DynamicalSystem {
32+
public:
33+
TorchDynamicalSystem(int state_dim,
34+
int control_dim,
35+
double timestep,
36+
std::string integration_type,
37+
std::shared_ptr<DynamicsModelInterface> model,
38+
bool use_gpu = false);
39+
40+
// Override core dynamics methods
41+
Eigen::VectorXd getContinuousDynamics(const Eigen::VectorXd& state,
42+
const Eigen::VectorXd& control) const override;
43+
44+
Eigen::MatrixXd getStateJacobian(const Eigen::VectorXd& state,
45+
const Eigen::VectorXd& control) const override;
46+
47+
Eigen::MatrixXd getControlJacobian(const Eigen::VectorXd& state,
48+
const Eigen::VectorXd& control) const override;
49+
50+
// Add batch processing capability
51+
std::vector<Eigen::VectorXd> getBatchDynamics(
52+
const std::vector<Eigen::VectorXd>& states,
53+
const std::vector<Eigen::VectorXd>& controls) const;
54+
55+
// Optional: Override Hessian computations if needed
56+
Eigen::MatrixXd getStateHessian(const Eigen::VectorXd& state,
57+
const Eigen::VectorXd& control) const override;
58+
59+
Eigen::MatrixXd getControlHessian(const Eigen::VectorXd& state,
60+
const Eigen::VectorXd& control) const override;
61+
62+
private:
63+
// Helper methods for tensor conversions
64+
torch::Tensor eigenToTorch(const Eigen::VectorXd& eigen_vec, bool requires_grad = false) const;
65+
Eigen::VectorXd torchToEigen(const torch::Tensor& tensor) const;
66+
67+
std::shared_ptr<DynamicsModelInterface> model_;
68+
bool use_gpu_;
69+
torch::Device device_;
70+
};
71+
72+
} // namespace cddp
73+
74+
#endif // CDDP_TORCH_DYNAMICAL_SYSTEM_HPP

0 commit comments

Comments
 (0)