Skip to content

Commit 456937f

Browse files
author
Charlles Abreu
committed
Fixed python wrapping procedure
1 parent e20fa9f commit 456937f

File tree

5 files changed

+61
-24
lines changed

5 files changed

+61
-24
lines changed

python/openmmcppforces/CMakeLists.txt

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
# Execute SWIG to generate source code for the Python module.
22

3-
add_custom_target(PythonWrapper DEPENDS "${MODULE_NAME}.i")
4-
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/${MODULE_NAME}.i ${CMAKE_CURRENT_BINARY_DIR})
3+
file(GLOB SWIG_SOURCE_FILES ${CMAKE_CURRENT_SOURCE_DIR}/*.i)
4+
add_custom_target(PythonWrapper DEPENDS ${SWIG_SOURCE_FILES})
5+
foreach(file ${SWIG_SOURCE_FILES})
6+
configure_file(${file} ${CMAKE_CURRENT_BINARY_DIR})
7+
endforeach(file ${SWIG_SOURCE_FILES})
58
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/__init__.py ${CMAKE_CURRENT_BINARY_DIR})
69
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/py.typed ${CMAKE_CURRENT_BINARY_DIR} COPYONLY)
710

python/openmmcppforces/header.i

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
%header %{
2+
namespace OpenMM {
3+
4+
PyObject *copyVVec3ToList(std::vector<Vec3> vVec3) {
5+
int i, n;
6+
PyObject *pyList;
7+
8+
n=vVec3.size();
9+
pyList=PyList_New(n);
10+
PyObject* mm = PyImport_AddModule("openmm");
11+
PyObject* vec3 = PyObject_GetAttrString(mm, "Vec3");
12+
for (i=0; i<n; i++) {
13+
OpenMM::Vec3& v = vVec3.at(i);
14+
PyObject* args = Py_BuildValue("(d,d,d)", v[0], v[1], v[2]);
15+
PyObject* pyVec = PyObject_CallObject(vec3, args);
16+
Py_DECREF(args);
17+
PyList_SET_ITEM(pyList, i, pyVec);
18+
}
19+
return pyList;
20+
}
21+
22+
int isNumpyAvailable() {
23+
static bool initialized = false;
24+
static bool available = false;
25+
if (!initialized) {
26+
initialized = true;
27+
available = (_import_array() >= 0);
28+
}
29+
return available;
30+
}
31+
32+
} // namespace OpenMM
33+
%}

python/openmmcppforces/openmmcppforces.i

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,5 @@
1-
21
%module openmmcppforces
32

4-
%import(module="openmm") "swig/OpenMMSwigHeaders.i"
5-
%include "swig/typemaps.i"
6-
%include <std_string.i>
7-
%include <std_vector.i>
8-
%include <std_map.i>
9-
10-
namespace std {
11-
%template(vectord) vector<double>;
12-
%template(vectorstring) vector<string>;
13-
%template(mapstringstring) map<string,string>;
14-
%template(mapstringdouble) map<string,double>;
15-
};
16-
173
%{
184
#define SWIG_PYTHON_CAST_MODE
195
#include "ConcertedRMSDForce.h"
@@ -22,8 +8,22 @@ namespace std {
228
#include "OpenMMDrude.h"
239
#include "openmm/RPMDIntegrator.h"
2410
#include "openmm/RPMDMonteCarloBarostat.h"
11+
#include "openmm/Force.h"
12+
#include "openmm/Vec3.h"
13+
#include <numpy/ndarrayobject.h>
14+
15+
using namespace OpenMM;
2516
%}
2617

18+
%import(module="openmm") "swig/OpenMMSwigHeaders.i"
19+
%include "swig/typemaps.i"
20+
%include "header.i"
21+
%include "std_vector.i"
22+
23+
namespace std {
24+
%template(vectori) vector<int>;
25+
};
26+
2727
%pythoncode %{
2828
__version__ = "@CMAKE_PROJECT_VERSION@"
2929
%}
@@ -64,11 +64,11 @@ namespace OpenMMCPPForces {
6464
*/
6565
class ConcertedRMSDForce : public OpenMM::Force {
6666
public:
67-
explicit ConcertedRMSDForce(const std::vector<OpenMM::Vec3>& referencePositions);
67+
explicit ConcertedRMSDForce(const std::vector<Vec3>& referencePositions);
6868
/**
6969
* Get the reference positions to compute the deviation from.
7070
*/
71-
const std::vector<OpenMM::Vec3>& getReferencePositions() const;
71+
const std::vector<Vec3>& getReferencePositions() const;
7272
/**
7373
* Set the reference positions to compute the deviation from.
7474
*
@@ -79,7 +79,7 @@ public:
7979
* vector must equal the number of particles in the system, even if not all
8080
* particles are used in computing the concerted RMSD.
8181
*/
82-
void setReferencePositions(const std::vector<OpenMM::Vec3>& positions);
82+
void setReferencePositions(const std::vector<Vec3>& positions);
8383
/**
8484
* Add a group of particles to be included in the concerted RMSD calculation.
8585
*
@@ -125,12 +125,12 @@ public:
125125
void setGroup(int index, const std::vector<int>& particles);
126126
/**
127127
* Update the reference positions and particle groups in a Context to match those stored
128-
* in this Force object. This method provides an efficient way to update these parameters
128+
* in this OpenMM::Force object. This method provides an efficient way to update these parameters
129129
* in an existing Context without needing to reinitialize it. Simply call setReferencePositions()
130130
* and setGroup() to modify this object's parameters, then call updateParametersInContext()
131131
* to copy them over to the Context.
132132
*/
133-
void updateParametersInContext(OpenMM::Context& context);
133+
void updateParametersInContext(Context& context);
134134
/**
135135
* Returns whether or not this force makes use of periodic boundary
136136
* conditions.
@@ -139,7 +139,7 @@ public:
139139
*/
140140
bool usesPeriodicBoundaryConditions();
141141
/*
142-
* Add methods for casting a Force to an ExtendedCustomCVForce.
142+
* Add methods for casting a OpenMM::Force to an ExtendedCustomCVForce.
143143
*/
144144
%extend {
145145
static OpenMMCPPForces::ConcertedRMSDForce& cast(OpenMM::Force& force) {

python/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[build-system]
2-
requires = ["setuptools"]
2+
requires = ["setuptools", "numpy >= 1.19"]
33
build-backend = "setuptools.build_meta"
44

55
[project]

python/setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from setuptools import setup, Extension
22
import os
33
import platform
4+
import numpy as np
45

56
openmm_dir = '@OPENMM_DIR@'
67
header_dir = '@PLUGIN_HEADER_DIR@'
@@ -26,7 +27,7 @@
2627
name='@MODULE_NAME@._@MODULE_NAME@',
2728
sources=[os.path.join(module_dir, '@WRAP_FILE@')],
2829
libraries=['OpenMM', '@PLUGIN_LIBRARY_NAME@'],
29-
include_dirs=[os.path.join(openmm_dir, 'include'), header_dir],
30+
include_dirs=[os.path.join(openmm_dir, 'include'), header_dir, np.get_include()],
3031
library_dirs=[os.path.join(openmm_dir, 'lib'), library_dir],
3132
extra_compile_args=extra_compile_args,
3233
extra_link_args=extra_link_args,

0 commit comments

Comments
 (0)