From 2e09003d62d36fa551a04276148f64daf4c5fefc Mon Sep 17 00:00:00 2001 From: Alain Denzler Date: Fri, 25 Apr 2025 10:12:18 +0200 Subject: [PATCH 01/31] add get_modelid helper --- mujoco_warp/_src/support.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/mujoco_warp/_src/support.py b/mujoco_warp/_src/support.py index 847bce8b..dd649e68 100644 --- a/mujoco_warp/_src/support.py +++ b/mujoco_warp/_src/support.py @@ -349,3 +349,17 @@ def jac( jacr = cdof_ang return jacp, jacr + +@wp.func +def get_modelId(array: wp.array(dtype=Any), worldid: int) -> int: + return worldid % array.shape[0] + +@wp.func +def get_modelId(array: wp.array2d(dtype=Any), worldid: int) -> int: + return worldid % array.shape[0] + +@wp.func +def get_modelId(array: wp.array3d(dtype=Any), worldid: int) -> int: + return worldid % array.shape[0] + + From d7d6a9e9f881e63d8e479981ba1761de5244fd94 Mon Sep 17 00:00:00 2001 From: Alain Denzler Date: Fri, 25 Apr 2025 10:49:53 +0200 Subject: [PATCH 02/31] start implementing things with helper --- mujoco_warp/_src/constraint.py | 7 ++++--- mujoco_warp/_src/io.py | 8 +++++++- mujoco_warp/_src/smooth.py | 5 +++-- mujoco_warp/_src/support.py | 20 +++++++++++--------- mujoco_warp/_src/types.py | 4 ++-- 5 files changed, 27 insertions(+), 17 deletions(-) diff --git a/mujoco_warp/_src/constraint.py b/mujoco_warp/_src/constraint.py index efaa02f9..7b569e3e 100644 --- a/mujoco_warp/_src/constraint.py +++ b/mujoco_warp/_src/constraint.py @@ -18,6 +18,7 @@ from . import math from . import support from . import types +from .support import get_batched_value from .warp_util import event_scope wp.config.enable_backward = False @@ -182,7 +183,7 @@ def _efc_equality_joint( # Two joint constraint qposadr2 = m.jnt_qposadr[jntid_2] dofadr2 = m.jnt_dofadr[jntid_2] - dif = d.qpos[worldid, qposadr2] - m.qpos0[qposadr2] + dif = d.qpos[worldid, qposadr2] - get_batched_value(m.qpos0, worldid, qposadr2) # Horner's method for polynomials rhs = data[0] + dif * (data[1] + dif * (data[2] + dif * (data[3] + dif * data[4]))) @@ -190,14 +191,14 @@ def _efc_equality_joint( 2.0 * data[2] + dif * (3.0 * data[3] + dif * 4.0 * data[4]) ) - pos = d.qpos[worldid, qposadr1] - m.qpos0[qposadr1] - rhs + pos = d.qpos[worldid, qposadr1] - get_batched_value(m.qpos0, worldid, qposadr1) - rhs Jqvel = d.qvel[worldid, dofadr1] - d.qvel[worldid, dofadr2] * deriv_2 invweight = m.dof_invweight0[dofadr1] + m.dof_invweight0[dofadr2] d.efc.J[efcid, dofadr2] = -deriv_2 else: # Single joint constraint - pos = d.qpos[worldid, qposadr1] - m.qpos0[qposadr1] - data[0] + pos = d.qpos[worldid, qposadr1] - get_batched_value(m.qpos0, worldid, qposadr1) - data[0] Jqvel = d.qvel[worldid, dofadr1] invweight = m.dof_invweight0[dofadr1] diff --git a/mujoco_warp/_src/io.py b/mujoco_warp/_src/io.py index 299ce2b4..ab55dde8 100644 --- a/mujoco_warp/_src/io.py +++ b/mujoco_warp/_src/io.py @@ -69,6 +69,12 @@ def geom_pair(m: mujoco.MjModel) -> Tuple[np.array, np.array]: return np.array(geompairs), np.array(pairids) +def create_nmodel_batched_array(mjm_array, dtype): + array = wp.array(mjm_array, dtype=dtype) + array.ndim += 1 + array.shape = (1,) + array.shape + array.strides = (0,) + array.strides + return array def put_model(mjm: mujoco.MjModel) -> types.Model: # check supported features @@ -161,7 +167,7 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: m.opt.depth_extension = wp.float32(0.1) # warp only m.stat.meaninertia = mjm.stat.meaninertia - m.qpos0 = wp.array(mjm.qpos0, dtype=wp.float32, ndim=1) + m.qpos0 = create_nmodel_batched_array(mjm.qpos0, dtype=wp.float32) m.qpos_spring = wp.array(mjm.qpos_spring, dtype=wp.float32, ndim=1) # dof lower triangle row and column indices diff --git a/mujoco_warp/_src/smooth.py b/mujoco_warp/_src/smooth.py index bc6a0ae6..7c9d8b8b 100644 --- a/mujoco_warp/_src/smooth.py +++ b/mujoco_warp/_src/smooth.py @@ -31,6 +31,7 @@ from .types import array2df from .types import array3df from .types import vec10 +from .support import get_batched_value from .warp_util import event_scope from .warp_util import kernel @@ -95,9 +96,9 @@ def _level(m: Model, d: Data, leveladr: int): # correct for off-center rotation xpos = xanchor - math.rot_vec_quat(m.jnt_pos[jntadr], xquat) elif jnt_type == wp.static(JointType.SLIDE.value): - xpos += xaxis * (qpos[qadr] - m.qpos0[qadr]) + xpos += xaxis * (qpos[qadr] - get_batched_value(m.qpos0, worldid, qadr)) elif jnt_type == wp.static(JointType.HINGE.value): - qpos0 = m.qpos0[qadr] + qpos0 = get_batched_value(m.qpos0, worldid, qadr) qloc = math.axis_angle_to_quat(jnt_axis, qpos[qadr] - qpos0) xquat = math.mul_quat(xquat, qloc) # correct for off-center rotation diff --git a/mujoco_warp/_src/support.py b/mujoco_warp/_src/support.py index dd649e68..eb3a5acc 100644 --- a/mujoco_warp/_src/support.py +++ b/mujoco_warp/_src/support.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== -from typing import Tuple +from typing import Tuple, Any import mujoco import warp as wp @@ -350,16 +350,18 @@ def jac( return jacp, jacr -@wp.func -def get_modelId(array: wp.array(dtype=Any), worldid: int) -> int: - return worldid % array.shape[0] @wp.func -def get_modelId(array: wp.array2d(dtype=Any), worldid: int) -> int: - return worldid % array.shape[0] +def get_batched_value(array: wp.array2d(dtype=Any), worldid: wp.int32, i: wp.int32): + modelid = worldid % array.shape[0] + return array[modelid, i] @wp.func -def get_modelId(array: wp.array3d(dtype=Any), worldid: int) -> int: - return worldid % array.shape[0] - +def get_batched_value(array: wp.array3d(dtype=Any), worldid: wp.int32, i: wp.int32, j: wp.int32): + modelid = worldid % array.shape[0] + return array[modelid, i, j] +@wp.func +def get_batched_value(array: wp.array4d(dtype=Any), worldid: wp.int32, i: wp.int32, j: wp.int32, k: wp.int32): + modelid = worldid % array.shape[0] + return array[modelid, i, j, k] diff --git a/mujoco_warp/_src/types.py b/mujoco_warp/_src/types.py index 86b13041..f9c8d802 100644 --- a/mujoco_warp/_src/types.py +++ b/mujoco_warp/_src/types.py @@ -554,7 +554,7 @@ class Model: npair: number of predefined geom pairs () opt: physics options stat: model statistics - qpos0: qpos values at default pose (nq,) + qpos0: qpos values at default pose (nmodel, nq) qpos_spring: reference pose for springs (nq,) body_tree: BFS ordering of body ids body_treeadr: starting index of each body tree level @@ -764,7 +764,7 @@ class Model: npair: int opt: Option stat: Statistic - qpos0: wp.array(dtype=wp.float32, ndim=1) + qpos0: wp.array(dtype=wp.float32, ndim=2) qpos_spring: wp.array(dtype=wp.float32, ndim=1) body_tree: wp.array(dtype=wp.int32, ndim=1) # warp only body_treeadr: wp.array(dtype=wp.int32, ndim=1) # warp only From 6ba92b0eac149574b2b29c921700b5bf047f60cd Mon Sep 17 00:00:00 2001 From: Alain Denzler Date: Fri, 25 Apr 2025 14:57:19 +0200 Subject: [PATCH 03/31] update ast_analyzer_test --- contrib/kernel_analyzer/kernel_analyzer/ast_analyzer_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/contrib/kernel_analyzer/kernel_analyzer/ast_analyzer_test.py b/contrib/kernel_analyzer/kernel_analyzer/ast_analyzer_test.py index dc2326dd..ca498c3d 100644 --- a/contrib/kernel_analyzer/kernel_analyzer/ast_analyzer_test.py +++ b/contrib/kernel_analyzer/kernel_analyzer/ast_analyzer_test.py @@ -136,7 +136,7 @@ def test_all_issues( @kernel def test_no_issues( # Model: - qpos0: wp.array(dtype=wp.float32, ndim=1), + qpos0: wp.array(dtype=wp.float32, ndim=2), geom_pos: wp.array(dtype=wp.vec3, ndim=1), # Data in: qpos_in: wp.array(dtype=wp.float32, ndim=2), From b21f290abca5687fee335cd3ed3c09c5ff13565d Mon Sep 17 00:00:00 2001 From: Alain Denzler Date: Fri, 25 Apr 2025 15:15:18 +0200 Subject: [PATCH 04/31] add test + dof_damping --- mujoco_warp/_src/forward.py | 7 +++--- mujoco_warp/_src/io_test.py | 43 +++++++++++++++++++++++++++++++++++++ mujoco_warp/_src/passive.py | 3 ++- mujoco_warp/_src/support.py | 4 ++++ mujoco_warp/_src/types.py | 6 +++--- 5 files changed, 56 insertions(+), 7 deletions(-) diff --git a/mujoco_warp/_src/forward.py b/mujoco_warp/_src/forward.py index 01fc6385..39e1d0fd 100644 --- a/mujoco_warp/_src/forward.py +++ b/mujoco_warp/_src/forward.py @@ -26,6 +26,7 @@ from . import smooth from . import solver from .support import xfrc_accumulate +from .support import get_batched_value from .types import MJ_MINVAL from .types import BiasType from .types import Data @@ -193,7 +194,7 @@ def add_damping_sum_qfrc_kernel_sparse(m: Model, d: Data): worldid, tid = wp.tid() dof_Madr = m.dof_Madr[tid] - d.qM_integration[worldid, 0, dof_Madr] += m.opt.timestep * m.dof_damping[tid] + d.qM_integration[worldid, 0, dof_Madr] += m.opt.timestep * get_batched_value(m.dof_damping, worldid, tid) d.qfrc_integration[worldid, tid] = ( d.qfrc_smooth[worldid, tid] + d.qfrc_constraint[worldid, tid] @@ -215,14 +216,14 @@ def eulerdamp_fused_dense(m: Model, d: Data): def tile_eulerdamp(adr: int, size: int, tilesize: int): @kernel def eulerdamp( - m: Model, d: Data, damping: wp.array(dtype=wp.float32), leveladr: int + m: Model, d: Data, damping: wp.array2d(dtype=wp.float32), leveladr: int ): worldid, nodeid = wp.tid() dofid = m.qLD_tile[leveladr + nodeid] M_tile = wp.tile_load( d.qM[worldid], shape=(tilesize, tilesize), offset=(dofid, dofid) ) - damping_tile = wp.tile_load(damping, shape=(tilesize,), offset=(dofid,)) + damping_tile = wp.tile_load(get_batched_value(damping, worldid), shape=(tilesize,), offset=(dofid,)) damping_scaled = damping_tile * m.opt.timestep qm_integration_tile = wp.tile_diag_add(M_tile, damping_scaled) diff --git a/mujoco_warp/_src/io_test.py b/mujoco_warp/_src/io_test.py index c54f8bc7..e65b0ae3 100644 --- a/mujoco_warp/_src/io_test.py +++ b/mujoco_warp/_src/io_test.py @@ -24,6 +24,14 @@ from . import test_util +# tolerance for difference between MuJoCo and MJWarp smooth calculations - mostly +# due to float precision +_TOLERANCE = 5e-5 + +def _assert_eq(a, b, name): + tol = _TOLERANCE * 10 # avoid test noise + err_msg = f"mismatch: {name}" + np.testing.assert_allclose(a, b, err_msg=err_msg, atol=tol, rtol=tol) class IOTest(absltest.TestCase): def test_make_put_data(self): @@ -291,6 +299,41 @@ def test_option_physical_constants(self): mjwarp.put_model(mjm) + def test_model_batching(self): + mjm, mjd, _, _ = test_util.fixture("humanoid/humanoid.xml", kick=True) + + m = mjwarp.put_model(mjm) + d = mjwarp.put_data(mjm, mjd, nworld=3) + + # manually create a batch of damping values + damping_orig = mjm.dof_damping + dof_damping = np.zeros((2, len(damping_orig)), dtype=np.float32) + dof_damping[0, :] = damping_orig + dof_damping[1, :] = damping_orig * 0.5 + + # set the batched damping values + m.dof_damping = wp.from_numpy(dof_damping, dtype=wp.float32) + + mjwarp.passive(m, d) + + # mujoco reference, just have 3 separate model/data strctures + mujoco.mj_passive(mjm, mjd) + + m2, d2, _, _ = test_util.fixture("humanoid/humanoid.xml") + d2.qvel = mjd.qvel # need to copy qvel because of randomization + m2.dof_damping *= 0.5 + + m3, d3, _, _ = test_util.fixture("humanoid/humanoid.xml") + d3.qvel = mjd.qvel # need to copy qvel because of randomization + + mujoco.mj_passive(m2, d2) + mujoco.mj_passive(m3, d3) + + _assert_eq(d.qfrc_damper.numpy()[0, :], mjd.qfrc_damper, "qfrc_damper") + _assert_eq(d.qfrc_damper.numpy()[1, :], d2.qfrc_damper, "qfrc_damper") + _assert_eq(d.qfrc_damper.numpy()[0, :], d3.qfrc_damper, "qfrc_damper") + + if __name__ == "__main__": wp.init() absltest.main() diff --git a/mujoco_warp/_src/passive.py b/mujoco_warp/_src/passive.py index c8cbc30b..fbc688a9 100644 --- a/mujoco_warp/_src/passive.py +++ b/mujoco_warp/_src/passive.py @@ -22,6 +22,7 @@ from .types import Model from .warp_util import event_scope from .warp_util import kernel +from .support import get_batched_value @event_scope @@ -93,7 +94,7 @@ def _spring(m: Model, d: Data): @kernel def _damper_passive(m: Model, d: Data): worldid, dofid = wp.tid() - damping = m.dof_damping[dofid] + damping = get_batched_value(m.dof_damping, worldid, dofid) qfrc_damper = -damping * d.qvel[worldid, dofid] d.qfrc_damper[worldid, dofid] = qfrc_damper diff --git a/mujoco_warp/_src/support.py b/mujoco_warp/_src/support.py index eb3a5acc..fa630c89 100644 --- a/mujoco_warp/_src/support.py +++ b/mujoco_warp/_src/support.py @@ -350,6 +350,10 @@ def jac( return jacp, jacr +@wp.func +def get_batched_value(array: wp.array2d(dtype=Any), worldid: wp.int32): + modelid = worldid % array.shape[0] + return array[modelid] @wp.func def get_batched_value(array: wp.array2d(dtype=Any), worldid: wp.int32, i: wp.int32): diff --git a/mujoco_warp/_src/types.py b/mujoco_warp/_src/types.py index f9c8d802..3a322590 100644 --- a/mujoco_warp/_src/types.py +++ b/mujoco_warp/_src/types.py @@ -434,7 +434,7 @@ class Constraint: gauss: gauss Cost (nworld,) cost: constraint + Gauss cost (nworld,) prev_cost: cost from previous iter (nworld,) - solver_niter: number of solver iterations (nworld,) + solver_niter: number of solver iterations (nworld, active: active (quadratic) constraints (njmax,) gtol: linesearch termination tolerance (nworld,) mv: qM @ search (nworld, nv) @@ -620,7 +620,7 @@ class Model: dof_parentid: id of dof's parent; -1: none (nv,) dof_Madr: dof address in M-diagonal (nv,) dof_armature: dof armature inertia/mass (nv,) - dof_damping: damping coefficient (nv,) + dof_damping: damping coefficient (nmodel, nv) dof_invweight0: diag. inverse inertia in qpos0 (nv,) dof_frictionloss: dof friction loss (nv,) dof_solimp: constraint solver impedance: frictionloss (nv, NIMP) @@ -830,7 +830,7 @@ class Model: dof_parentid: wp.array(dtype=wp.int32, ndim=1) dof_Madr: wp.array(dtype=wp.int32, ndim=1) dof_armature: wp.array(dtype=wp.float32, ndim=1) - dof_damping: wp.array(dtype=wp.float32, ndim=1) + dof_damping: wp.array(dtype=wp.float32, ndim=2) dof_invweight0: wp.array(dtype=wp.float32, ndim=1) dof_frictionloss: wp.array(dtype=wp.float32, ndim=1) dof_solimp: wp.array(dtype=vec5, ndim=1) From 819d9c1fbdd32652441649be088ca8530096fc5f Mon Sep 17 00:00:00 2001 From: Alain Denzler Date: Fri, 25 Apr 2025 15:24:37 +0200 Subject: [PATCH 05/31] fixes --- mujoco_warp/_src/forward.py | 12 ++++++------ mujoco_warp/_src/io.py | 2 +- mujoco_warp/_src/support.py | 3 ++- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/mujoco_warp/_src/forward.py b/mujoco_warp/_src/forward.py index 39e1d0fd..f769693d 100644 --- a/mujoco_warp/_src/forward.py +++ b/mujoco_warp/_src/forward.py @@ -26,7 +26,7 @@ from . import smooth from . import solver from .support import xfrc_accumulate -from .support import get_batched_value +from .support import get_batched_array from .types import MJ_MINVAL from .types import BiasType from .types import Data @@ -194,7 +194,7 @@ def add_damping_sum_qfrc_kernel_sparse(m: Model, d: Data): worldid, tid = wp.tid() dof_Madr = m.dof_Madr[tid] - d.qM_integration[worldid, 0, dof_Madr] += m.opt.timestep * get_batched_value(m.dof_damping, worldid, tid) + d.qM_integration[worldid, 0, dof_Madr] += m.opt.timestep * get_batched_array(m.dof_damping, worldid, tid) d.qfrc_integration[worldid, tid] = ( d.qfrc_smooth[worldid, tid] + d.qfrc_constraint[worldid, tid] @@ -223,7 +223,7 @@ def eulerdamp( M_tile = wp.tile_load( d.qM[worldid], shape=(tilesize, tilesize), offset=(dofid, dofid) ) - damping_tile = wp.tile_load(get_batched_value(damping, worldid), shape=(tilesize,), offset=(dofid,)) + damping_tile = wp.tile_load(get_batched_array(damping, worldid), shape=(tilesize,), offset=(dofid,)) damping_scaled = damping_tile * m.opt.timestep qm_integration_tile = wp.tile_diag_add(M_tile, damping_scaled) @@ -392,7 +392,7 @@ def actuator_bias_gain_vel(m: Model, d: Data): d.act_vel_integration[worldid, actid] = bias_vel + gain_vel * ctrl def qderiv_actuator_damping_fused( - m: Model, d: Data, damping: wp.array(dtype=wp.float32) + m: Model, d: Data, damping: wp.array2d(dtype=wp.float32) ): if actuation_enabled: block_dim = 64 @@ -408,7 +408,7 @@ def qderiv_actuator_damping_tiled( ): @kernel def qderiv_actuator_fused_kernel( - m: Model, d: Data, damping: wp.array(dtype=wp.float32), leveladr: int + m: Model, d: Data, damping: wp.array2d(dtype=wp.float32), leveladr: int ): worldid, nodeid = wp.tid() offset_nv = m.actuator_moment_offset_nv[leveladr + nodeid] @@ -435,7 +435,7 @@ def qderiv_actuator_fused_kernel( ) if wp.static(passive_enabled): - dof_damping = wp.tile_load(damping, shape=tilesize_nv, offset=offset_nv) + dof_damping = wp.tile_load(get_batched_array(damping, worldid), shape=tilesize_nv, offset=offset_nv) negative = wp.neg(dof_damping) qderiv_tile = wp.tile_diag_add(qderiv_tile, negative) diff --git a/mujoco_warp/_src/io.py b/mujoco_warp/_src/io.py index ab55dde8..dc2e6c2a 100644 --- a/mujoco_warp/_src/io.py +++ b/mujoco_warp/_src/io.py @@ -458,7 +458,7 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: m.dof_parentid = wp.array(mjm.dof_parentid, dtype=wp.int32, ndim=1) m.dof_Madr = wp.array(mjm.dof_Madr, dtype=wp.int32, ndim=1) m.dof_armature = wp.array(mjm.dof_armature, dtype=wp.float32, ndim=1) - m.dof_damping = wp.array(mjm.dof_damping, dtype=wp.float32, ndim=1) + m.dof_damping = create_nmodel_batched_array(mjm.dof_damping, dtype=wp.float32) m.dof_frictionloss = wp.array(mjm.dof_frictionloss, dtype=wp.float32, ndim=1) m.dof_solimp = wp.array(mjm.dof_solimp, dtype=types.vec5, ndim=1) m.dof_solref = wp.array(mjm.dof_solref, dtype=wp.vec2, ndim=1) diff --git a/mujoco_warp/_src/support.py b/mujoco_warp/_src/support.py index fa630c89..fbaf0735 100644 --- a/mujoco_warp/_src/support.py +++ b/mujoco_warp/_src/support.py @@ -351,7 +351,8 @@ def jac( return jacp, jacr @wp.func -def get_batched_value(array: wp.array2d(dtype=Any), worldid: wp.int32): +def get_batched_array(array: wp.array2d(dtype=Any), worldid: wp.int32): + """Returns the array slice for the given worldid.""" modelid = worldid % array.shape[0] return array[modelid] From db068dc7646a6e0a14d082383b1577ce698804d5 Mon Sep 17 00:00:00 2001 From: Alain Denzler Date: Fri, 25 Apr 2025 15:26:23 +0200 Subject: [PATCH 06/31] small fix --- mujoco_warp/_src/forward.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mujoco_warp/_src/forward.py b/mujoco_warp/_src/forward.py index f769693d..32e0c523 100644 --- a/mujoco_warp/_src/forward.py +++ b/mujoco_warp/_src/forward.py @@ -27,6 +27,7 @@ from . import solver from .support import xfrc_accumulate from .support import get_batched_array +from .support import get_batched_value from .types import MJ_MINVAL from .types import BiasType from .types import Data @@ -194,7 +195,7 @@ def add_damping_sum_qfrc_kernel_sparse(m: Model, d: Data): worldid, tid = wp.tid() dof_Madr = m.dof_Madr[tid] - d.qM_integration[worldid, 0, dof_Madr] += m.opt.timestep * get_batched_array(m.dof_damping, worldid, tid) + d.qM_integration[worldid, 0, dof_Madr] += m.opt.timestep * get_batched_value(m.dof_damping, worldid, tid) d.qfrc_integration[worldid, tid] = ( d.qfrc_smooth[worldid, tid] + d.qfrc_constraint[worldid, tid] From 173679a71f6bdd3366179c1f53e8b42071495236 Mon Sep 17 00:00:00 2001 From: Alain Denzler Date: Fri, 25 Apr 2025 15:35:26 +0200 Subject: [PATCH 07/31] qpos_spring --- mujoco_warp/_src/io.py | 2 +- mujoco_warp/_src/passive.py | 24 ++++++++++++------------ mujoco_warp/_src/types.py | 4 ++-- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/mujoco_warp/_src/io.py b/mujoco_warp/_src/io.py index dc2e6c2a..ab207e1f 100644 --- a/mujoco_warp/_src/io.py +++ b/mujoco_warp/_src/io.py @@ -168,7 +168,7 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: m.stat.meaninertia = mjm.stat.meaninertia m.qpos0 = create_nmodel_batched_array(mjm.qpos0, dtype=wp.float32) - m.qpos_spring = wp.array(mjm.qpos_spring, dtype=wp.float32, ndim=1) + m.qpos_spring = create_nmodel_batched_array(mjm.qpos_spring, dtype=wp.float32) # dof lower triangle row and column indices dof_tri_row, dof_tri_col = np.tril_indices(mjm.nv) diff --git a/mujoco_warp/_src/passive.py b/mujoco_warp/_src/passive.py index fbc688a9..fe295ea7 100644 --- a/mujoco_warp/_src/passive.py +++ b/mujoco_warp/_src/passive.py @@ -47,9 +47,9 @@ def _spring(m: Model, d: Data): if jnt_type == wp.static(JointType.FREE.value): dif = wp.vec3( - d.qpos[worldid, qposid + 0] - m.qpos_spring[qposid + 0], - d.qpos[worldid, qposid + 1] - m.qpos_spring[qposid + 1], - d.qpos[worldid, qposid + 2] - m.qpos_spring[qposid + 2], + d.qpos[worldid, qposid + 0] - get_batched_value(m.qpos_spring, worldid, qposid + 0), + d.qpos[worldid, qposid + 1] - get_batched_value(m.qpos_spring, worldid, qposid + 1), + d.qpos[worldid, qposid + 2] - get_batched_value(m.qpos_spring, worldid, qposid + 2), ) d.qfrc_spring[worldid, dofid + 0] = -stiffness * dif[0] d.qfrc_spring[worldid, dofid + 1] = -stiffness * dif[1] @@ -61,10 +61,10 @@ def _spring(m: Model, d: Data): d.qpos[worldid, qposid + 6], ) ref = wp.quat( - m.qpos_spring[qposid + 3], - m.qpos_spring[qposid + 4], - m.qpos_spring[qposid + 5], - m.qpos_spring[qposid + 6], + get_batched_value(m.qpos_spring, worldid, qposid + 3), + get_batched_value(m.qpos_spring, worldid, qposid + 4), + get_batched_value(m.qpos_spring, worldid, qposid + 5), + get_batched_value(m.qpos_spring, worldid, qposid + 6), ) dif = math.quat_sub(rot, ref) d.qfrc_spring[worldid, dofid + 3] = -stiffness * dif[0] @@ -78,17 +78,17 @@ def _spring(m: Model, d: Data): d.qpos[worldid, qposid + 3], ) ref = wp.quat( - m.qpos_spring[qposid + 0], - m.qpos_spring[qposid + 1], - m.qpos_spring[qposid + 2], - m.qpos_spring[qposid + 3], + get_batched_value(m.qpos_spring, worldid, qposid + 0), + get_batched_value(m.qpos_spring, worldid, qposid + 1), + get_batched_value(m.qpos_spring, worldid, qposid + 2), + get_batched_value(m.qpos_spring, worldid, qposid + 3), ) dif = math.quat_sub(rot, ref) d.qfrc_spring[worldid, dofid + 0] = -stiffness * dif[0] d.qfrc_spring[worldid, dofid + 1] = -stiffness * dif[1] d.qfrc_spring[worldid, dofid + 2] = -stiffness * dif[2] else: # mjJNT_SLIDE, mjJNT_HINGE - fdif = d.qpos[worldid, qposid] - m.qpos_spring[qposid] + fdif = d.qpos[worldid, qposid] - get_batched_value(m.qpos_spring, worldid, qposid) d.qfrc_spring[worldid, dofid] = -stiffness * fdif @kernel diff --git a/mujoco_warp/_src/types.py b/mujoco_warp/_src/types.py index 3a322590..44dcb56a 100644 --- a/mujoco_warp/_src/types.py +++ b/mujoco_warp/_src/types.py @@ -555,7 +555,7 @@ class Model: opt: physics options stat: model statistics qpos0: qpos values at default pose (nmodel, nq) - qpos_spring: reference pose for springs (nq,) + qpos_spring: reference pose for springs (nmodel, nq,) body_tree: BFS ordering of body ids body_treeadr: starting index of each body tree level actuator_moment_offset_nv: tiling configuration @@ -765,7 +765,7 @@ class Model: opt: Option stat: Statistic qpos0: wp.array(dtype=wp.float32, ndim=2) - qpos_spring: wp.array(dtype=wp.float32, ndim=1) + qpos_spring: wp.array(dtype=wp.float32, ndim=2) body_tree: wp.array(dtype=wp.int32, ndim=1) # warp only body_treeadr: wp.array(dtype=wp.int32, ndim=1) # warp only actuator_moment_offset_nv: wp.array(dtype=wp.int32, ndim=1) # warp only From bf6f809554e3d11053eca13f8393eab26fd72415 Mon Sep 17 00:00:00 2001 From: Alain Denzler Date: Fri, 25 Apr 2025 15:42:51 +0200 Subject: [PATCH 08/31] body_pos --- mujoco_warp/_src/io.py | 2 +- mujoco_warp/_src/smooth.py | 4 ++-- mujoco_warp/_src/types.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/mujoco_warp/_src/io.py b/mujoco_warp/_src/io.py index ab207e1f..b7814bb9 100644 --- a/mujoco_warp/_src/io.py +++ b/mujoco_warp/_src/io.py @@ -369,7 +369,7 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: m.body_parentid = wp.array(mjm.body_parentid, dtype=wp.int32, ndim=1) m.body_mocapid = wp.array(mjm.body_mocapid, dtype=wp.int32, ndim=1) m.body_weldid = wp.array(mjm.body_weldid, dtype=wp.int32, ndim=1) - m.body_pos = wp.array(mjm.body_pos, dtype=wp.vec3, ndim=1) + m.body_pos = create_nmodel_batched_array(mjm.body_pos, dtype=wp.vec3) m.body_quat = wp.array(mjm.body_quat, dtype=wp.quat, ndim=1) m.body_ipos = wp.array(mjm.body_ipos, dtype=wp.vec3, ndim=1) m.body_iquat = wp.array(mjm.body_iquat, dtype=wp.quat, ndim=1) diff --git a/mujoco_warp/_src/smooth.py b/mujoco_warp/_src/smooth.py index 7c9d8b8b..c2868523 100644 --- a/mujoco_warp/_src/smooth.py +++ b/mujoco_warp/_src/smooth.py @@ -62,7 +62,7 @@ def _level(m: Model, d: Data, leveladr: int): if jntnum == 0: # no joints - apply fixed translation and rotation relative to parent pid = m.body_parentid[bodyid] - xpos = (d.xmat[worldid, pid] * m.body_pos[bodyid]) + d.xpos[worldid, pid] + xpos = (d.xmat[worldid, pid] * get_batched_value(m.body_pos, worldid, bodyid)) + d.xpos[worldid, pid] xquat = math.mul_quat(d.xquat[worldid, pid], m.body_quat[bodyid]) elif jntnum == 1 and m.jnt_type[jntadr] == wp.static(JointType.FREE.value): # free joint @@ -75,7 +75,7 @@ def _level(m: Model, d: Data, leveladr: int): # regular or no joints # apply fixed translation and rotation relative to parent pid = m.body_parentid[bodyid] - xpos = (d.xmat[worldid, pid] * m.body_pos[bodyid]) + d.xpos[worldid, pid] + xpos = (d.xmat[worldid, pid] * get_batched_value(m.body_pos, worldid, bodyid)) + d.xpos[worldid, pid] xquat = math.mul_quat(d.xquat[worldid, pid], m.body_quat[bodyid]) for _ in range(jntnum): diff --git a/mujoco_warp/_src/types.py b/mujoco_warp/_src/types.py index 44dcb56a..90314717 100644 --- a/mujoco_warp/_src/types.py +++ b/mujoco_warp/_src/types.py @@ -588,7 +588,7 @@ class Model: body_dofadr: start addr of dofs; -1: no dofs (nbody,) body_geomnum: number of geoms (nbody,) body_geomadr: start addr of geoms; -1: no geoms (nbody,) - body_pos: position offset rel. to parent body (nbody, 3) + body_pos: position offset rel. to parent body (nmodel, nbody, 3) body_quat: orientation offset rel. to parent body (nbody, 4) body_ipos: local position of center of mass (nbody, 3) body_iquat: local orientation of inertia ellipsoid (nbody, 4) @@ -798,7 +798,7 @@ class Model: body_dofadr: wp.array(dtype=wp.int32, ndim=1) body_geomnum: wp.array(dtype=wp.int32, ndim=1) body_geomadr: wp.array(dtype=wp.int32, ndim=1) - body_pos: wp.array(dtype=wp.vec3, ndim=1) + body_pos: wp.array(dtype=wp.vec3, ndim=2) body_quat: wp.array(dtype=wp.quat, ndim=1) body_ipos: wp.array(dtype=wp.vec3, ndim=1) body_iquat: wp.array(dtype=wp.quat, ndim=1) From 3af4cf91aacaec84ab1271be22af35a96d8b8530 Mon Sep 17 00:00:00 2001 From: Alain Denzler Date: Fri, 25 Apr 2025 15:49:52 +0200 Subject: [PATCH 09/31] ipos/iquat/mass etc --- mujoco_warp/_src/io.py | 10 +++++----- mujoco_warp/_src/sensor.py | 7 ++++--- mujoco_warp/_src/smooth.py | 24 ++++++++++++------------ mujoco_warp/_src/types.py | 20 ++++++++++---------- 4 files changed, 31 insertions(+), 30 deletions(-) diff --git a/mujoco_warp/_src/io.py b/mujoco_warp/_src/io.py index b7814bb9..f4a06c3d 100644 --- a/mujoco_warp/_src/io.py +++ b/mujoco_warp/_src/io.py @@ -370,12 +370,12 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: m.body_mocapid = wp.array(mjm.body_mocapid, dtype=wp.int32, ndim=1) m.body_weldid = wp.array(mjm.body_weldid, dtype=wp.int32, ndim=1) m.body_pos = create_nmodel_batched_array(mjm.body_pos, dtype=wp.vec3) - m.body_quat = wp.array(mjm.body_quat, dtype=wp.quat, ndim=1) - m.body_ipos = wp.array(mjm.body_ipos, dtype=wp.vec3, ndim=1) - m.body_iquat = wp.array(mjm.body_iquat, dtype=wp.quat, ndim=1) + m.body_quat = create_nmodel_batched_array(mjm.body_quat, dtype=wp.quat) + m.body_ipos = create_nmodel_batched_array(mjm.body_ipos, dtype=wp.vec3) + m.body_iquat = create_nmodel_batched_array(mjm.body_iquat, dtype=wp.quat) m.body_rootid = wp.array(mjm.body_rootid, dtype=wp.int32, ndim=1) - m.body_inertia = wp.array(mjm.body_inertia, dtype=wp.vec3, ndim=1) - m.body_mass = wp.array(mjm.body_mass, dtype=wp.float32, ndim=1) + m.body_inertia = create_nmodel_batched_array(mjm.body_inertia, dtype=wp.vec3) + m.body_mass = create_nmodel_batched_array(mjm.body_mass, dtype=wp.float32) m.body_subtreemass = wp.array(mjm.body_subtreemass, dtype=wp.float32, ndim=1) subtree_mass = np.copy(mjm.body_mass) diff --git a/mujoco_warp/_src/sensor.py b/mujoco_warp/_src/sensor.py index b443a7f8..e56afb63 100644 --- a/mujoco_warp/_src/sensor.py +++ b/mujoco_warp/_src/sensor.py @@ -27,6 +27,7 @@ from .types import SensorType from .warp_util import event_scope from .warp_util import kernel +from .support import get_batched_value @wp.func @@ -136,15 +137,15 @@ def _frame_quat( m: Model, d: Data, worldid: int, objid: int, objtype: int, refid: int ) -> wp.quat: if objtype == int(ObjType.BODY.value): - quat = math.mul_quat(d.xquat[worldid, objid], m.body_iquat[objid]) + quat = math.mul_quat(d.xquat[worldid, objid], get_batched_value(m.body_iquat, worldid, objid)) if refid == -1: return quat - refquat = math.mul_quat(d.xquat[worldid, refid], m.body_iquat[refid]) + refquat = math.mul_quat(d.xquat[worldid, refid], get_batched_value(m.body_iquat, worldid, refid)) elif objtype == int(ObjType.XBODY.value): quat = d.xquat[worldid, objid] if refid == -1: return quat - refquat = d.xquat[worldid, refid] + refquat = math.mul_quat(d.xquat[worldid, refid], get_batched_value(m.body_iquat, worldid, refid)) elif objtype == int(ObjType.GEOM.value): quat = math.mul_quat(d.xquat[worldid, m.geom_bodyid[objid]], m.geom_quat[objid]) if refid == -1: diff --git a/mujoco_warp/_src/smooth.py b/mujoco_warp/_src/smooth.py index c2868523..6bc210a4 100644 --- a/mujoco_warp/_src/smooth.py +++ b/mujoco_warp/_src/smooth.py @@ -63,7 +63,7 @@ def _level(m: Model, d: Data, leveladr: int): # no joints - apply fixed translation and rotation relative to parent pid = m.body_parentid[bodyid] xpos = (d.xmat[worldid, pid] * get_batched_value(m.body_pos, worldid, bodyid)) + d.xpos[worldid, pid] - xquat = math.mul_quat(d.xquat[worldid, pid], m.body_quat[bodyid]) + xquat = math.mul_quat(d.xquat[worldid, pid], get_batched_value(m.body_quat, worldid, bodyid)) elif jntnum == 1 and m.jnt_type[jntadr] == wp.static(JointType.FREE.value): # free joint qadr = m.jnt_qposadr[jntadr] @@ -76,7 +76,7 @@ def _level(m: Model, d: Data, leveladr: int): # apply fixed translation and rotation relative to parent pid = m.body_parentid[bodyid] xpos = (d.xmat[worldid, pid] * get_batched_value(m.body_pos, worldid, bodyid)) + d.xpos[worldid, pid] - xquat = math.mul_quat(d.xquat[worldid, pid], m.body_quat[bodyid]) + xquat = math.mul_quat(d.xquat[worldid, pid], get_batched_value(m.body_quat, worldid, bodyid)) for _ in range(jntnum): qadr = m.jnt_qposadr[jntadr] @@ -112,9 +112,9 @@ def _level(m: Model, d: Data, leveladr: int): xquat = wp.normalize(xquat) d.xquat[worldid, bodyid] = xquat d.xmat[worldid, bodyid] = math.quat_to_mat(xquat) - d.xipos[worldid, bodyid] = xpos + math.rot_vec_quat(m.body_ipos[bodyid], xquat) + d.xipos[worldid, bodyid] = xpos + math.rot_vec_quat(get_batched_value(m.body_ipos, worldid, bodyid), xquat) d.ximat[worldid, bodyid] = math.quat_to_mat( - math.mul_quat(xquat, m.body_iquat[bodyid]) + math.mul_quat(xquat, get_batched_value(m.body_iquat, worldid, bodyid)) ) @kernel @@ -161,7 +161,7 @@ def com_pos(m: Model, d: Data): @kernel def subtree_com_init(m: Model, d: Data): worldid, bodyid = wp.tid() - d.subtree_com[worldid, bodyid] = d.xipos[worldid, bodyid] * m.body_mass[bodyid] + d.subtree_com[worldid, bodyid] = d.xipos[worldid, bodyid] * get_batched_value(m.body_mass, worldid, bodyid) @kernel def subtree_com_acc(m: Model, d: Data, leveladr: int): @@ -179,8 +179,8 @@ def subtree_div(m: Model, d: Data): def cinert(m: Model, d: Data): worldid, bodyid = wp.tid() mat = d.ximat[worldid, bodyid] - inert = m.body_inertia[bodyid] - mass = m.body_mass[bodyid] + inert = get_batched_value(m.body_inertia, worldid, bodyid) + mass = get_batched_value(m.body_mass, worldid, bodyid) dif = d.xipos[worldid, bodyid] - d.subtree_com[worldid, m.body_rootid[bodyid]] # express inertia in com-based frame (mju_inertCom) @@ -1107,11 +1107,11 @@ def _forward(m: Model, d: Data): # update linear velocity lin -= wp.cross(xipos - subtree_com_root, ang) - d.subtree_linvel[worldid, bodyid] = m.body_mass[bodyid] * lin + d.subtree_linvel[worldid, bodyid] = get_batched_value(m.body_mass, worldid, bodyid) * lin dv = wp.transpose(ximat) @ ang - dv[0] *= m.body_inertia[bodyid][0] - dv[1] *= m.body_inertia[bodyid][1] - dv[2] *= m.body_inertia[bodyid][2] + dv[0] *= get_batched_value(m.body_inertia, worldid, bodyid)[0] + dv[1] *= get_batched_value(m.body_inertia, worldid, bodyid)[1] + dv[2] *= get_batched_value(m.body_inertia, worldid, bodyid)[2] d.subtree_angmom[worldid, bodyid] = ximat @ dv d.subtree_bodyvel[worldid, bodyid] = wp.spatial_vector(ang, lin) @@ -1149,7 +1149,7 @@ def _angular_momentum(m: Model, d: Data, leveladr: int): vel = d.subtree_bodyvel[worldid, bodyid] linvel = d.subtree_linvel[worldid, bodyid] linvel_parent = d.subtree_linvel[worldid, pid] - mass = m.body_mass[bodyid] + mass = get_batched_value(m.body_mass, worldid, bodyid) subtreemass = m.body_subtreemass[bodyid] # momentum wrt body i diff --git a/mujoco_warp/_src/types.py b/mujoco_warp/_src/types.py index 90314717..4f28d8f0 100644 --- a/mujoco_warp/_src/types.py +++ b/mujoco_warp/_src/types.py @@ -589,13 +589,13 @@ class Model: body_geomnum: number of geoms (nbody,) body_geomadr: start addr of geoms; -1: no geoms (nbody,) body_pos: position offset rel. to parent body (nmodel, nbody, 3) - body_quat: orientation offset rel. to parent body (nbody, 4) - body_ipos: local position of center of mass (nbody, 3) - body_iquat: local orientation of inertia ellipsoid (nbody, 4) - body_mass: mass (nbody,) + body_quat: orientation offset rel. to parent body (nmodel, nbody, 4) + body_ipos: local position of center of mass (nmodel, nbody, 3) + body_iquat: local orientation of inertia ellipsoid (nmodel, nbody, 4) + body_mass: mass (nmodel, nbody,) body_subtreemass: mass of subtree starting at this body (nbody,) subtree_mass: mass of subtree (nbody,) - body_inertia: diagonal inertia in ipos/iquat frame (nbody, 3) + body_inertia: diagonal inertia in ipos/iquat frame (nmodel, nbody, 3) body_invweight0: mean inv inert in qpos0 (trn, rot) (nbody, 2) body_contype: OR over all geom contypes (nbody,) body_conaffinity: OR over all geom conaffinities (nbody,) @@ -799,13 +799,13 @@ class Model: body_geomnum: wp.array(dtype=wp.int32, ndim=1) body_geomadr: wp.array(dtype=wp.int32, ndim=1) body_pos: wp.array(dtype=wp.vec3, ndim=2) - body_quat: wp.array(dtype=wp.quat, ndim=1) - body_ipos: wp.array(dtype=wp.vec3, ndim=1) - body_iquat: wp.array(dtype=wp.quat, ndim=1) - body_mass: wp.array(dtype=wp.float32, ndim=1) + body_quat: wp.array(dtype=wp.quat, ndim=2) + body_ipos: wp.array(dtype=wp.vec3, ndim=2) + body_iquat: wp.array(dtype=wp.quat, ndim=2) + body_mass: wp.array(dtype=wp.float32, ndim=2) body_subtreemass: wp.array(dtype=wp.float32, ndim=1) subtree_mass: wp.array(dtype=wp.float32, ndim=1) - body_inertia: wp.array(dtype=wp.vec3, ndim=1) + body_inertia: wp.array(dtype=wp.vec3, ndim=2) body_invweight0: wp.array(dtype=wp.float32, ndim=2) body_contype: wp.array(dtype=wp.int32, ndim=1) body_conaffinity: wp.array(dtype=wp.int32, ndim=1) From 171d8d53e58a48bcf2669ea14aef14fcc948fded Mon Sep 17 00:00:00 2001 From: Alain Denzler Date: Fri, 25 Apr 2025 16:13:37 +0200 Subject: [PATCH 10/31] more conversions --- mujoco_warp/_src/constraint.py | 31 ++++++++++++++++--------------- mujoco_warp/_src/io.py | 16 ++++++++-------- mujoco_warp/_src/passive.py | 2 +- mujoco_warp/_src/smooth.py | 8 ++++---- mujoco_warp/_src/support.py | 6 ++++++ mujoco_warp/_src/types.py | 32 ++++++++++++++++---------------- 6 files changed, 51 insertions(+), 44 deletions(-) diff --git a/mujoco_warp/_src/constraint.py b/mujoco_warp/_src/constraint.py index 7b569e3e..82ec7c2f 100644 --- a/mujoco_warp/_src/constraint.py +++ b/mujoco_warp/_src/constraint.py @@ -19,6 +19,7 @@ from . import support from . import types from .support import get_batched_value +from .support import get_batched_array from .warp_util import event_scope wp.config.enable_backward = False @@ -132,7 +133,7 @@ def _efc_equality_connect( d.efc.J[efcid + 2, dofid] = j1mj2[2] Jqvel += j1mj2 * d.qvel[worldid, dofid] - invweight = m.body_invweight0[body1id, 0] + m.body_invweight0[body2id, 0] + invweight = get_batched_value(m.body_invweight0, worldid, body1id, 0) + get_batched_value(m.body_invweight0, worldid, body2id, 0) pos_imp = wp.length(pos) solref = m.eq_solref[i_eq] @@ -326,7 +327,7 @@ def _efc_equality_weld( crotq = math.mul_quat(quat1, quat) # copy axis components crot = wp.vec3(crotq[1], crotq[2], crotq[3]) * torquescale - invweight_t = m.body_invweight0[body1id, 0] + m.body_invweight0[body2id, 0] + invweight_t = get_batched_value(m.body_invweight0, worldid, body1id, 0) + get_batched_value(m.body_invweight0, worldid, body2id, 0) pos_imp = wp.sqrt(wp.length_sq(cpos) + wp.length_sq(crot)) @@ -349,7 +350,7 @@ def _efc_equality_weld( i_eq, ) - invweight_r = m.body_invweight0[body1id, 1] + m.body_invweight0[body2id, 1] + invweight_r = get_batched_value(m.body_invweight0, worldid, body1id, 1) + get_batched_value(m.body_invweight0, worldid, body2id, 1) for i in range(3): _update_efc_row( @@ -377,9 +378,9 @@ def _efc_limit_slide_hinge( jntid = m.jnt_limited_slide_hinge_adr[jntlimitedid] qpos = d.qpos[worldid, m.jnt_qposadr[jntid]] - jnt_range = m.jnt_range[jntid] + jnt_range = get_batched_array(m.jnt_range, worldid, jntid) dist_min, dist_max = qpos - jnt_range[0], jnt_range[1] - qpos - pos = wp.min(dist_min, dist_max) - m.jnt_margin[jntid] + pos = wp.min(dist_min, dist_max) - get_batched_value(m.jnt_margin, worldid, jntid) active = pos < 0 if active: @@ -400,9 +401,9 @@ def _efc_limit_slide_hinge( pos, pos, m.dof_invweight0[dofadr], - m.jnt_solref[jntid], - m.jnt_solimp[jntid], - m.jnt_margin[jntid], + get_batched_value(m.jnt_solref, worldid, jntid), + get_batched_value(m.jnt_solimp, worldid, jntid), + get_batched_value(m.jnt_margin, worldid, jntid), Jqvel, 0.0, dofadr, @@ -424,8 +425,8 @@ def _efc_limit_ball( ) axis_angle = math.quat_to_vel(jnt_quat) axis, angle = math.normalize_with_norm(axis_angle) - jnt_margin = m.jnt_margin[jntid] - jnt_range = m.jnt_range[jntid] + jnt_margin = get_batched_value(m.jnt_margin, worldid, jntid) + jnt_range = get_batched_array(m.jnt_range, worldid, jntid) pos = wp.max(jnt_range[0], jnt_range[1]) - angle - jnt_margin active = pos < 0 @@ -452,9 +453,9 @@ def _efc_limit_ball( pos, pos, m.dof_invweight0[dofadr], - m.jnt_solref[jntid], - m.jnt_solimp[jntid], - jnt_margin, + get_batched_value(m.jnt_solref, worldid, jntid), + get_batched_value(m.jnt_solimp, worldid, jntid), + get_batched_value(m.jnt_margin, worldid, jntid), Jqvel, 0.0, jntid, @@ -548,7 +549,7 @@ def _efc_contact_pyramidal( frame = d.contact.frame[conid] # pyramidal has common invweight across all edges - invweight = m.body_invweight0[body1, 0] + m.body_invweight0[body2, 0] + invweight = get_batched_value(m.body_invweight0, worldid, body1, 0) + get_batched_value(m.body_invweight0, worldid, body2, 0) if condim > 1: dimid2 = dimid / 2 + 1 @@ -649,7 +650,7 @@ def _efc_contact_elliptic( d.efc.J[efcid, i] = J Jqvel += J * d.qvel[worldid, i] - invweight = m.body_invweight0[body1, 0] + m.body_invweight0[body2, 0] + invweight = get_batched_value(m.body_invweight0, worldid, body1, 0) + get_batched_value(m.body_invweight0, worldid, body2, 0) ref = d.contact.solref[conid] pos_aref = pos diff --git a/mujoco_warp/_src/io.py b/mujoco_warp/_src/io.py index f4a06c3d..d76e9ef3 100644 --- a/mujoco_warp/_src/io.py +++ b/mujoco_warp/_src/io.py @@ -383,8 +383,8 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: for i in range(mjm.nbody - 1, -1, -1): subtree_mass[mjm.body_parentid[i]] += subtree_mass[i] - m.subtree_mass = wp.array(subtree_mass, dtype=wp.float32, ndim=1) - m.body_invweight0 = wp.array(mjm.body_invweight0, dtype=wp.float32, ndim=2) + m.subtree_mass = create_nmodel_batched_array(subtree_mass, dtype=wp.float32) + m.body_invweight0 = create_nmodel_batched_array(mjm.body_invweight0, dtype=wp.float32) m.body_geomnum = wp.array(mjm.body_geomnum, dtype=wp.int32, ndim=1) m.body_geomadr = wp.array(mjm.body_geomadr, dtype=wp.int32, ndim=1) m.body_contype = wp.array(mjm.body_contype, dtype=wp.int32, ndim=1) @@ -396,15 +396,15 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: ) m.jnt_limited_ball_adr = wp.array(jnt_limited_ball_adr, dtype=wp.int32, ndim=1) m.jnt_type = wp.array(mjm.jnt_type, dtype=wp.int32, ndim=1) - m.jnt_solref = wp.array(mjm.jnt_solref, dtype=wp.vec2f, ndim=1) - m.jnt_solimp = wp.array(mjm.jnt_solimp, dtype=types.vec5, ndim=1) + m.jnt_solref = create_nmodel_batched_array(mjm.jnt_solref, dtype=wp.vec2f) + m.jnt_solimp = create_nmodel_batched_array(mjm.jnt_solimp, dtype=types.vec5) m.jnt_qposadr = wp.array(mjm.jnt_qposadr, dtype=wp.int32, ndim=1) m.jnt_dofadr = wp.array(mjm.jnt_dofadr, dtype=wp.int32, ndim=1) m.jnt_axis = wp.array(mjm.jnt_axis, dtype=wp.vec3, ndim=1) - m.jnt_pos = wp.array(mjm.jnt_pos, dtype=wp.vec3, ndim=1) - m.jnt_range = wp.array(mjm.jnt_range, dtype=wp.float32, ndim=2) - m.jnt_margin = wp.array(mjm.jnt_margin, dtype=wp.float32, ndim=1) - m.jnt_stiffness = wp.array(mjm.jnt_stiffness, dtype=wp.float32, ndim=1) + m.jnt_pos = create_nmodel_batched_array(mjm.jnt_pos, dtype=wp.vec3) + m.jnt_range = create_nmodel_batched_array(mjm.jnt_range, dtype=wp.float32) + m.jnt_stiffness = create_nmodel_batched_array(mjm.jnt_stiffness, dtype=wp.float32) + m.jnt_margin = create_nmodel_batched_array(mjm.jnt_margin, dtype=wp.float32) m.jnt_actfrclimited = wp.array(mjm.jnt_actfrclimited, dtype=wp.bool, ndim=1) m.jnt_actfrcrange = wp.array(mjm.jnt_actfrcrange, dtype=wp.vec2, ndim=1) m.geom_type = wp.array(mjm.geom_type, dtype=wp.int32, ndim=1) diff --git a/mujoco_warp/_src/passive.py b/mujoco_warp/_src/passive.py index fe295ea7..6cb1ad7c 100644 --- a/mujoco_warp/_src/passive.py +++ b/mujoco_warp/_src/passive.py @@ -36,7 +36,7 @@ def passive(m: Model, d: Data): @kernel def _spring(m: Model, d: Data): worldid, jntid = wp.tid() - stiffness = m.jnt_stiffness[jntid] + stiffness = get_batched_value(m.jnt_stiffness, worldid, jntid) dofid = m.jnt_dofadr[jntid] if stiffness == 0.0: diff --git a/mujoco_warp/_src/smooth.py b/mujoco_warp/_src/smooth.py index 6bc210a4..686d97e0 100644 --- a/mujoco_warp/_src/smooth.py +++ b/mujoco_warp/_src/smooth.py @@ -82,7 +82,7 @@ def _level(m: Model, d: Data, leveladr: int): qadr = m.jnt_qposadr[jntadr] jnt_type = m.jnt_type[jntadr] jnt_axis = m.jnt_axis[jntadr] - xanchor = math.rot_vec_quat(m.jnt_pos[jntadr], xquat) + xpos + xanchor = math.rot_vec_quat(get_batched_value(m.jnt_pos, worldid, jntadr), xquat) + xpos xaxis = math.rot_vec_quat(jnt_axis, xquat) if jnt_type == wp.static(JointType.BALL.value): @@ -94,7 +94,7 @@ def _level(m: Model, d: Data, leveladr: int): ) xquat = math.mul_quat(xquat, qloc) # correct for off-center rotation - xpos = xanchor - math.rot_vec_quat(m.jnt_pos[jntadr], xquat) + xpos = xanchor - math.rot_vec_quat(get_batched_value(m.jnt_pos, worldid, jntadr), xquat) elif jnt_type == wp.static(JointType.SLIDE.value): xpos += xaxis * (qpos[qadr] - get_batched_value(m.qpos0, worldid, qadr)) elif jnt_type == wp.static(JointType.HINGE.value): @@ -102,7 +102,7 @@ def _level(m: Model, d: Data, leveladr: int): qloc = math.axis_angle_to_quat(jnt_axis, qpos[qadr] - qpos0) xquat = math.mul_quat(xquat, qloc) # correct for off-center rotation - xpos = xanchor - math.rot_vec_quat(m.jnt_pos[jntadr], xquat) + xpos = xanchor - math.rot_vec_quat(get_batched_value(m.jnt_pos, worldid, jntadr), xquat) d.xanchor[worldid, jntadr] = xanchor d.xaxis[worldid, jntadr] = xaxis @@ -173,7 +173,7 @@ def subtree_com_acc(m: Model, d: Data, leveladr: int): @kernel def subtree_div(m: Model, d: Data): worldid, bodyid = wp.tid() - d.subtree_com[worldid, bodyid] /= m.subtree_mass[bodyid] + d.subtree_com[worldid, bodyid] /= get_batched_value(m.subtree_mass, worldid, bodyid) @kernel def cinert(m: Model, d: Data): diff --git a/mujoco_warp/_src/support.py b/mujoco_warp/_src/support.py index fbaf0735..fa0b95aa 100644 --- a/mujoco_warp/_src/support.py +++ b/mujoco_warp/_src/support.py @@ -361,6 +361,12 @@ def get_batched_value(array: wp.array2d(dtype=Any), worldid: wp.int32, i: wp.int modelid = worldid % array.shape[0] return array[modelid, i] +@wp.func +def get_batched_array(array: wp.array3d(dtype=Any), worldid: wp.int32, i: wp.int32): + """Returns the array slice for the given worldid.""" + modelid = worldid % array.shape[0] + return array[modelid, i] + @wp.func def get_batched_value(array: wp.array3d(dtype=Any), worldid: wp.int32, i: wp.int32, j: wp.int32): modelid = worldid % array.shape[0] diff --git a/mujoco_warp/_src/types.py b/mujoco_warp/_src/types.py index 4f28d8f0..00ed9b18 100644 --- a/mujoco_warp/_src/types.py +++ b/mujoco_warp/_src/types.py @@ -594,9 +594,9 @@ class Model: body_iquat: local orientation of inertia ellipsoid (nmodel, nbody, 4) body_mass: mass (nmodel, nbody,) body_subtreemass: mass of subtree starting at this body (nbody,) - subtree_mass: mass of subtree (nbody,) + subtree_mass: mass of subtree (nmodel, nbody,) body_inertia: diagonal inertia in ipos/iquat frame (nmodel, nbody, 3) - body_invweight0: mean inv inert in qpos0 (trn, rot) (nbody, 2) + body_invweight0: mean inv inert in qpos0 (trn, rot) (nmodel, nbody, 2) body_contype: OR over all geom contypes (nbody,) body_conaffinity: OR over all geom conaffinities (nbody,) jnt_type: type of joint (mjtJoint) (njnt,) @@ -605,14 +605,14 @@ class Model: jnt_bodyid: id of joint's body (njnt,) jnt_limited: does joint have limits (njnt,) jnt_actfrclimited: does joint have actuator force limits (njnt,) - jnt_solref: constraint solver reference: limit (njnt, mjNREF) - jnt_solimp: constraint solver impedance: limit (njnt, mjNIMP) - jnt_pos: local anchor position (njnt, 3) + jnt_solref: constraint solver reference: limit (nmodel, njnt, mjNREF) + jnt_solimp: constraint solver impedance: limit (nmodel, njnt, mjNIMP) + jnt_pos: local anchor position (nmodel, njnt, 3) jnt_axis: local joint axis (njnt, 3) - jnt_stiffness: stiffness coefficient (njnt,) - jnt_range: joint limits (njnt, 2) + jnt_stiffness: stiffness coefficient (nmodel, njnt,) + jnt_range: joint limits (nmodel, njnt, 2) jnt_actfrcrange: range of total actuator force (njnt, 2) - jnt_margin: min distance for limit detection (njnt,) + jnt_margin: min distance for limit detection (nmodel, njnt,) jnt_limited_slide_hinge_adr: limited/slide/hinge jntadr jnt_limited_ball_adr: limited/ball jntadr dof_bodyid: id of dof's body (nv,) @@ -804,9 +804,9 @@ class Model: body_iquat: wp.array(dtype=wp.quat, ndim=2) body_mass: wp.array(dtype=wp.float32, ndim=2) body_subtreemass: wp.array(dtype=wp.float32, ndim=1) - subtree_mass: wp.array(dtype=wp.float32, ndim=1) + subtree_mass: wp.array(dtype=wp.float32, ndim=2) body_inertia: wp.array(dtype=wp.vec3, ndim=2) - body_invweight0: wp.array(dtype=wp.float32, ndim=2) + body_invweight0: wp.array(dtype=wp.float32, ndim=3) body_contype: wp.array(dtype=wp.int32, ndim=1) body_conaffinity: wp.array(dtype=wp.int32, ndim=1) jnt_type: wp.array(dtype=wp.int32, ndim=1) @@ -815,14 +815,14 @@ class Model: jnt_bodyid: wp.array(dtype=wp.int32, ndim=1) jnt_limited: wp.array(dtype=wp.int32, ndim=1) jnt_actfrclimited: wp.array(dtype=wp.bool, ndim=1) - jnt_solref: wp.array(dtype=wp.vec2, ndim=1) - jnt_solimp: wp.array(dtype=vec5, ndim=1) - jnt_pos: wp.array(dtype=wp.vec3, ndim=1) + jnt_solref: wp.array(dtype=wp.vec2, ndim=2) + jnt_solimp: wp.array(dtype=vec5, ndim=2) + jnt_pos: wp.array(dtype=wp.vec3, ndim=2) jnt_axis: wp.array(dtype=wp.vec3, ndim=1) - jnt_stiffness: wp.array(dtype=wp.float32, ndim=1) - jnt_range: wp.array(dtype=wp.float32, ndim=2) + jnt_stiffness: wp.array(dtype=wp.float32, ndim=2) + jnt_range: wp.array(dtype=wp.float32, ndim=3) jnt_actfrcrange: wp.array(dtype=wp.vec2, ndim=1) - jnt_margin: wp.array(dtype=wp.float32, ndim=1) + jnt_margin: wp.array(dtype=wp.float32, ndim=2) jnt_limited_slide_hinge_adr: wp.array(dtype=wp.int32, ndim=1) # warp only jnt_limited_ball_adr: wp.array(dtype=wp.int32, ndim=1) # warp only dof_bodyid: wp.array(dtype=wp.int32, ndim=1) From 3270655274e065c58026670b54a9cd60cefc2d1b Mon Sep 17 00:00:00 2001 From: Alain Denzler Date: Fri, 25 Apr 2025 17:45:56 +0200 Subject: [PATCH 11/31] committing WIP to get things working --- mujoco_warp/_src/collision_box.py | 9 +- mujoco_warp/_src/collision_convex.py | 7 +- mujoco_warp/_src/collision_driver.py | 7 +- mujoco_warp/_src/collision_primitive.py | 50 +++---- mujoco_warp/_src/constraint.py | 53 ++++--- mujoco_warp/_src/forward.py | 27 ++-- mujoco_warp/_src/io.py | 104 +++++++------- mujoco_warp/_src/sensor.py | 8 +- mujoco_warp/_src/smooth.py | 38 ++--- mujoco_warp/_src/types.py | 176 ++++++++++++------------ 10 files changed, 243 insertions(+), 236 deletions(-) diff --git a/mujoco_warp/_src/collision_box.py b/mujoco_warp/_src/collision_box.py index ae3f2849..7758c985 100644 --- a/mujoco_warp/_src/collision_box.py +++ b/mujoco_warp/_src/collision_box.py @@ -25,6 +25,7 @@ from .types import Data from .types import GeomType from .types import Model +from .support import get_batched_value BOX_BOX_BLOCK_DIM = 32 @@ -210,7 +211,7 @@ def box_box_kernel( worldid = d.collision_worldid[bp_idx] geoms, margin, gap, condim, friction, solref, solreffriction, solimp = ( - contact_params(m, d, tid) + contact_params(m, d, tid, worldid) ) # transformations @@ -220,8 +221,8 @@ def box_box_kernel( trans_atob = b_mat_inv @ (a_pos - b_pos) rot_atob = b_mat_inv @ a_mat - a_size = m.geom_size[ga] - b_size = m.geom_size[gb] + a_size = get_batched_value(m.geom_size, worldid, ga) + b_size = get_batched_value(m.geom_size, worldid, gb) a = box(rot_atob, trans_atob, a_size) b = box(wp.identity(3, wp.float32), wp.vec3(0.0), b_size) @@ -312,7 +313,7 @@ def box_box_kernel( for i in range(4): pos[i] = pos[idx] - margin = wp.max(m.geom_margin[ga], m.geom_margin[gb]) + margin = wp.max(get_batched_value(m.geom_margin, worldid, ga), get_batched_value(m.geom_margin, worldid, gb)) for i in range(4): pos_glob = b_mat @ pos[i] + b_pos n_glob = b_mat @ sep_axis diff --git a/mujoco_warp/_src/collision_convex.py b/mujoco_warp/_src/collision_convex.py index 7ebde685..b2051009 100644 --- a/mujoco_warp/_src/collision_convex.py +++ b/mujoco_warp/_src/collision_convex.py @@ -29,6 +29,7 @@ from .types import Data from .types import GeomType from .types import Model +from .support import get_batched_value # XXX disable backward pass codegen globally for now # enabling backward pass leads to 10min compile time @@ -732,10 +733,10 @@ def gjk_epa_sparse( if m.geom_type[g1] != type1 or m.geom_type[g2] != type2: return - info1 = _geom(g1, m, d.geom_xpos[worldid], d.geom_xmat[worldid]) - info2 = _geom(g2, m, d.geom_xpos[worldid], d.geom_xmat[worldid]) + info1 = _geom(g1, m, d.geom_xpos[worldid], d.geom_xmat[worldid], worldid) + info2 = _geom(g2, m, d.geom_xpos[worldid], d.geom_xmat[worldid], worldid) - margin = wp.max(m.geom_margin[g1], m.geom_margin[g2]) + margin = wp.max(get_batched_value(m.geom_margin, worldid, g1), get_batched_value(m.geom_margin, worldid, g2)) simplex, normal = _gjk(m, info1, info2) diff --git a/mujoco_warp/_src/collision_driver.py b/mujoco_warp/_src/collision_driver.py index 00fdc5b1..e38a01e1 100644 --- a/mujoco_warp/_src/collision_driver.py +++ b/mujoco_warp/_src/collision_driver.py @@ -26,14 +26,15 @@ from .types import DisableBit from .types import Model from .warp_util import event_scope +from .support import get_batched_value wp.set_module_options({"enable_backward": False}) @wp.func def _sphere_filter(m: Model, d: Data, geom1: int, geom2: int, worldid: int) -> bool: - margin1 = m.geom_margin[geom1] - margin2 = m.geom_margin[geom2] + margin1 = get_batched_value(m.geom_margin, worldid, geom1) + margin2 = get_batched_value(m.geom_margin, worldid, geom2) pos1 = d.geom_xpos[worldid, geom1] pos2 = d.geom_xpos[worldid, geom2] size1 = m.geom_rbound[geom1] @@ -111,7 +112,7 @@ def _sap_project(m: Model, d: Data, direction: wp.vec3): # geom is a plane rbound = MJ_MAXVAL - radius = rbound + m.geom_margin[geomid] + radius = rbound + get_batched_value(m.geom_margin, worldid, geomid) center = wp.dot(direction, xpos) d.sap_projection_lower[worldid, geomid] = center - radius diff --git a/mujoco_warp/_src/collision_primitive.py b/mujoco_warp/_src/collision_primitive.py index 426ec093..692631bf 100644 --- a/mujoco_warp/_src/collision_primitive.py +++ b/mujoco_warp/_src/collision_primitive.py @@ -24,6 +24,8 @@ from .types import GeomType from .types import Model from .types import vec5 +from .support import get_batched_value +from .support import get_batched_array wp.set_module_options({"enable_backward": False}) @@ -44,12 +46,13 @@ def _geom( m: Model, geom_xpos: wp.array(dtype=wp.vec3), geom_xmat: wp.array(dtype=wp.mat33), + worldid: int, ) -> Geom: geom = Geom() geom.pos = geom_xpos[gid] rot = geom_xmat[gid] geom.rot = rot - geom.size = m.geom_size[gid] + geom.size = get_batched_value(m.geom_size, worldid, gid) geom.normal = wp.vec3(rot[0, 2], rot[1, 2], rot[2, 2]) # plane dataid = m.geom_dataid[gid] if dataid >= 0: @@ -737,27 +740,27 @@ def plane_cylinder( @wp.func -def contact_params(m: Model, d: Data, cid: int): +def contact_params(m: Model, d: Data, cid: int, worldid: int): geoms = d.collision_pair[cid] pairid = d.collision_pairid[cid] if pairid > -1: - margin = m.pair_margin[pairid] - gap = m.pair_gap[pairid] + margin = get_batched_value(m.pair_margin, worldid, pairid) + gap = get_batched_value(m.pair_gap, worldid, pairid) condim = m.pair_dim[pairid] - friction = m.pair_friction[pairid] - solref = m.pair_solref[pairid] - solreffriction = m.pair_solreffriction[pairid] - solimp = m.pair_solimp[pairid] + friction = get_batched_value(m.pair_friction, worldid, pairid) + solref = get_batched_value(m.pair_solref, worldid, pairid) + solreffriction = get_batched_value(m.pair_solreffriction, worldid, pairid) + solimp = get_batched_value(m.pair_solimp, worldid, pairid) else: g1 = geoms[0] g2 = geoms[1] - p1 = m.geom_priority[g1] - p2 = m.geom_priority[g2] + p1 = get_batched_value(m.geom_priority, worldid, g1) + p2 = get_batched_value(m.geom_priority, worldid, g2) - solmix1 = m.geom_solmix[g1] - solmix2 = m.geom_solmix[g2] + solmix1 = get_batched_value(m.geom_solmix, worldid, g1) + solmix2 = get_batched_value(m.geom_solmix, worldid, g2) mix = solmix1 / (solmix1 + solmix2) mix = wp.where((solmix1 < MJ_MINVAL) and (solmix2 < MJ_MINVAL), 0.5, mix) @@ -765,8 +768,8 @@ def contact_params(m: Model, d: Data, cid: int): mix = wp.where((solmix1 >= MJ_MINVAL) and (solmix2 < MJ_MINVAL), 1.0, mix) mix = wp.where(p1 == p2, mix, wp.where(p1 > p2, 1.0, 0.0)) - margin = wp.max(m.geom_margin[g1], m.geom_margin[g2]) - gap = wp.max(m.geom_gap[g1], m.geom_gap[g2]) + margin = wp.max(get_batched_value(m.geom_margin, worldid, g1), get_batched_value(m.geom_margin, worldid, g2)) + gap = wp.max(get_batched_value(m.geom_gap, worldid, g1), get_batched_value(m.geom_gap, worldid, g2)) condim1 = m.geom_condim[g1] condim2 = m.geom_condim[g2] @@ -774,7 +777,7 @@ def contact_params(m: Model, d: Data, cid: int): p1 == p2, wp.max(condim1, condim2), wp.where(p1 > p2, condim1, condim2) ) - geom_friction = wp.max(m.geom_friction[g1], m.geom_friction[g2]) + geom_friction = wp.max(get_batched_value(m.geom_friction, worldid, g1), get_batched_value(m.geom_friction, worldid, g2)) friction = vec5( geom_friction[0], geom_friction[0], @@ -783,14 +786,14 @@ def contact_params(m: Model, d: Data, cid: int): geom_friction[2], ) - if m.geom_solref[g1].x > 0.0 and m.geom_solref[g2].x > 0.0: - solref = mix * m.geom_solref[g1] + (1.0 - mix) * m.geom_solref[g2] + if get_batched_value(m.geom_solref, worldid, g1).x > 0.0 and get_batched_value(m.geom_solref, worldid, g2).x > 0.0: + solref = mix * get_batched_value(m.geom_solref, worldid, g1) + (1.0 - mix) * get_batched_value(m.geom_solref, worldid, g2) else: - solref = wp.min(m.geom_solref[g1], m.geom_solref[g2]) + solref = wp.min(get_batched_value(m.geom_solref, worldid, g1), get_batched_value(m.geom_solref, worldid, g2)) solreffriction = wp.vec2(0.0, 0.0) - solimp = mix * m.geom_solimp[g1] + (1.0 - mix) * m.geom_solimp[g2] + solimp = mix * get_batched_value(m.geom_solimp, worldid, g1) + (1.0 - mix) * get_batched_value(m.geom_solimp, worldid, g2) return geoms, margin, gap, condim, friction, solref, solreffriction, solimp @@ -868,16 +871,17 @@ def _primitive_narrowphase( if tid >= d.ncollision[0]: return + worldid = d.collision_worldid[tid] + geoms, margin, gap, condim, friction, solref, solreffriction, solimp = contact_params( - m, d, tid + m, d, tid, worldid ) g1 = geoms[0] g2 = geoms[1] - worldid = d.collision_worldid[tid] - geom1 = _geom(g1, m, d.geom_xpos[worldid], d.geom_xmat[worldid]) - geom2 = _geom(g2, m, d.geom_xpos[worldid], d.geom_xmat[worldid]) + geom1 = _geom(g1, m, d.geom_xpos[worldid], d.geom_xmat[worldid], worldid) + geom2 = _geom(g2, m, d.geom_xpos[worldid], d.geom_xmat[worldid], worldid) type1 = m.geom_type[g1] type2 = m.geom_type[g2] diff --git a/mujoco_warp/_src/constraint.py b/mujoco_warp/_src/constraint.py index 82ec7c2f..94d76a4e 100644 --- a/mujoco_warp/_src/constraint.py +++ b/mujoco_warp/_src/constraint.py @@ -100,7 +100,7 @@ def _efc_equality_connect( for i in range(wp.static(3)): d.efc.worldid[efcid + i] = worldid - data = m.eq_data[i_eq] + data = get_batched_value(m.eq_data, worldid, i_eq) anchor1 = wp.vec3f(data[0], data[1], data[2]) anchor2 = wp.vec3f(data[3], data[4], data[5]) @@ -136,8 +136,8 @@ def _efc_equality_connect( invweight = get_batched_value(m.body_invweight0, worldid, body1id, 0) + get_batched_value(m.body_invweight0, worldid, body2id, 0) pos_imp = wp.length(pos) - solref = m.eq_solref[i_eq] - solimp = m.eq_solimp[i_eq] + solref = get_batched_value(m.eq_solref, worldid, i_eq) + solimp = get_batched_value(m.eq_solimp, worldid, i_eq) for i in range(3): efcidi = efcid + i @@ -175,7 +175,7 @@ def _efc_equality_joint( jntid_1 = m.eq_obj1id[i_eq] jntid_2 = m.eq_obj2id[i_eq] - data = m.eq_data[i_eq] + data = get_batched_value(m.eq_data, worldid, i_eq) dofadr1 = m.jnt_dofadr[jntid_1] qposadr1 = m.jnt_qposadr[jntid_1] d.efc.J[efcid, dofadr1] = 1.0 @@ -194,14 +194,14 @@ def _efc_equality_joint( pos = d.qpos[worldid, qposadr1] - get_batched_value(m.qpos0, worldid, qposadr1) - rhs Jqvel = d.qvel[worldid, dofadr1] - d.qvel[worldid, dofadr2] * deriv_2 - invweight = m.dof_invweight0[dofadr1] + m.dof_invweight0[dofadr2] + invweight = get_batched_value(m.dof_invweight0, worldid, dofadr1) + get_batched_value(m.dof_invweight0, worldid, dofadr2) d.efc.J[efcid, dofadr2] = -deriv_2 else: # Single joint constraint pos = d.qpos[worldid, qposadr1] - get_batched_value(m.qpos0, worldid, qposadr1) - data[0] Jqvel = d.qvel[worldid, dofadr1] - invweight = m.dof_invweight0[dofadr1] + invweight = get_batched_value(m.dof_invweight0, worldid, dofadr1) # Update constraint parameters _update_efc_row( @@ -211,8 +211,8 @@ def _efc_equality_joint( pos, pos, invweight, - m.eq_solref[i_eq], - m.eq_solimp[i_eq], + get_batched_value(m.eq_solref, worldid, i_eq), + get_batched_value(m.eq_solimp, worldid, i_eq), wp.float32(0.0), Jqvel, 0.0, @@ -228,7 +228,7 @@ def _efc_friction( # TODO(team): tendon worldid, dofid = wp.tid() - if m.dof_frictionloss[dofid] <= 0.0: + if get_batched_value(m.dof_frictionloss, worldid, dofid) <= 0.0: return efcid = wp.atomic_add(d.nefc, 0, 1) @@ -244,13 +244,13 @@ def _efc_friction( efcid, 0.0, 0.0, - m.dof_invweight0[dofid], - m.dof_solref[dofid], - m.dof_solimp[dofid], + get_batched_value(m.dof_invweight0, worldid, dofid), + get_batched_value(m.dof_solref, worldid, dofid), + get_batched_value(m.dof_solimp, worldid, dofid), 0.0, Jqvel, - m.dof_frictionloss[dofid], - dofid, + get_batched_value(m.dof_frictionloss, worldid, dofid), + m.jnt_bodyid[m.dof_jntid[dofid]], ) @@ -274,7 +274,7 @@ def _efc_equality_weld( obj1id = m.eq_obj1id[i_eq] obj2id = m.eq_obj2id[i_eq] - data = m.eq_data[i_eq] + data = get_batched_value(m.eq_data, worldid, i_eq) anchor1 = wp.vec3(data[0], data[1], data[2]) anchor2 = wp.vec3(data[3], data[4], data[5]) relpose = wp.quat(data[6], data[7], data[8], data[9]) @@ -286,9 +286,8 @@ def _efc_equality_weld( body2id = m.site_bodyid[obj2id] pos1 = d.site_xpos[worldid, obj1id] pos2 = d.site_xpos[worldid, obj2id] - - quat = math.mul_quat(d.xquat[worldid, body1id], m.site_quat[obj1id]) - quat1 = math.quat_inv(math.mul_quat(d.xquat[worldid, body2id], m.site_quat[obj2id])) + quat = math.mul_quat(d.xquat[worldid, body1id], get_batched_value(m.site_quat, worldid, obj1id)) + quat1 = math.quat_inv(math.mul_quat(d.xquat[worldid, body2id], get_batched_value(m.site_quat, worldid, obj2id))) else: body1id = obj1id @@ -331,8 +330,8 @@ def _efc_equality_weld( pos_imp = wp.sqrt(wp.length_sq(cpos) + wp.length_sq(crot)) - solref = m.eq_solref[i_eq] - solimp = m.eq_solimp[i_eq] + solref = get_batched_value(m.eq_solref, worldid, i_eq) + solimp = get_batched_value(m.eq_solimp, worldid, i_eq) for i in range(3): _update_efc_row( @@ -400,7 +399,7 @@ def _efc_limit_slide_hinge( efcid, pos, pos, - m.dof_invweight0[dofadr], + get_batched_value(m.dof_invweight0, worldid, dofadr), get_batched_value(m.jnt_solref, worldid, jntid), get_batched_value(m.jnt_solimp, worldid, jntid), get_batched_value(m.jnt_margin, worldid, jntid), @@ -452,7 +451,7 @@ def _efc_limit_ball( efcid, pos, pos, - m.dof_invweight0[dofadr], + get_batched_value(m.dof_invweight0, worldid, dofadr), get_batched_value(m.jnt_solref, worldid, jntid), get_batched_value(m.jnt_solimp, worldid, jntid), get_batched_value(m.jnt_margin, worldid, jntid), @@ -470,10 +469,10 @@ def _efc_limit_tendon( worldid, tenlimitedid = wp.tid() tenid = m.tendon_limited_adr[tenlimitedid] - ten_range = m.tendon_range[tenid] + ten_range = get_batched_value(m.tendon_range, worldid, tenid) length = d.ten_length[worldid, tenid] dist_min, dist_max = length - ten_range[0], ten_range[1] - length - ten_margin = m.tendon_margin[tenid] + ten_margin = get_batched_value(m.tendon_margin, worldid, tenid) pos = wp.min(dist_min, dist_max) - ten_margin active = pos < 0 @@ -505,9 +504,9 @@ def _efc_limit_tendon( efcid, pos, pos, - m.tendon_invweight0[tenid], - m.tendon_solref_lim[tenid], - m.tendon_solimp_lim[tenid], + get_batched_value(m.tendon_invweight0, worldid, tenid), + get_batched_value(m.tendon_solref_lim, worldid, tenid), + get_batched_value(m.tendon_solimp_lim, worldid, tenid), ten_margin, Jqvel, 0.0, diff --git a/mujoco_warp/_src/forward.py b/mujoco_warp/_src/forward.py index 32e0c523..6816cd9b 100644 --- a/mujoco_warp/_src/forward.py +++ b/mujoco_warp/_src/forward.py @@ -130,7 +130,7 @@ def _next_act( # advance the actuation if m.actuator_dyntype[actid] == wp.static(DynType.FILTEREXACT.value): - dyn_prm = m.actuator_dynprm[actid] + dyn_prm = get_batched_value(m.actuator_dynprm, worldid, actid) tau = wp.max(MJ_MINVAL, dyn_prm[0]) act += act_dot * tau * (1.0 - wp.exp(-m.opt.timestep / tau)) else: @@ -138,7 +138,7 @@ def _next_act( # clamp to actrange if m.actuator_actlimited[actid]: - actrange = m.actuator_actrange[actid] + actrange = get_batched_value(m.actuator_actrange, worldid, actid) act = wp.clamp(act, actrange[0], actrange[1]) d.act[worldid, actid] = act @@ -380,10 +380,10 @@ def actuator_bias_gain_vel(m: Model, d: Data): actuator_dyntype = m.actuator_dyntype[actid] if actuator_biastype == wp.static(BiasType.AFFINE.value): - bias_vel = m.actuator_biasprm[actid][2] + bias_vel = get_batched_value(m.actuator_biasprm, worldid, actid)[2] if actuator_gaintype == wp.static(GainType.AFFINE.value): - gain_vel = m.actuator_gainprm[actid][2] + gain_vel = get_batched_value(m.actuator_gainprm, worldid, actid)[2] ctrl = d.ctrl[worldid, actid] @@ -629,7 +629,7 @@ def _force(m: Model, d: Data): dsbl_clampctrl = m.opt.disableflags & wp.static(DisableBit.CLAMPCTRL.value) if m.actuator_ctrllimited[uid] and not dsbl_clampctrl: - r = m.actuator_ctrlrange[uid] + r = get_batched_value(m.actuator_ctrlrange, worldid, uid) ctrl = wp.clamp(ctrl, r[0], r[1]) if m.na: @@ -640,7 +640,7 @@ def _force(m: Model, d: Data): elif dyntype == int(DynType.FILTER.value) or dyntype == int( DynType.FILTEREXACT.value ): - dynprm = m.actuator_dynprm[uid] + dynprm = get_batched_value(m.actuator_dynprm, worldid, uid) actadr = m.actuator_actadr[uid] act = d.act[worldid, actadr] d.act_dot[worldid, actadr] = (ctrl - act) / wp.max(dynprm[0], MJ_MINVAL) @@ -659,7 +659,7 @@ def _force(m: Model, d: Data): # gain gaintype = m.actuator_gaintype[uid] - gainprm = m.actuator_gainprm[uid] + gainprm = get_batched_value(m.actuator_gainprm, worldid, uid) gain = 0.0 if gaintype == int(GainType.FIXED.value): @@ -671,7 +671,7 @@ def _force(m: Model, d: Data): # bias biastype = m.actuator_biastype[uid] - biasprm = m.actuator_biasprm[uid] + biasprm = get_batched_value(m.actuator_biasprm, worldid, uid) bias = 0.0 # BiasType.NONE if biastype == int(BiasType.AFFINE.value): @@ -684,7 +684,7 @@ def _force(m: Model, d: Data): # TODO(team): tendon total force clamping if m.actuator_forcelimited[uid]: - r = m.actuator_forcerange[uid] + r = get_batched_value(m.actuator_forcerange, worldid, uid) f = wp.clamp(f, r[0], r[1]) d.actuator_force[worldid, uid] = f @@ -693,10 +693,11 @@ def _qfrc_limited(m: Model, d: Data): worldid, dofid = wp.tid() jntid = m.dof_jntid[dofid] if m.jnt_actfrclimited[jntid]: + range = get_batched_value(m.jnt_actfrcrange, worldid, jntid) d.qfrc_actuator[worldid, dofid] = wp.clamp( d.qfrc_actuator[worldid, dofid], - m.jnt_actfrcrange[jntid][0], - m.jnt_actfrcrange[jntid][1], + range[0], + range[1], ) if m.opt.is_sparse: @@ -711,8 +712,8 @@ def _qfrc(m: Model, moment: array3df, force: array2df, qfrc: array2df): s += moment[worldid, uid, vid] * force[worldid, uid] jntid = m.dof_jntid[vid] if m.jnt_actfrclimited[jntid]: - r = m.jnt_actfrcrange[jntid] - s = wp.clamp(s, r[0], r[1]) + range = get_batched_value(m.jnt_actfrcrange, worldid, jntid) + s = wp.clamp(s, range[0], range[1]) qfrc[worldid, vid] = s wp.launch(_force, dim=[d.nworld, m.nu], inputs=[m, d]) diff --git a/mujoco_warp/_src/io.py b/mujoco_warp/_src/io.py index d76e9ef3..62ecaa5f 100644 --- a/mujoco_warp/_src/io.py +++ b/mujoco_warp/_src/io.py @@ -406,22 +406,22 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: m.jnt_stiffness = create_nmodel_batched_array(mjm.jnt_stiffness, dtype=wp.float32) m.jnt_margin = create_nmodel_batched_array(mjm.jnt_margin, dtype=wp.float32) m.jnt_actfrclimited = wp.array(mjm.jnt_actfrclimited, dtype=wp.bool, ndim=1) - m.jnt_actfrcrange = wp.array(mjm.jnt_actfrcrange, dtype=wp.vec2, ndim=1) + m.jnt_actfrcrange = create_nmodel_batched_array(mjm.jnt_actfrcrange, dtype=wp.vec2) m.geom_type = wp.array(mjm.geom_type, dtype=wp.int32, ndim=1) m.geom_bodyid = wp.array(mjm.geom_bodyid, dtype=wp.int32, ndim=1) - m.geom_conaffinity = wp.array(mjm.geom_conaffinity, dtype=wp.int32, ndim=1) - m.geom_contype = wp.array(mjm.geom_contype, dtype=wp.int32, ndim=1) + m.geom_conaffinity = create_nmodel_batched_array(mjm.geom_conaffinity, dtype=wp.int32) + m.geom_contype = create_nmodel_batched_array(mjm.geom_contype, dtype=wp.int32) m.geom_condim = wp.array(mjm.geom_condim, dtype=wp.int32, ndim=1) - m.geom_pos = wp.array(mjm.geom_pos, dtype=wp.vec3, ndim=1) - m.geom_quat = wp.array(mjm.geom_quat, dtype=wp.quat, ndim=1) + m.geom_pos = create_nmodel_batched_array(mjm.geom_pos, dtype=wp.vec3) + m.geom_quat = create_nmodel_batched_array(mjm.geom_quat, dtype=wp.quat) m.geom_size = wp.array(mjm.geom_size, dtype=wp.vec3, ndim=1) - m.geom_priority = wp.array(mjm.geom_priority, dtype=wp.int32, ndim=1) - m.geom_solmix = wp.array(mjm.geom_solmix, dtype=wp.float32, ndim=1) - m.geom_solref = wp.array(mjm.geom_solref, dtype=wp.vec2, ndim=1) - m.geom_solimp = wp.array(mjm.geom_solimp, dtype=types.vec5, ndim=1) - m.geom_friction = wp.array(mjm.geom_friction, dtype=wp.vec3, ndim=1) - m.geom_margin = wp.array(mjm.geom_margin, dtype=wp.float32, ndim=1) - m.geom_gap = wp.array(mjm.geom_gap, dtype=wp.float32, ndim=1) + m.geom_priority = create_nmodel_batched_array(mjm.geom_priority, dtype=wp.int32) + m.geom_solmix = create_nmodel_batched_array(mjm.geom_solmix, dtype=wp.float32) + m.geom_solref = create_nmodel_batched_array(mjm.geom_solref, dtype=wp.vec2) + m.geom_solimp = create_nmodel_batched_array(mjm.geom_solimp, dtype=types.vec5) + m.geom_friction = create_nmodel_batched_array(mjm.geom_friction, dtype=wp.vec3) + m.geom_margin = create_nmodel_batched_array(mjm.geom_margin, dtype=wp.float32) + m.geom_gap = create_nmodel_batched_array(mjm.geom_gap, dtype=wp.float32) m.geom_aabb = wp.array(mjm.geom_aabb, dtype=wp.vec3, ndim=3) m.geom_rbound = wp.array(mjm.geom_rbound, dtype=wp.float32, ndim=1) m.geom_dataid = wp.array(mjm.geom_dataid, dtype=wp.int32, ndim=1) @@ -433,55 +433,55 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: m.eq_obj2id = wp.array(mjm.eq_obj2id, dtype=wp.int32, ndim=1) m.eq_objtype = wp.array(mjm.eq_objtype, dtype=wp.int32, ndim=1) m.eq_active0 = wp.array(mjm.eq_active0, dtype=wp.bool, ndim=1) - m.eq_solref = wp.array(mjm.eq_solref, dtype=wp.vec2, ndim=1) - m.eq_solimp = wp.array(mjm.eq_solimp, dtype=types.vec5, ndim=1) + m.eq_solref = create_nmodel_batched_array(mjm.eq_solref, dtype=wp.vec2) + m.eq_solimp = create_nmodel_batched_array(mjm.eq_solimp, dtype=types.vec5) m.eq_data = wp.array(mjm.eq_data, dtype=types.vec11, ndim=1) - m.site_pos = wp.array(mjm.site_pos, dtype=wp.vec3, ndim=1) - m.site_quat = wp.array(mjm.site_quat, dtype=wp.quat, ndim=1) + m.site_pos = create_nmodel_batched_array(mjm.site_pos, dtype=wp.vec3) + m.site_quat = create_nmodel_batched_array(mjm.site_quat, dtype=wp.quat) m.site_bodyid = wp.array(mjm.site_bodyid, dtype=wp.int32, ndim=1) m.cam_mode = wp.array(mjm.cam_mode, dtype=wp.int32, ndim=1) m.cam_bodyid = wp.array(mjm.cam_bodyid, dtype=wp.int32, ndim=1) m.cam_targetbodyid = wp.array(mjm.cam_targetbodyid, dtype=wp.int32, ndim=1) - m.cam_pos = wp.array(mjm.cam_pos, dtype=wp.vec3, ndim=1) - m.cam_quat = wp.array(mjm.cam_quat, dtype=wp.quat, ndim=1) - m.cam_poscom0 = wp.array(mjm.cam_poscom0, dtype=wp.vec3, ndim=1) - m.cam_pos0 = wp.array(mjm.cam_pos0, dtype=wp.vec3, ndim=1) + m.cam_pos = create_nmodel_batched_array(mjm.cam_pos, dtype=wp.vec3) + m.cam_quat = create_nmodel_batched_array(mjm.cam_quat, dtype=wp.quat) + m.cam_poscom0 = create_nmodel_batched_array(mjm.cam_poscom0, dtype=wp.vec3) + m.cam_pos0 = create_nmodel_batched_array(mjm.cam_pos0, dtype=wp.vec3) m.light_mode = wp.array(mjm.light_mode, dtype=wp.int32, ndim=1) m.light_bodyid = wp.array(mjm.light_bodyid, dtype=wp.int32, ndim=1) m.light_targetbodyid = wp.array(mjm.light_targetbodyid, dtype=wp.int32, ndim=1) - m.light_pos = wp.array(mjm.light_pos, dtype=wp.vec3, ndim=1) - m.light_dir = wp.array(mjm.light_dir, dtype=wp.vec3, ndim=1) - m.light_poscom0 = wp.array(mjm.light_poscom0, dtype=wp.vec3, ndim=1) - m.light_pos0 = wp.array(mjm.light_pos0, dtype=wp.vec3, ndim=1) + m.light_pos = create_nmodel_batched_array(mjm.light_pos, dtype=wp.vec3) + m.light_dir = create_nmodel_batched_array(mjm.light_dir, dtype=wp.vec3) + m.light_poscom0 = create_nmodel_batched_array(mjm.light_poscom0, dtype=wp.vec3) + m.light_pos0 = create_nmodel_batched_array(mjm.light_pos0, dtype=wp.vec3) m.dof_bodyid = wp.array(mjm.dof_bodyid, dtype=wp.int32, ndim=1) m.dof_jntid = wp.array(mjm.dof_jntid, dtype=wp.int32, ndim=1) m.dof_parentid = wp.array(mjm.dof_parentid, dtype=wp.int32, ndim=1) m.dof_Madr = wp.array(mjm.dof_Madr, dtype=wp.int32, ndim=1) - m.dof_armature = wp.array(mjm.dof_armature, dtype=wp.float32, ndim=1) + m.dof_armature = create_nmodel_batched_array(mjm.dof_armature, dtype=wp.float32) m.dof_damping = create_nmodel_batched_array(mjm.dof_damping, dtype=wp.float32) - m.dof_frictionloss = wp.array(mjm.dof_frictionloss, dtype=wp.float32, ndim=1) - m.dof_solimp = wp.array(mjm.dof_solimp, dtype=types.vec5, ndim=1) - m.dof_solref = wp.array(mjm.dof_solref, dtype=wp.vec2, ndim=1) + m.dof_frictionloss = create_nmodel_batched_array(mjm.dof_frictionloss, dtype=wp.float32) + m.dof_solimp = create_nmodel_batched_array(mjm.dof_solimp, dtype=types.vec5) + m.dof_solref = create_nmodel_batched_array(mjm.dof_solref, dtype=wp.vec2) m.dof_tri_row = wp.from_numpy(dof_tri_row, dtype=wp.int32) m.dof_tri_col = wp.from_numpy(dof_tri_col, dtype=wp.int32) - m.dof_invweight0 = wp.array(mjm.dof_invweight0, dtype=wp.float32, ndim=1) + m.dof_invweight0 = create_nmodel_batched_array(mjm.dof_invweight0, dtype=wp.float32) m.actuator_trntype = wp.array(mjm.actuator_trntype, dtype=wp.int32, ndim=1) m.actuator_trnid = wp.array(mjm.actuator_trnid, dtype=wp.int32, ndim=2) m.actuator_ctrllimited = wp.array(mjm.actuator_ctrllimited, dtype=wp.bool, ndim=1) - m.actuator_ctrlrange = wp.array(mjm.actuator_ctrlrange, dtype=wp.vec2, ndim=1) + m.actuator_ctrlrange = create_nmodel_batched_array(mjm.actuator_ctrlrange, dtype=wp.vec2) m.actuator_forcelimited = wp.array(mjm.actuator_forcelimited, dtype=wp.bool, ndim=1) - m.actuator_forcerange = wp.array(mjm.actuator_forcerange, dtype=wp.vec2, ndim=1) + m.actuator_forcerange = create_nmodel_batched_array(mjm.actuator_forcerange, dtype=wp.vec2) m.actuator_gaintype = wp.array(mjm.actuator_gaintype, dtype=wp.int32, ndim=1) - m.actuator_gainprm = wp.array(mjm.actuator_gainprm, dtype=types.vec10f, ndim=1) + m.actuator_gainprm = create_nmodel_batched_array(mjm.actuator_gainprm, dtype=types.vec10f) m.actuator_biastype = wp.array(mjm.actuator_biastype, dtype=wp.int32, ndim=1) - m.actuator_biasprm = wp.array(mjm.actuator_biasprm, dtype=types.vec10f, ndim=1) - m.actuator_gear = wp.array(mjm.actuator_gear, dtype=wp.spatial_vector, ndim=1) + m.actuator_biasprm = create_nmodel_batched_array(mjm.actuator_biasprm, dtype=types.vec10f) + m.actuator_gear = create_nmodel_batched_array(mjm.actuator_gear, dtype=wp.spatial_vector) m.actuator_actlimited = wp.array(mjm.actuator_actlimited, dtype=wp.bool, ndim=1) - m.actuator_actrange = wp.array(mjm.actuator_actrange, dtype=wp.vec2, ndim=1) + m.actuator_actrange = create_nmodel_batched_array(mjm.actuator_actrange, dtype=wp.vec2) m.actuator_actadr = wp.array(mjm.actuator_actadr, dtype=wp.int32, ndim=1) m.actuator_actnum = wp.array(mjm.actuator_actnum, dtype=wp.int32, ndim=1) m.actuator_dyntype = wp.array(mjm.actuator_dyntype, dtype=wp.int32, ndim=1) - m.actuator_dynprm = wp.array(mjm.actuator_dynprm, dtype=types.vec10f, ndim=1) + m.actuator_dynprm = create_nmodel_batched_array(mjm.actuator_dynprm, dtype=types.vec10f) m.exclude_signature = wp.array(mjm.exclude_signature, dtype=wp.int32, ndim=1) # pre-compute indices of equality constraints @@ -508,13 +508,13 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: # predefined collision pairs m.pair_dim = wp.array(mjm.pair_dim, dtype=wp.int32, ndim=1) m.pair_geom1 = wp.array(mjm.pair_geom1, dtype=wp.int32, ndim=1) - m.pair_geom2 = wp.array(mjm.pair_geom2, dtype=wp.int32, ndim=1) - m.pair_solref = wp.array(mjm.pair_solref, dtype=wp.vec2, ndim=1) - m.pair_solreffriction = wp.array(mjm.pair_solreffriction, dtype=wp.vec2, ndim=1) - m.pair_solimp = wp.array(mjm.pair_solimp, dtype=types.vec5, ndim=1) - m.pair_margin = wp.array(mjm.pair_margin, dtype=wp.float32, ndim=1) - m.pair_gap = wp.array(mjm.pair_gap, dtype=wp.float32, ndim=1) - m.pair_friction = wp.array(mjm.pair_friction, dtype=types.vec5, ndim=1) + m.pair_geom2 = create_nmodel_batched_array(mjm.pair_geom2, dtype=wp.int32) + m.pair_solref = create_nmodel_batched_array(mjm.pair_solref, dtype=wp.vec2) + m.pair_solreffriction = create_nmodel_batched_array(mjm.pair_solreffriction, dtype=wp.vec2) + m.pair_solimp = create_nmodel_batched_array(mjm.pair_solimp, dtype=types.vec5) + m.pair_margin = create_nmodel_batched_array(mjm.pair_margin, dtype=wp.float32) + m.pair_gap = create_nmodel_batched_array(mjm.pair_gap, dtype=wp.float32) + m.pair_friction = create_nmodel_batched_array(mjm.pair_friction, dtype=types.vec5) m.condim_max = np.max(mjm.geom_condim) # TODO(team): get max after filtering # tendon @@ -524,15 +524,15 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: m.tendon_limited_adr = wp.array( np.nonzero(mjm.tendon_limited)[0], dtype=wp.int32, ndim=1 ) - m.tendon_solref_lim = wp.array(mjm.tendon_solref_lim, dtype=wp.vec2f, ndim=1) - m.tendon_solimp_lim = wp.array(mjm.tendon_solimp_lim, dtype=types.vec5, ndim=1) - m.tendon_range = wp.array(mjm.tendon_range, dtype=wp.vec2f, ndim=1) - m.tendon_margin = wp.array(mjm.tendon_margin, dtype=wp.float32, ndim=1) - m.tendon_length0 = wp.array(mjm.tendon_length0, dtype=wp.float32, ndim=1) - m.tendon_invweight0 = wp.array(mjm.tendon_invweight0, dtype=wp.float32, ndim=1) - m.wrap_objid = wp.array(mjm.wrap_objid, dtype=wp.int32, ndim=1) - m.wrap_prm = wp.array(mjm.wrap_prm, dtype=wp.float32, ndim=1) - m.wrap_type = wp.array(mjm.wrap_type, dtype=wp.int32, ndim=1) + m.tendon_solref_lim = create_nmodel_batched_array(mjm.tendon_solref_lim, dtype=wp.vec2f) + m.tendon_solimp_lim = create_nmodel_batched_array(mjm.tendon_solimp_lim, dtype=types.vec5) + m.tendon_range = create_nmodel_batched_array(mjm.tendon_range, dtype=wp.vec2f) + m.tendon_margin = create_nmodel_batched_array(mjm.tendon_margin, dtype=wp.float32) + m.tendon_length0 = create_nmodel_batched_array(mjm.tendon_length0, dtype=wp.float32) + m.tendon_invweight0 = create_nmodel_batched_array(mjm.tendon_invweight0, dtype=wp.float32) + m.wrap_objid = wp.array(mjm.wrap_objid, dtype=wp.int32) + m.wrap_prm = create_nmodel_batched_array(mjm.wrap_prm, dtype=wp.float32) + m.wrap_type = wp.array(mjm.wrap_type, dtype=wp.int32) # fixed tendon tendon_jnt_adr = [] diff --git a/mujoco_warp/_src/sensor.py b/mujoco_warp/_src/sensor.py index e56afb63..4c3a4e44 100644 --- a/mujoco_warp/_src/sensor.py +++ b/mujoco_warp/_src/sensor.py @@ -147,15 +147,15 @@ def _frame_quat( return quat refquat = math.mul_quat(d.xquat[worldid, refid], get_batched_value(m.body_iquat, worldid, refid)) elif objtype == int(ObjType.GEOM.value): - quat = math.mul_quat(d.xquat[worldid, m.geom_bodyid[objid]], m.geom_quat[objid]) + quat = math.mul_quat(d.xquat[worldid, m.geom_bodyid[objid]], get_batched_value(m.geom_quat, worldid, objid)) if refid == -1: return quat - refquat = math.mul_quat(d.xquat[worldid, m.geom_bodyid[refid]], m.geom_quat[refid]) + refquat = math.mul_quat(d.xquat[worldid, m.geom_bodyid[refid]], get_batched_value(m.geom_quat, worldid, refid)) elif objtype == int(ObjType.SITE.value): - quat = math.mul_quat(d.xquat[worldid, m.site_bodyid[objid]], m.site_quat[objid]) + quat = math.mul_quat(d.xquat[worldid, m.site_bodyid[objid]], get_batched_value(m.site_quat, worldid, objid)) if refid == -1: return quat - refquat = math.mul_quat(d.xquat[worldid, m.site_bodyid[refid]], m.site_quat[refid]) + refquat = math.mul_quat(d.xquat[worldid, m.site_bodyid[refid]], get_batched_value(m.site_quat, worldid, refid)) # TODO(team): camera diff --git a/mujoco_warp/_src/smooth.py b/mujoco_warp/_src/smooth.py index 686d97e0..1c1e35f5 100644 --- a/mujoco_warp/_src/smooth.py +++ b/mujoco_warp/_src/smooth.py @@ -123,9 +123,9 @@ def geom_local_to_global(m: Model, d: Data): bodyid = m.geom_bodyid[geomid] xpos = d.xpos[worldid, bodyid] xquat = d.xquat[worldid, bodyid] - d.geom_xpos[worldid, geomid] = xpos + math.rot_vec_quat(m.geom_pos[geomid], xquat) + d.geom_xpos[worldid, geomid] = xpos + math.rot_vec_quat(get_batched_value(m.geom_pos, worldid, geomid), xquat) d.geom_xmat[worldid, geomid] = math.quat_to_mat( - math.mul_quat(xquat, m.geom_quat[geomid]) + math.mul_quat(xquat, get_batched_value(m.geom_quat, worldid, geomid)) ) @kernel @@ -134,9 +134,9 @@ def site_local_to_global(m: Model, d: Data): bodyid = m.site_bodyid[siteid] xpos = d.xpos[worldid, bodyid] xquat = d.xquat[worldid, bodyid] - d.site_xpos[worldid, siteid] = xpos + math.rot_vec_quat(m.site_pos[siteid], xquat) + d.site_xpos[worldid, siteid] = xpos + math.rot_vec_quat(get_batched_value(m.site_pos, worldid, siteid), xquat) d.site_xmat[worldid, siteid] = math.quat_to_mat( - math.mul_quat(xquat, m.site_quat[siteid]) + math.mul_quat(xquat, get_batched_value(m.site_quat, worldid, siteid)) ) wp.launch(_root, dim=(d.nworld), inputs=[m, d]) @@ -265,9 +265,9 @@ def cam_local_to_global(m: Model, d: Data): bodyid = m.cam_bodyid[camid] xpos = d.xpos[worldid, bodyid] xquat = d.xquat[worldid, bodyid] - d.cam_xpos[worldid, camid] = xpos + math.rot_vec_quat(m.cam_pos[camid], xquat) + d.cam_xpos[worldid, camid] = xpos + math.rot_vec_quat(get_batched_value(m.cam_pos, worldid, camid), xquat) d.cam_xmat[worldid, camid] = math.quat_to_mat( - math.mul_quat(xquat, m.cam_quat[camid]) + math.mul_quat(xquat, get_batched_value(m.cam_quat, worldid, camid)) ) @kernel @@ -281,10 +281,10 @@ def cam_fn(m: Model, d: Data): return elif m.cam_mode[camid] == wp.static(CamLightType.TRACK.value): body_xpos = d.xpos[worldid, m.cam_bodyid[camid]] - d.cam_xpos[worldid, camid] = body_xpos + m.cam_pos0[camid] + d.cam_xpos[worldid, camid] = body_xpos + get_batched_value(m.cam_pos0, worldid, camid) elif m.cam_mode[camid] == wp.static(CamLightType.TRACKCOM.value): d.cam_xpos[worldid, camid] = ( - d.subtree_com[worldid, m.cam_bodyid[camid]] + m.cam_poscom0[camid] + d.subtree_com[worldid, m.cam_bodyid[camid]] + get_batched_value(m.cam_poscom0, worldid, camid) ) elif m.cam_mode[camid] == wp.static(CamLightType.TARGETBODY.value) or m.cam_mode[ camid @@ -313,9 +313,9 @@ def light_local_to_global(m: Model, d: Data): xpos = d.xpos[worldid, bodyid] xquat = d.xquat[worldid, bodyid] d.light_xpos[worldid, lightid] = xpos + math.rot_vec_quat( - m.light_pos[lightid], xquat + get_batched_value(m.light_pos, worldid, lightid), xquat ) - d.light_xdir[worldid, lightid] = math.rot_vec_quat(m.light_dir[lightid], xquat) + d.light_xdir[worldid, lightid] = math.rot_vec_quat(get_batched_value(m.light_dir, worldid, lightid), xquat) @kernel def light_fn(m: Model, d: Data): @@ -328,10 +328,10 @@ def light_fn(m: Model, d: Data): return elif m.light_mode[lightid] == wp.static(CamLightType.TRACK.value): body_xpos = d.xpos[worldid, m.light_bodyid[lightid]] - d.light_xpos[worldid, lightid] = body_xpos + m.light_pos0[lightid] + d.light_xpos[worldid, lightid] = body_xpos + get_batched_value(m.light_pos0, worldid, lightid) elif m.light_mode[lightid] == wp.static(CamLightType.TRACKCOM.value): d.light_xpos[worldid, lightid] = ( - d.subtree_com[worldid, m.light_bodyid[lightid]] + m.light_poscom0[lightid] + d.subtree_com[worldid, m.light_bodyid[lightid]] + get_batched_value(m.light_poscom0, worldid, lightid) ) elif m.light_mode[lightid] == wp.static( CamLightType.TARGETBODY.value @@ -372,7 +372,7 @@ def qM_sparse(m: Model, d: Data): bodyid = m.dof_bodyid[dofid] # init M(i,i) with armature inertia - d.qM[worldid, 0, madr_ij] = m.dof_armature[dofid] + d.qM[worldid, 0, madr_ij] = get_batched_value(m.dof_armature, worldid, dofid) # precompute buf = crb_body_i * cdof_i buf = math.inert_vec(d.crb[worldid, bodyid], d.cdof[worldid, dofid]) @@ -389,7 +389,7 @@ def qM_dense(m: Model, d: Data): bodyid = m.dof_bodyid[dofid] # init M(i,i) with armature inertia - M = m.dof_armature[dofid] + M = get_batched_value(m.dof_armature, worldid, dofid) # precompute buf = crb_body_i * cdof_i buf = math.inert_vec(d.crb[worldid, bodyid], d.cdof[worldid, dofid]) @@ -690,7 +690,7 @@ def _cfrc_ext_equality(m: Model, d: Data): worldid = d.efc.worldid[efcid] id = d.efc.id[efcid] - eq_data = m.eq_data[id] + eq_data = get_batched_value(m.eq_data, worldid, id) body_semantic = m.eq_objtype[id] == wp.static(ObjType.BODY.value) obj1 = m.eq_obj1id[id] @@ -711,7 +711,7 @@ def _cfrc_ext_equality(m: Model, d: Data): else: offset = wp.vec3(eq_data[3], eq_data[4], eq_data[5]) else: - offset = m.site_pos[obj1] + offset = get_batched_value(m.site_pos, worldid, obj1) # transform point on body1: local -> global pos = d.xmat[worldid, bodyid1] @ offset + d.xpos[worldid, bodyid1] @@ -733,7 +733,7 @@ def _cfrc_ext_equality(m: Model, d: Data): else: offset = wp.vec3(eq_data[0], eq_data[1], eq_data[2]) else: - offset = m.site_pos[obj2] + offset = get_batched_value(m.site_pos, worldid, obj2) # transform point on body2: local -> global pos = d.xmat[worldid, bodyid2] @ offset + d.xpos[worldid, bodyid2] @@ -812,7 +812,7 @@ def _transmission( ): worldid, actid = wp.tid() trntype = m.actuator_trntype[actid] - gear = m.actuator_gear[actid] + gear = get_batched_value(m.actuator_gear, worldid, actid) if trntype == wp.static(TrnType.JOINT.value) or trntype == wp.static( TrnType.JOINTINPARENT.value ): @@ -1198,7 +1198,7 @@ def _joint_tendon(m: Model, d: Data): wrap_jnt_adr = m.wrap_jnt_adr[wrapid] wrap_objid = m.wrap_objid[wrap_jnt_adr] - prm = m.wrap_prm[wrap_jnt_adr] + prm = get_batched_value(m.wrap_prm, worldid, wrap_jnt_adr) # add to length L = prm * d.qpos[worldid, m.jnt_qposadr[wrap_objid]] diff --git a/mujoco_warp/_src/types.py b/mujoco_warp/_src/types.py index 00ed9b18..0e91b490 100644 --- a/mujoco_warp/_src/types.py +++ b/mujoco_warp/_src/types.py @@ -611,7 +611,7 @@ class Model: jnt_axis: local joint axis (njnt, 3) jnt_stiffness: stiffness coefficient (nmodel, njnt,) jnt_range: joint limits (nmodel, njnt, 2) - jnt_actfrcrange: range of total actuator force (njnt, 2) + jnt_actfrcrange: range of total actuator force (nmodel, njnt, 2) jnt_margin: min distance for limit detection (nmodel, njnt,) jnt_limited_slide_hinge_adr: limited/slide/hinge jntadr jnt_limited_ball_adr: limited/ball jntadr @@ -619,49 +619,49 @@ class Model: dof_jntid: id of dof's joint (nv,) dof_parentid: id of dof's parent; -1: none (nv,) dof_Madr: dof address in M-diagonal (nv,) - dof_armature: dof armature inertia/mass (nv,) + dof_armature: dof armature inertia/mass (nmodel, nv) dof_damping: damping coefficient (nmodel, nv) - dof_invweight0: diag. inverse inertia in qpos0 (nv,) - dof_frictionloss: dof friction loss (nv,) - dof_solimp: constraint solver impedance: frictionloss (nv, NIMP) - dof_solref: constraint solver reference: frictionloss (nv, NREF) + dof_invweight0: diag. inverse inertia in qpos0 (nmodel, nv) + dof_frictionloss: dof friction loss (nmodel, nv) + dof_solimp: constraint solver impedance: frictionloss (nmodel, nv, NIMP) + dof_solref: constraint solver reference: frictionloss (nmodel, nv, NREF) dof_tri_row: np.tril_indices (mjm.nv)[0] dof_tri_col: np.tril_indices (mjm.nv)[1] geom_type: geometric type (mjtGeom) (ngeom,) - geom_contype: geom contact type (ngeom,) - geom_conaffinity: geom contact affinity (ngeom,) + geom_contype: geom contact type (nmodel, ngeom) + geom_conaffinity: geom contact affinity (nmodel, ngeom,) geom_condim: contact dimensionality (1, 3, 4, 6) (ngeom,) geom_bodyid: id of geom's body (ngeom,) geom_dataid: id of geom's mesh/hfield; -1: none (ngeom,) - geom_priority: geom contact priority (ngeom,) - geom_solmix: mixing coef for solref/imp in geom pair (ngeom,) - geom_solref: constraint solver reference: contact (ngeom, mjNREF) - geom_solimp: constraint solver impedance: contact (ngeom, mjNIMP) + geom_priority: geom contact priority (nmodel, ngeom) + geom_solmix: mixing coef for solref/imp in geom pair (nmodel, ngeom,) + geom_solref: constraint solver reference: contact (nmodel, ngeom, mjNREF) + geom_solimp: constraint solver impedance: contact (nmodel, ngeom, mjNIMP) geom_size: geom-specific size parameters (ngeom, 3) geom_aabb: bounding box, (center, size) (ngeom, 6) geom_rbound: radius of bounding sphere (ngeom,) - geom_pos: local position offset rel. to body (ngeom, 3) - geom_quat: local orientation offset rel. to body (ngeom, 4) - geom_friction: friction for (slide, spin, roll) (ngeom, 3) - geom_margin: detect contact if dist Date: Mon, 28 Apr 2025 11:01:21 +0200 Subject: [PATCH 12/31] fixes --- contrib/kernel_analyzer/kernel_analyzer/ast_analyzer_test.py | 2 +- mujoco_warp/_src/io.py | 2 +- mujoco_warp/_src/types.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/contrib/kernel_analyzer/kernel_analyzer/ast_analyzer_test.py b/contrib/kernel_analyzer/kernel_analyzer/ast_analyzer_test.py index ca498c3d..42d0440c 100644 --- a/contrib/kernel_analyzer/kernel_analyzer/ast_analyzer_test.py +++ b/contrib/kernel_analyzer/kernel_analyzer/ast_analyzer_test.py @@ -137,7 +137,7 @@ def test_all_issues( def test_no_issues( # Model: qpos0: wp.array(dtype=wp.float32, ndim=2), - geom_pos: wp.array(dtype=wp.vec3, ndim=1), + geom_pos: wp.array(dtype=wp.vec3, ndim=2), # Data in: qpos_in: wp.array(dtype=wp.float32, ndim=2), qvel_in: wp.array(dtype=wp.float32, ndim=2), diff --git a/mujoco_warp/_src/io.py b/mujoco_warp/_src/io.py index 62ecaa5f..3ce3e79b 100644 --- a/mujoco_warp/_src/io.py +++ b/mujoco_warp/_src/io.py @@ -435,7 +435,7 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: m.eq_active0 = wp.array(mjm.eq_active0, dtype=wp.bool, ndim=1) m.eq_solref = create_nmodel_batched_array(mjm.eq_solref, dtype=wp.vec2) m.eq_solimp = create_nmodel_batched_array(mjm.eq_solimp, dtype=types.vec5) - m.eq_data = wp.array(mjm.eq_data, dtype=types.vec11, ndim=1) + m.eq_data = create_nmodel_batched_array(mjm.eq_data, dtype=types.vec11) m.site_pos = create_nmodel_batched_array(mjm.site_pos, dtype=wp.vec3) m.site_quat = create_nmodel_batched_array(mjm.site_quat, dtype=wp.quat) m.site_bodyid = wp.array(mjm.site_bodyid, dtype=wp.int32, ndim=1) diff --git a/mujoco_warp/_src/types.py b/mujoco_warp/_src/types.py index 0e91b490..051e4955 100644 --- a/mujoco_warp/_src/types.py +++ b/mujoco_warp/_src/types.py @@ -672,7 +672,7 @@ class Model: eq_active0: initial enable/disable constraint state (neq,) eq_solref: constraint solver reference (nmodel, neq, mjNREF) eq_solimp: constraint solver impedance (nmodel, neq, mjNIMP) - eq_data: numeric data for constraint (neq, mjNEQDATA) + eq_data: numeric data for constraint (nmodel, neq, mjNEQDATA) eq_connect_adr: eq_* addresses of type `CONNECT` eq_wld_adr: eq_* addresses of type `WELD` eq_jnt_adr: eq_* addresses of type `JOINT` From 555c81595d11839905312fc8e7d6d617975b06e4 Mon Sep 17 00:00:00 2001 From: Alain Denzler Date: Mon, 28 Apr 2025 11:27:40 +0200 Subject: [PATCH 13/31] fix geom_size --- mujoco_warp/_src/io.py | 2 +- mujoco_warp/_src/types.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mujoco_warp/_src/io.py b/mujoco_warp/_src/io.py index 3ce3e79b..b5093a3d 100644 --- a/mujoco_warp/_src/io.py +++ b/mujoco_warp/_src/io.py @@ -414,7 +414,7 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: m.geom_condim = wp.array(mjm.geom_condim, dtype=wp.int32, ndim=1) m.geom_pos = create_nmodel_batched_array(mjm.geom_pos, dtype=wp.vec3) m.geom_quat = create_nmodel_batched_array(mjm.geom_quat, dtype=wp.quat) - m.geom_size = wp.array(mjm.geom_size, dtype=wp.vec3, ndim=1) + m.geom_size = create_nmodel_batched_array(mjm.geom_size, dtype=wp.vec3) m.geom_priority = create_nmodel_batched_array(mjm.geom_priority, dtype=wp.int32) m.geom_solmix = create_nmodel_batched_array(mjm.geom_solmix, dtype=wp.float32) m.geom_solref = create_nmodel_batched_array(mjm.geom_solref, dtype=wp.vec2) diff --git a/mujoco_warp/_src/types.py b/mujoco_warp/_src/types.py index 051e4955..44aa6358 100644 --- a/mujoco_warp/_src/types.py +++ b/mujoco_warp/_src/types.py @@ -637,7 +637,7 @@ class Model: geom_solmix: mixing coef for solref/imp in geom pair (nmodel, ngeom,) geom_solref: constraint solver reference: contact (nmodel, ngeom, mjNREF) geom_solimp: constraint solver impedance: contact (nmodel, ngeom, mjNIMP) - geom_size: geom-specific size parameters (ngeom, 3) + geom_size: geom-specific size parameters (nmodel, ngeom, 3) geom_aabb: bounding box, (center, size) (ngeom, 6) geom_rbound: radius of bounding sphere (ngeom,) geom_pos: local position offset rel. to body (nmodel, ngeom, 3) From d7facd1456b46d8d0089636316b5ccb6e3ac02b1 Mon Sep 17 00:00:00 2001 From: Alain Denzler Date: Mon, 28 Apr 2025 11:38:32 +0200 Subject: [PATCH 14/31] formatting and linting --- mujoco_warp/_src/collision_box.py | 7 ++- mujoco_warp/_src/collision_convex.py | 7 ++- mujoco_warp/_src/collision_driver.py | 2 +- mujoco_warp/_src/collision_primitive.py | 38 ++++++++---- mujoco_warp/_src/constraint.py | 46 +++++++++++---- mujoco_warp/_src/forward.py | 14 +++-- mujoco_warp/_src/io.py | 60 +++++++++++++------ mujoco_warp/_src/io_test.py | 5 +- mujoco_warp/_src/passive.py | 11 ++-- mujoco_warp/_src/sensor.py | 34 ++++++++--- mujoco_warp/_src/smooth.py | 78 +++++++++++++++++-------- mujoco_warp/_src/support.py | 15 ++++- 12 files changed, 230 insertions(+), 87 deletions(-) diff --git a/mujoco_warp/_src/collision_box.py b/mujoco_warp/_src/collision_box.py index 7758c985..b482c0e7 100644 --- a/mujoco_warp/_src/collision_box.py +++ b/mujoco_warp/_src/collision_box.py @@ -22,10 +22,10 @@ from .collision_primitive import contact_params from .collision_primitive import write_contact from .math import make_frame +from .support import get_batched_value from .types import Data from .types import GeomType from .types import Model -from .support import get_batched_value BOX_BOX_BLOCK_DIM = 32 @@ -313,7 +313,10 @@ def box_box_kernel( for i in range(4): pos[i] = pos[idx] - margin = wp.max(get_batched_value(m.geom_margin, worldid, ga), get_batched_value(m.geom_margin, worldid, gb)) + margin = wp.max( + get_batched_value(m.geom_margin, worldid, ga), + get_batched_value(m.geom_margin, worldid, gb), + ) for i in range(4): pos_glob = b_mat @ pos[i] + b_pos n_glob = b_mat @ sep_axis diff --git a/mujoco_warp/_src/collision_convex.py b/mujoco_warp/_src/collision_convex.py index b2051009..71d562fa 100644 --- a/mujoco_warp/_src/collision_convex.py +++ b/mujoco_warp/_src/collision_convex.py @@ -24,12 +24,12 @@ from .math import orthonormal from .support import all_same from .support import any_different +from .support import get_batched_value from .types import MJ_MINVAL from .types import NUM_GEOM_TYPES from .types import Data from .types import GeomType from .types import Model -from .support import get_batched_value # XXX disable backward pass codegen globally for now # enabling backward pass leads to 10min compile time @@ -736,7 +736,10 @@ def gjk_epa_sparse( info1 = _geom(g1, m, d.geom_xpos[worldid], d.geom_xmat[worldid], worldid) info2 = _geom(g2, m, d.geom_xpos[worldid], d.geom_xmat[worldid], worldid) - margin = wp.max(get_batched_value(m.geom_margin, worldid, g1), get_batched_value(m.geom_margin, worldid, g2)) + margin = wp.max( + get_batched_value(m.geom_margin, worldid, g1), + get_batched_value(m.geom_margin, worldid, g2), + ) simplex, normal = _gjk(m, info1, info2) diff --git a/mujoco_warp/_src/collision_driver.py b/mujoco_warp/_src/collision_driver.py index e38a01e1..1580360a 100644 --- a/mujoco_warp/_src/collision_driver.py +++ b/mujoco_warp/_src/collision_driver.py @@ -20,13 +20,13 @@ from .collision_box import box_box_narrowphase from .collision_convex import gjk_narrowphase from .collision_primitive import primitive_narrowphase +from .support import get_batched_value from .types import MJ_MAXVAL from .types import MJ_MINVAL from .types import Data from .types import DisableBit from .types import Model from .warp_util import event_scope -from .support import get_batched_value wp.set_module_options({"enable_backward": False}) diff --git a/mujoco_warp/_src/collision_primitive.py b/mujoco_warp/_src/collision_primitive.py index 692631bf..78f36860 100644 --- a/mujoco_warp/_src/collision_primitive.py +++ b/mujoco_warp/_src/collision_primitive.py @@ -19,13 +19,13 @@ from .math import closest_segment_to_segment_points from .math import make_frame from .math import normalize_with_norm +from .support import get_batched_array +from .support import get_batched_value from .types import MJ_MINVAL from .types import Data from .types import GeomType from .types import Model from .types import vec5 -from .support import get_batched_value -from .support import get_batched_array wp.set_module_options({"enable_backward": False}) @@ -768,8 +768,14 @@ def contact_params(m: Model, d: Data, cid: int, worldid: int): mix = wp.where((solmix1 >= MJ_MINVAL) and (solmix2 < MJ_MINVAL), 1.0, mix) mix = wp.where(p1 == p2, mix, wp.where(p1 > p2, 1.0, 0.0)) - margin = wp.max(get_batched_value(m.geom_margin, worldid, g1), get_batched_value(m.geom_margin, worldid, g2)) - gap = wp.max(get_batched_value(m.geom_gap, worldid, g1), get_batched_value(m.geom_gap, worldid, g2)) + margin = wp.max( + get_batched_value(m.geom_margin, worldid, g1), + get_batched_value(m.geom_margin, worldid, g2), + ) + gap = wp.max( + get_batched_value(m.geom_gap, worldid, g1), + get_batched_value(m.geom_gap, worldid, g2), + ) condim1 = m.geom_condim[g1] condim2 = m.geom_condim[g2] @@ -777,7 +783,10 @@ def contact_params(m: Model, d: Data, cid: int, worldid: int): p1 == p2, wp.max(condim1, condim2), wp.where(p1 > p2, condim1, condim2) ) - geom_friction = wp.max(get_batched_value(m.geom_friction, worldid, g1), get_batched_value(m.geom_friction, worldid, g2)) + geom_friction = wp.max( + get_batched_value(m.geom_friction, worldid, g1), + get_batched_value(m.geom_friction, worldid, g2), + ) friction = vec5( geom_friction[0], geom_friction[0], @@ -786,14 +795,24 @@ def contact_params(m: Model, d: Data, cid: int, worldid: int): geom_friction[2], ) - if get_batched_value(m.geom_solref, worldid, g1).x > 0.0 and get_batched_value(m.geom_solref, worldid, g2).x > 0.0: - solref = mix * get_batched_value(m.geom_solref, worldid, g1) + (1.0 - mix) * get_batched_value(m.geom_solref, worldid, g2) + if ( + get_batched_value(m.geom_solref, worldid, g1).x > 0.0 + and get_batched_value(m.geom_solref, worldid, g2).x > 0.0 + ): + solref = mix * get_batched_value(m.geom_solref, worldid, g1) + ( + 1.0 - mix + ) * get_batched_value(m.geom_solref, worldid, g2) else: - solref = wp.min(get_batched_value(m.geom_solref, worldid, g1), get_batched_value(m.geom_solref, worldid, g2)) + solref = wp.min( + get_batched_value(m.geom_solref, worldid, g1), + get_batched_value(m.geom_solref, worldid, g2), + ) solreffriction = wp.vec2(0.0, 0.0) - solimp = mix * get_batched_value(m.geom_solimp, worldid, g1) + (1.0 - mix) * get_batched_value(m.geom_solimp, worldid, g2) + solimp = mix * get_batched_value(m.geom_solimp, worldid, g1) + ( + 1.0 - mix + ) * get_batched_value(m.geom_solimp, worldid, g2) return geoms, margin, gap, condim, friction, solref, solreffriction, solimp @@ -879,7 +898,6 @@ def _primitive_narrowphase( g1 = geoms[0] g2 = geoms[1] - geom1 = _geom(g1, m, d.geom_xpos[worldid], d.geom_xmat[worldid], worldid) geom2 = _geom(g2, m, d.geom_xpos[worldid], d.geom_xmat[worldid], worldid) diff --git a/mujoco_warp/_src/constraint.py b/mujoco_warp/_src/constraint.py index 94d76a4e..edd9f011 100644 --- a/mujoco_warp/_src/constraint.py +++ b/mujoco_warp/_src/constraint.py @@ -18,8 +18,8 @@ from . import math from . import support from . import types -from .support import get_batched_value from .support import get_batched_array +from .support import get_batched_value from .warp_util import event_scope wp.config.enable_backward = False @@ -133,7 +133,9 @@ def _efc_equality_connect( d.efc.J[efcid + 2, dofid] = j1mj2[2] Jqvel += j1mj2 * d.qvel[worldid, dofid] - invweight = get_batched_value(m.body_invweight0, worldid, body1id, 0) + get_batched_value(m.body_invweight0, worldid, body2id, 0) + invweight = get_batched_value( + m.body_invweight0, worldid, body1id, 0 + ) + get_batched_value(m.body_invweight0, worldid, body2id, 0) pos_imp = wp.length(pos) solref = get_batched_value(m.eq_solref, worldid, i_eq) @@ -192,14 +194,22 @@ def _efc_equality_joint( 2.0 * data[2] + dif * (3.0 * data[3] + dif * 4.0 * data[4]) ) - pos = d.qpos[worldid, qposadr1] - get_batched_value(m.qpos0, worldid, qposadr1) - rhs + pos = ( + d.qpos[worldid, qposadr1] - get_batched_value(m.qpos0, worldid, qposadr1) - rhs + ) Jqvel = d.qvel[worldid, dofadr1] - d.qvel[worldid, dofadr2] * deriv_2 - invweight = get_batched_value(m.dof_invweight0, worldid, dofadr1) + get_batched_value(m.dof_invweight0, worldid, dofadr2) + invweight = get_batched_value( + m.dof_invweight0, worldid, dofadr1 + ) + get_batched_value(m.dof_invweight0, worldid, dofadr2) d.efc.J[efcid, dofadr2] = -deriv_2 else: # Single joint constraint - pos = d.qpos[worldid, qposadr1] - get_batched_value(m.qpos0, worldid, qposadr1) - data[0] + pos = ( + d.qpos[worldid, qposadr1] + - get_batched_value(m.qpos0, worldid, qposadr1) + - data[0] + ) Jqvel = d.qvel[worldid, dofadr1] invweight = get_batched_value(m.dof_invweight0, worldid, dofadr1) @@ -286,8 +296,14 @@ def _efc_equality_weld( body2id = m.site_bodyid[obj2id] pos1 = d.site_xpos[worldid, obj1id] pos2 = d.site_xpos[worldid, obj2id] - quat = math.mul_quat(d.xquat[worldid, body1id], get_batched_value(m.site_quat, worldid, obj1id)) - quat1 = math.quat_inv(math.mul_quat(d.xquat[worldid, body2id], get_batched_value(m.site_quat, worldid, obj2id))) + quat = math.mul_quat( + d.xquat[worldid, body1id], get_batched_value(m.site_quat, worldid, obj1id) + ) + quat1 = math.quat_inv( + math.mul_quat( + d.xquat[worldid, body2id], get_batched_value(m.site_quat, worldid, obj2id) + ) + ) else: body1id = obj1id @@ -326,7 +342,9 @@ def _efc_equality_weld( crotq = math.mul_quat(quat1, quat) # copy axis components crot = wp.vec3(crotq[1], crotq[2], crotq[3]) * torquescale - invweight_t = get_batched_value(m.body_invweight0, worldid, body1id, 0) + get_batched_value(m.body_invweight0, worldid, body2id, 0) + invweight_t = get_batched_value( + m.body_invweight0, worldid, body1id, 0 + ) + get_batched_value(m.body_invweight0, worldid, body2id, 0) pos_imp = wp.sqrt(wp.length_sq(cpos) + wp.length_sq(crot)) @@ -349,7 +367,9 @@ def _efc_equality_weld( i_eq, ) - invweight_r = get_batched_value(m.body_invweight0, worldid, body1id, 1) + get_batched_value(m.body_invweight0, worldid, body2id, 1) + invweight_r = get_batched_value( + m.body_invweight0, worldid, body1id, 1 + ) + get_batched_value(m.body_invweight0, worldid, body2id, 1) for i in range(3): _update_efc_row( @@ -548,7 +568,9 @@ def _efc_contact_pyramidal( frame = d.contact.frame[conid] # pyramidal has common invweight across all edges - invweight = get_batched_value(m.body_invweight0, worldid, body1, 0) + get_batched_value(m.body_invweight0, worldid, body2, 0) + invweight = get_batched_value( + m.body_invweight0, worldid, body1, 0 + ) + get_batched_value(m.body_invweight0, worldid, body2, 0) if condim > 1: dimid2 = dimid / 2 + 1 @@ -649,7 +671,9 @@ def _efc_contact_elliptic( d.efc.J[efcid, i] = J Jqvel += J * d.qvel[worldid, i] - invweight = get_batched_value(m.body_invweight0, worldid, body1, 0) + get_batched_value(m.body_invweight0, worldid, body2, 0) + invweight = get_batched_value( + m.body_invweight0, worldid, body1, 0 + ) + get_batched_value(m.body_invweight0, worldid, body2, 0) ref = d.contact.solref[conid] pos_aref = pos diff --git a/mujoco_warp/_src/forward.py b/mujoco_warp/_src/forward.py index 6816cd9b..3d4a31df 100644 --- a/mujoco_warp/_src/forward.py +++ b/mujoco_warp/_src/forward.py @@ -25,9 +25,9 @@ from . import sensor from . import smooth from . import solver -from .support import xfrc_accumulate from .support import get_batched_array from .support import get_batched_value +from .support import xfrc_accumulate from .types import MJ_MINVAL from .types import BiasType from .types import Data @@ -195,7 +195,9 @@ def add_damping_sum_qfrc_kernel_sparse(m: Model, d: Data): worldid, tid = wp.tid() dof_Madr = m.dof_Madr[tid] - d.qM_integration[worldid, 0, dof_Madr] += m.opt.timestep * get_batched_value(m.dof_damping, worldid, tid) + d.qM_integration[worldid, 0, dof_Madr] += m.opt.timestep * get_batched_value( + m.dof_damping, worldid, tid + ) d.qfrc_integration[worldid, tid] = ( d.qfrc_smooth[worldid, tid] + d.qfrc_constraint[worldid, tid] @@ -224,7 +226,9 @@ def eulerdamp( M_tile = wp.tile_load( d.qM[worldid], shape=(tilesize, tilesize), offset=(dofid, dofid) ) - damping_tile = wp.tile_load(get_batched_array(damping, worldid), shape=(tilesize,), offset=(dofid,)) + damping_tile = wp.tile_load( + get_batched_array(damping, worldid), shape=(tilesize,), offset=(dofid,) + ) damping_scaled = damping_tile * m.opt.timestep qm_integration_tile = wp.tile_diag_add(M_tile, damping_scaled) @@ -436,7 +440,9 @@ def qderiv_actuator_fused_kernel( ) if wp.static(passive_enabled): - dof_damping = wp.tile_load(get_batched_array(damping, worldid), shape=tilesize_nv, offset=offset_nv) + dof_damping = wp.tile_load( + get_batched_array(damping, worldid), shape=tilesize_nv, offset=offset_nv + ) negative = wp.neg(dof_damping) qderiv_tile = wp.tile_diag_add(qderiv_tile, negative) diff --git a/mujoco_warp/_src/io.py b/mujoco_warp/_src/io.py index b5093a3d..292706d4 100644 --- a/mujoco_warp/_src/io.py +++ b/mujoco_warp/_src/io.py @@ -69,12 +69,14 @@ def geom_pair(m: mujoco.MjModel) -> Tuple[np.array, np.array]: return np.array(geompairs), np.array(pairids) + def create_nmodel_batched_array(mjm_array, dtype): - array = wp.array(mjm_array, dtype=dtype) - array.ndim += 1 - array.shape = (1,) + array.shape - array.strides = (0,) + array.strides - return array + array = wp.array(mjm_array, dtype=dtype) + array.ndim += 1 + array.shape = (1,) + array.shape + array.strides = (0,) + array.strides + return array + def put_model(mjm: mujoco.MjModel) -> types.Model: # check supported features @@ -459,7 +461,9 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: m.dof_Madr = wp.array(mjm.dof_Madr, dtype=wp.int32, ndim=1) m.dof_armature = create_nmodel_batched_array(mjm.dof_armature, dtype=wp.float32) m.dof_damping = create_nmodel_batched_array(mjm.dof_damping, dtype=wp.float32) - m.dof_frictionloss = create_nmodel_batched_array(mjm.dof_frictionloss, dtype=wp.float32) + m.dof_frictionloss = create_nmodel_batched_array( + mjm.dof_frictionloss, dtype=wp.float32 + ) m.dof_solimp = create_nmodel_batched_array(mjm.dof_solimp, dtype=types.vec5) m.dof_solref = create_nmodel_batched_array(mjm.dof_solref, dtype=wp.vec2) m.dof_tri_row = wp.from_numpy(dof_tri_row, dtype=wp.int32) @@ -468,20 +472,34 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: m.actuator_trntype = wp.array(mjm.actuator_trntype, dtype=wp.int32, ndim=1) m.actuator_trnid = wp.array(mjm.actuator_trnid, dtype=wp.int32, ndim=2) m.actuator_ctrllimited = wp.array(mjm.actuator_ctrllimited, dtype=wp.bool, ndim=1) - m.actuator_ctrlrange = create_nmodel_batched_array(mjm.actuator_ctrlrange, dtype=wp.vec2) + m.actuator_ctrlrange = create_nmodel_batched_array( + mjm.actuator_ctrlrange, dtype=wp.vec2 + ) m.actuator_forcelimited = wp.array(mjm.actuator_forcelimited, dtype=wp.bool, ndim=1) - m.actuator_forcerange = create_nmodel_batched_array(mjm.actuator_forcerange, dtype=wp.vec2) + m.actuator_forcerange = create_nmodel_batched_array( + mjm.actuator_forcerange, dtype=wp.vec2 + ) m.actuator_gaintype = wp.array(mjm.actuator_gaintype, dtype=wp.int32, ndim=1) - m.actuator_gainprm = create_nmodel_batched_array(mjm.actuator_gainprm, dtype=types.vec10f) + m.actuator_gainprm = create_nmodel_batched_array( + mjm.actuator_gainprm, dtype=types.vec10f + ) m.actuator_biastype = wp.array(mjm.actuator_biastype, dtype=wp.int32, ndim=1) - m.actuator_biasprm = create_nmodel_batched_array(mjm.actuator_biasprm, dtype=types.vec10f) - m.actuator_gear = create_nmodel_batched_array(mjm.actuator_gear, dtype=wp.spatial_vector) + m.actuator_biasprm = create_nmodel_batched_array( + mjm.actuator_biasprm, dtype=types.vec10f + ) + m.actuator_gear = create_nmodel_batched_array( + mjm.actuator_gear, dtype=wp.spatial_vector + ) m.actuator_actlimited = wp.array(mjm.actuator_actlimited, dtype=wp.bool, ndim=1) - m.actuator_actrange = create_nmodel_batched_array(mjm.actuator_actrange, dtype=wp.vec2) + m.actuator_actrange = create_nmodel_batched_array( + mjm.actuator_actrange, dtype=wp.vec2 + ) m.actuator_actadr = wp.array(mjm.actuator_actadr, dtype=wp.int32, ndim=1) m.actuator_actnum = wp.array(mjm.actuator_actnum, dtype=wp.int32, ndim=1) m.actuator_dyntype = wp.array(mjm.actuator_dyntype, dtype=wp.int32, ndim=1) - m.actuator_dynprm = create_nmodel_batched_array(mjm.actuator_dynprm, dtype=types.vec10f) + m.actuator_dynprm = create_nmodel_batched_array( + mjm.actuator_dynprm, dtype=types.vec10f + ) m.exclude_signature = wp.array(mjm.exclude_signature, dtype=wp.int32, ndim=1) # pre-compute indices of equality constraints @@ -510,7 +528,9 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: m.pair_geom1 = wp.array(mjm.pair_geom1, dtype=wp.int32, ndim=1) m.pair_geom2 = create_nmodel_batched_array(mjm.pair_geom2, dtype=wp.int32) m.pair_solref = create_nmodel_batched_array(mjm.pair_solref, dtype=wp.vec2) - m.pair_solreffriction = create_nmodel_batched_array(mjm.pair_solreffriction, dtype=wp.vec2) + m.pair_solreffriction = create_nmodel_batched_array( + mjm.pair_solreffriction, dtype=wp.vec2 + ) m.pair_solimp = create_nmodel_batched_array(mjm.pair_solimp, dtype=types.vec5) m.pair_margin = create_nmodel_batched_array(mjm.pair_margin, dtype=wp.float32) m.pair_gap = create_nmodel_batched_array(mjm.pair_gap, dtype=wp.float32) @@ -524,12 +544,18 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: m.tendon_limited_adr = wp.array( np.nonzero(mjm.tendon_limited)[0], dtype=wp.int32, ndim=1 ) - m.tendon_solref_lim = create_nmodel_batched_array(mjm.tendon_solref_lim, dtype=wp.vec2f) - m.tendon_solimp_lim = create_nmodel_batched_array(mjm.tendon_solimp_lim, dtype=types.vec5) + m.tendon_solref_lim = create_nmodel_batched_array( + mjm.tendon_solref_lim, dtype=wp.vec2f + ) + m.tendon_solimp_lim = create_nmodel_batched_array( + mjm.tendon_solimp_lim, dtype=types.vec5 + ) m.tendon_range = create_nmodel_batched_array(mjm.tendon_range, dtype=wp.vec2f) m.tendon_margin = create_nmodel_batched_array(mjm.tendon_margin, dtype=wp.float32) m.tendon_length0 = create_nmodel_batched_array(mjm.tendon_length0, dtype=wp.float32) - m.tendon_invweight0 = create_nmodel_batched_array(mjm.tendon_invweight0, dtype=wp.float32) + m.tendon_invweight0 = create_nmodel_batched_array( + mjm.tendon_invweight0, dtype=wp.float32 + ) m.wrap_objid = wp.array(mjm.wrap_objid, dtype=wp.int32) m.wrap_prm = create_nmodel_batched_array(mjm.wrap_prm, dtype=wp.float32) m.wrap_type = wp.array(mjm.wrap_type, dtype=wp.int32) diff --git a/mujoco_warp/_src/io_test.py b/mujoco_warp/_src/io_test.py index e65b0ae3..72d02f34 100644 --- a/mujoco_warp/_src/io_test.py +++ b/mujoco_warp/_src/io_test.py @@ -28,11 +28,13 @@ # due to float precision _TOLERANCE = 5e-5 + def _assert_eq(a, b, name): tol = _TOLERANCE * 10 # avoid test noise err_msg = f"mismatch: {name}" np.testing.assert_allclose(a, b, err_msg=err_msg, atol=tol, rtol=tol) + class IOTest(absltest.TestCase): def test_make_put_data(self): """Tests that make_data and put_data are producing the same shapes for all warp arrays.""" @@ -298,7 +300,6 @@ def test_option_physical_constants(self): with self.assertRaises(NotImplementedError): mjwarp.put_model(mjm) - def test_model_batching(self): mjm, mjd, _, _ = test_util.fixture("humanoid/humanoid.xml", kick=True) @@ -310,7 +311,7 @@ def test_model_batching(self): dof_damping = np.zeros((2, len(damping_orig)), dtype=np.float32) dof_damping[0, :] = damping_orig dof_damping[1, :] = damping_orig * 0.5 - + # set the batched damping values m.dof_damping = wp.from_numpy(dof_damping, dtype=wp.float32) diff --git a/mujoco_warp/_src/passive.py b/mujoco_warp/_src/passive.py index 6cb1ad7c..39c3eb6f 100644 --- a/mujoco_warp/_src/passive.py +++ b/mujoco_warp/_src/passive.py @@ -16,13 +16,13 @@ import warp as wp from . import math +from .support import get_batched_value from .types import Data from .types import DisableBit from .types import JointType from .types import Model from .warp_util import event_scope from .warp_util import kernel -from .support import get_batched_value @event_scope @@ -47,9 +47,12 @@ def _spring(m: Model, d: Data): if jnt_type == wp.static(JointType.FREE.value): dif = wp.vec3( - d.qpos[worldid, qposid + 0] - get_batched_value(m.qpos_spring, worldid, qposid + 0), - d.qpos[worldid, qposid + 1] - get_batched_value(m.qpos_spring, worldid, qposid + 1), - d.qpos[worldid, qposid + 2] - get_batched_value(m.qpos_spring, worldid, qposid + 2), + d.qpos[worldid, qposid + 0] + - get_batched_value(m.qpos_spring, worldid, qposid + 0), + d.qpos[worldid, qposid + 1] + - get_batched_value(m.qpos_spring, worldid, qposid + 1), + d.qpos[worldid, qposid + 2] + - get_batched_value(m.qpos_spring, worldid, qposid + 2), ) d.qfrc_spring[worldid, dofid + 0] = -stiffness * dif[0] d.qfrc_spring[worldid, dofid + 1] = -stiffness * dif[1] diff --git a/mujoco_warp/_src/sensor.py b/mujoco_warp/_src/sensor.py index 4c3a4e44..ed1e611d 100644 --- a/mujoco_warp/_src/sensor.py +++ b/mujoco_warp/_src/sensor.py @@ -20,6 +20,7 @@ from . import math from . import smooth +from .support import get_batched_value from .types import Data from .types import DisableBit from .types import Model @@ -27,7 +28,6 @@ from .types import SensorType from .warp_util import event_scope from .warp_util import kernel -from .support import get_batched_value @wp.func @@ -137,25 +137,43 @@ def _frame_quat( m: Model, d: Data, worldid: int, objid: int, objtype: int, refid: int ) -> wp.quat: if objtype == int(ObjType.BODY.value): - quat = math.mul_quat(d.xquat[worldid, objid], get_batched_value(m.body_iquat, worldid, objid)) + quat = math.mul_quat( + d.xquat[worldid, objid], get_batched_value(m.body_iquat, worldid, objid) + ) if refid == -1: return quat - refquat = math.mul_quat(d.xquat[worldid, refid], get_batched_value(m.body_iquat, worldid, refid)) + refquat = math.mul_quat( + d.xquat[worldid, refid], get_batched_value(m.body_iquat, worldid, refid) + ) elif objtype == int(ObjType.XBODY.value): quat = d.xquat[worldid, objid] if refid == -1: return quat - refquat = math.mul_quat(d.xquat[worldid, refid], get_batched_value(m.body_iquat, worldid, refid)) + refquat = math.mul_quat( + d.xquat[worldid, refid], get_batched_value(m.body_iquat, worldid, refid) + ) elif objtype == int(ObjType.GEOM.value): - quat = math.mul_quat(d.xquat[worldid, m.geom_bodyid[objid]], get_batched_value(m.geom_quat, worldid, objid)) + quat = math.mul_quat( + d.xquat[worldid, m.geom_bodyid[objid]], + get_batched_value(m.geom_quat, worldid, objid), + ) if refid == -1: return quat - refquat = math.mul_quat(d.xquat[worldid, m.geom_bodyid[refid]], get_batched_value(m.geom_quat, worldid, refid)) + refquat = math.mul_quat( + d.xquat[worldid, m.geom_bodyid[refid]], + get_batched_value(m.geom_quat, worldid, refid), + ) elif objtype == int(ObjType.SITE.value): - quat = math.mul_quat(d.xquat[worldid, m.site_bodyid[objid]], get_batched_value(m.site_quat, worldid, objid)) + quat = math.mul_quat( + d.xquat[worldid, m.site_bodyid[objid]], + get_batched_value(m.site_quat, worldid, objid), + ) if refid == -1: return quat - refquat = math.mul_quat(d.xquat[worldid, m.site_bodyid[refid]], get_batched_value(m.site_quat, worldid, refid)) + refquat = math.mul_quat( + d.xquat[worldid, m.site_bodyid[refid]], + get_batched_value(m.site_quat, worldid, refid), + ) # TODO(team): camera diff --git a/mujoco_warp/_src/smooth.py b/mujoco_warp/_src/smooth.py index 1c1e35f5..075b139c 100644 --- a/mujoco_warp/_src/smooth.py +++ b/mujoco_warp/_src/smooth.py @@ -19,6 +19,7 @@ from . import math from . import support +from .support import get_batched_value from .types import MJ_MINVAL from .types import CamLightType from .types import Data @@ -31,7 +32,6 @@ from .types import array2df from .types import array3df from .types import vec10 -from .support import get_batched_value from .warp_util import event_scope from .warp_util import kernel @@ -62,8 +62,12 @@ def _level(m: Model, d: Data, leveladr: int): if jntnum == 0: # no joints - apply fixed translation and rotation relative to parent pid = m.body_parentid[bodyid] - xpos = (d.xmat[worldid, pid] * get_batched_value(m.body_pos, worldid, bodyid)) + d.xpos[worldid, pid] - xquat = math.mul_quat(d.xquat[worldid, pid], get_batched_value(m.body_quat, worldid, bodyid)) + xpos = ( + d.xmat[worldid, pid] * get_batched_value(m.body_pos, worldid, bodyid) + ) + d.xpos[worldid, pid] + xquat = math.mul_quat( + d.xquat[worldid, pid], get_batched_value(m.body_quat, worldid, bodyid) + ) elif jntnum == 1 and m.jnt_type[jntadr] == wp.static(JointType.FREE.value): # free joint qadr = m.jnt_qposadr[jntadr] @@ -75,14 +79,20 @@ def _level(m: Model, d: Data, leveladr: int): # regular or no joints # apply fixed translation and rotation relative to parent pid = m.body_parentid[bodyid] - xpos = (d.xmat[worldid, pid] * get_batched_value(m.body_pos, worldid, bodyid)) + d.xpos[worldid, pid] - xquat = math.mul_quat(d.xquat[worldid, pid], get_batched_value(m.body_quat, worldid, bodyid)) + xpos = ( + d.xmat[worldid, pid] * get_batched_value(m.body_pos, worldid, bodyid) + ) + d.xpos[worldid, pid] + xquat = math.mul_quat( + d.xquat[worldid, pid], get_batched_value(m.body_quat, worldid, bodyid) + ) for _ in range(jntnum): qadr = m.jnt_qposadr[jntadr] jnt_type = m.jnt_type[jntadr] jnt_axis = m.jnt_axis[jntadr] - xanchor = math.rot_vec_quat(get_batched_value(m.jnt_pos, worldid, jntadr), xquat) + xpos + xanchor = ( + math.rot_vec_quat(get_batched_value(m.jnt_pos, worldid, jntadr), xquat) + xpos + ) xaxis = math.rot_vec_quat(jnt_axis, xquat) if jnt_type == wp.static(JointType.BALL.value): @@ -94,7 +104,9 @@ def _level(m: Model, d: Data, leveladr: int): ) xquat = math.mul_quat(xquat, qloc) # correct for off-center rotation - xpos = xanchor - math.rot_vec_quat(get_batched_value(m.jnt_pos, worldid, jntadr), xquat) + xpos = xanchor - math.rot_vec_quat( + get_batched_value(m.jnt_pos, worldid, jntadr), xquat + ) elif jnt_type == wp.static(JointType.SLIDE.value): xpos += xaxis * (qpos[qadr] - get_batched_value(m.qpos0, worldid, qadr)) elif jnt_type == wp.static(JointType.HINGE.value): @@ -102,7 +114,9 @@ def _level(m: Model, d: Data, leveladr: int): qloc = math.axis_angle_to_quat(jnt_axis, qpos[qadr] - qpos0) xquat = math.mul_quat(xquat, qloc) # correct for off-center rotation - xpos = xanchor - math.rot_vec_quat(get_batched_value(m.jnt_pos, worldid, jntadr), xquat) + xpos = xanchor - math.rot_vec_quat( + get_batched_value(m.jnt_pos, worldid, jntadr), xquat + ) d.xanchor[worldid, jntadr] = xanchor d.xaxis[worldid, jntadr] = xaxis @@ -112,7 +126,9 @@ def _level(m: Model, d: Data, leveladr: int): xquat = wp.normalize(xquat) d.xquat[worldid, bodyid] = xquat d.xmat[worldid, bodyid] = math.quat_to_mat(xquat) - d.xipos[worldid, bodyid] = xpos + math.rot_vec_quat(get_batched_value(m.body_ipos, worldid, bodyid), xquat) + d.xipos[worldid, bodyid] = xpos + math.rot_vec_quat( + get_batched_value(m.body_ipos, worldid, bodyid), xquat + ) d.ximat[worldid, bodyid] = math.quat_to_mat( math.mul_quat(xquat, get_batched_value(m.body_iquat, worldid, bodyid)) ) @@ -123,7 +139,9 @@ def geom_local_to_global(m: Model, d: Data): bodyid = m.geom_bodyid[geomid] xpos = d.xpos[worldid, bodyid] xquat = d.xquat[worldid, bodyid] - d.geom_xpos[worldid, geomid] = xpos + math.rot_vec_quat(get_batched_value(m.geom_pos, worldid, geomid), xquat) + d.geom_xpos[worldid, geomid] = xpos + math.rot_vec_quat( + get_batched_value(m.geom_pos, worldid, geomid), xquat + ) d.geom_xmat[worldid, geomid] = math.quat_to_mat( math.mul_quat(xquat, get_batched_value(m.geom_quat, worldid, geomid)) ) @@ -134,7 +152,9 @@ def site_local_to_global(m: Model, d: Data): bodyid = m.site_bodyid[siteid] xpos = d.xpos[worldid, bodyid] xquat = d.xquat[worldid, bodyid] - d.site_xpos[worldid, siteid] = xpos + math.rot_vec_quat(get_batched_value(m.site_pos, worldid, siteid), xquat) + d.site_xpos[worldid, siteid] = xpos + math.rot_vec_quat( + get_batched_value(m.site_pos, worldid, siteid), xquat + ) d.site_xmat[worldid, siteid] = math.quat_to_mat( math.mul_quat(xquat, get_batched_value(m.site_quat, worldid, siteid)) ) @@ -161,7 +181,9 @@ def com_pos(m: Model, d: Data): @kernel def subtree_com_init(m: Model, d: Data): worldid, bodyid = wp.tid() - d.subtree_com[worldid, bodyid] = d.xipos[worldid, bodyid] * get_batched_value(m.body_mass, worldid, bodyid) + d.subtree_com[worldid, bodyid] = d.xipos[worldid, bodyid] * get_batched_value( + m.body_mass, worldid, bodyid + ) @kernel def subtree_com_acc(m: Model, d: Data, leveladr: int): @@ -265,7 +287,9 @@ def cam_local_to_global(m: Model, d: Data): bodyid = m.cam_bodyid[camid] xpos = d.xpos[worldid, bodyid] xquat = d.xquat[worldid, bodyid] - d.cam_xpos[worldid, camid] = xpos + math.rot_vec_quat(get_batched_value(m.cam_pos, worldid, camid), xquat) + d.cam_xpos[worldid, camid] = xpos + math.rot_vec_quat( + get_batched_value(m.cam_pos, worldid, camid), xquat + ) d.cam_xmat[worldid, camid] = math.quat_to_mat( math.mul_quat(xquat, get_batched_value(m.cam_quat, worldid, camid)) ) @@ -281,11 +305,13 @@ def cam_fn(m: Model, d: Data): return elif m.cam_mode[camid] == wp.static(CamLightType.TRACK.value): body_xpos = d.xpos[worldid, m.cam_bodyid[camid]] - d.cam_xpos[worldid, camid] = body_xpos + get_batched_value(m.cam_pos0, worldid, camid) - elif m.cam_mode[camid] == wp.static(CamLightType.TRACKCOM.value): - d.cam_xpos[worldid, camid] = ( - d.subtree_com[worldid, m.cam_bodyid[camid]] + get_batched_value(m.cam_poscom0, worldid, camid) + d.cam_xpos[worldid, camid] = body_xpos + get_batched_value( + m.cam_pos0, worldid, camid ) + elif m.cam_mode[camid] == wp.static(CamLightType.TRACKCOM.value): + d.cam_xpos[worldid, camid] = d.subtree_com[ + worldid, m.cam_bodyid[camid] + ] + get_batched_value(m.cam_poscom0, worldid, camid) elif m.cam_mode[camid] == wp.static(CamLightType.TARGETBODY.value) or m.cam_mode[ camid ] == wp.static(CamLightType.TARGETBODYCOM.value): @@ -315,7 +341,9 @@ def light_local_to_global(m: Model, d: Data): d.light_xpos[worldid, lightid] = xpos + math.rot_vec_quat( get_batched_value(m.light_pos, worldid, lightid), xquat ) - d.light_xdir[worldid, lightid] = math.rot_vec_quat(get_batched_value(m.light_dir, worldid, lightid), xquat) + d.light_xdir[worldid, lightid] = math.rot_vec_quat( + get_batched_value(m.light_dir, worldid, lightid), xquat + ) @kernel def light_fn(m: Model, d: Data): @@ -328,11 +356,13 @@ def light_fn(m: Model, d: Data): return elif m.light_mode[lightid] == wp.static(CamLightType.TRACK.value): body_xpos = d.xpos[worldid, m.light_bodyid[lightid]] - d.light_xpos[worldid, lightid] = body_xpos + get_batched_value(m.light_pos0, worldid, lightid) - elif m.light_mode[lightid] == wp.static(CamLightType.TRACKCOM.value): - d.light_xpos[worldid, lightid] = ( - d.subtree_com[worldid, m.light_bodyid[lightid]] + get_batched_value(m.light_poscom0, worldid, lightid) + d.light_xpos[worldid, lightid] = body_xpos + get_batched_value( + m.light_pos0, worldid, lightid ) + elif m.light_mode[lightid] == wp.static(CamLightType.TRACKCOM.value): + d.light_xpos[worldid, lightid] = d.subtree_com[ + worldid, m.light_bodyid[lightid] + ] + get_batched_value(m.light_poscom0, worldid, lightid) elif m.light_mode[lightid] == wp.static( CamLightType.TARGETBODY.value ) or m.light_mode[lightid] == wp.static(CamLightType.TARGETBODYCOM.value): @@ -1107,7 +1137,9 @@ def _forward(m: Model, d: Data): # update linear velocity lin -= wp.cross(xipos - subtree_com_root, ang) - d.subtree_linvel[worldid, bodyid] = get_batched_value(m.body_mass, worldid, bodyid) * lin + d.subtree_linvel[worldid, bodyid] = ( + get_batched_value(m.body_mass, worldid, bodyid) * lin + ) dv = wp.transpose(ximat) @ ang dv[0] *= get_batched_value(m.body_inertia, worldid, bodyid)[0] dv[1] *= get_batched_value(m.body_inertia, worldid, bodyid)[1] diff --git a/mujoco_warp/_src/support.py b/mujoco_warp/_src/support.py index fa0b95aa..ca9970e2 100644 --- a/mujoco_warp/_src/support.py +++ b/mujoco_warp/_src/support.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== -from typing import Tuple, Any +from typing import Any, Tuple import mujoco import warp as wp @@ -350,29 +350,38 @@ def jac( return jacp, jacr + @wp.func def get_batched_array(array: wp.array2d(dtype=Any), worldid: wp.int32): """Returns the array slice for the given worldid.""" modelid = worldid % array.shape[0] return array[modelid] + @wp.func def get_batched_value(array: wp.array2d(dtype=Any), worldid: wp.int32, i: wp.int32): modelid = worldid % array.shape[0] return array[modelid, i] + @wp.func def get_batched_array(array: wp.array3d(dtype=Any), worldid: wp.int32, i: wp.int32): """Returns the array slice for the given worldid.""" modelid = worldid % array.shape[0] return array[modelid, i] + @wp.func -def get_batched_value(array: wp.array3d(dtype=Any), worldid: wp.int32, i: wp.int32, j: wp.int32): +def get_batched_value( + array: wp.array3d(dtype=Any), worldid: wp.int32, i: wp.int32, j: wp.int32 +): modelid = worldid % array.shape[0] return array[modelid, i, j] + @wp.func -def get_batched_value(array: wp.array4d(dtype=Any), worldid: wp.int32, i: wp.int32, j: wp.int32, k: wp.int32): +def get_batched_value( + array: wp.array4d(dtype=Any), worldid: wp.int32, i: wp.int32, j: wp.int32, k: wp.int32 +): modelid = worldid % array.shape[0] return array[modelid, i, j, k] From 7a8bb01b2b5071fdf2b22913127acb23c3068f7c Mon Sep 17 00:00:00 2001 From: Alain Denzler Date: Wed, 30 Apr 2025 14:02:48 +0200 Subject: [PATCH 15/31] remove get_batched_arrays --- mujoco_warp/_src/collision_primitive.py | 1 - mujoco_warp/_src/constraint.py | 5 ++--- mujoco_warp/_src/forward.py | 5 ++--- mujoco_warp/_src/support.py | 14 -------------- 4 files changed, 4 insertions(+), 21 deletions(-) diff --git a/mujoco_warp/_src/collision_primitive.py b/mujoco_warp/_src/collision_primitive.py index 78f36860..7338b6c6 100644 --- a/mujoco_warp/_src/collision_primitive.py +++ b/mujoco_warp/_src/collision_primitive.py @@ -19,7 +19,6 @@ from .math import closest_segment_to_segment_points from .math import make_frame from .math import normalize_with_norm -from .support import get_batched_array from .support import get_batched_value from .types import MJ_MINVAL from .types import Data diff --git a/mujoco_warp/_src/constraint.py b/mujoco_warp/_src/constraint.py index edd9f011..d0a42381 100644 --- a/mujoco_warp/_src/constraint.py +++ b/mujoco_warp/_src/constraint.py @@ -18,7 +18,6 @@ from . import math from . import support from . import types -from .support import get_batched_array from .support import get_batched_value from .warp_util import event_scope @@ -397,7 +396,7 @@ def _efc_limit_slide_hinge( jntid = m.jnt_limited_slide_hinge_adr[jntlimitedid] qpos = d.qpos[worldid, m.jnt_qposadr[jntid]] - jnt_range = get_batched_array(m.jnt_range, worldid, jntid) + jnt_range = m.jnt_range[worldid, jntid] dist_min, dist_max = qpos - jnt_range[0], jnt_range[1] - qpos pos = wp.min(dist_min, dist_max) - get_batched_value(m.jnt_margin, worldid, jntid) active = pos < 0 @@ -445,7 +444,7 @@ def _efc_limit_ball( axis_angle = math.quat_to_vel(jnt_quat) axis, angle = math.normalize_with_norm(axis_angle) jnt_margin = get_batched_value(m.jnt_margin, worldid, jntid) - jnt_range = get_batched_array(m.jnt_range, worldid, jntid) + jnt_range = m.jnt_range[worldid, jntid] pos = wp.max(jnt_range[0], jnt_range[1]) - angle - jnt_margin active = pos < 0 diff --git a/mujoco_warp/_src/forward.py b/mujoco_warp/_src/forward.py index 3d4a31df..d436850e 100644 --- a/mujoco_warp/_src/forward.py +++ b/mujoco_warp/_src/forward.py @@ -25,7 +25,6 @@ from . import sensor from . import smooth from . import solver -from .support import get_batched_array from .support import get_batched_value from .support import xfrc_accumulate from .types import MJ_MINVAL @@ -227,7 +226,7 @@ def eulerdamp( d.qM[worldid], shape=(tilesize, tilesize), offset=(dofid, dofid) ) damping_tile = wp.tile_load( - get_batched_array(damping, worldid), shape=(tilesize,), offset=(dofid,) + damping[worldid], shape=(tilesize,), offset=(dofid,) ) damping_scaled = damping_tile * m.opt.timestep qm_integration_tile = wp.tile_diag_add(M_tile, damping_scaled) @@ -441,7 +440,7 @@ def qderiv_actuator_fused_kernel( if wp.static(passive_enabled): dof_damping = wp.tile_load( - get_batched_array(damping, worldid), shape=tilesize_nv, offset=offset_nv + damping[worldid], shape=tilesize_nv, offset=offset_nv ) negative = wp.neg(dof_damping) qderiv_tile = wp.tile_diag_add(qderiv_tile, negative) diff --git a/mujoco_warp/_src/support.py b/mujoco_warp/_src/support.py index ca9970e2..54b816d5 100644 --- a/mujoco_warp/_src/support.py +++ b/mujoco_warp/_src/support.py @@ -351,26 +351,12 @@ def jac( return jacp, jacr -@wp.func -def get_batched_array(array: wp.array2d(dtype=Any), worldid: wp.int32): - """Returns the array slice for the given worldid.""" - modelid = worldid % array.shape[0] - return array[modelid] - @wp.func def get_batched_value(array: wp.array2d(dtype=Any), worldid: wp.int32, i: wp.int32): modelid = worldid % array.shape[0] return array[modelid, i] - -@wp.func -def get_batched_array(array: wp.array3d(dtype=Any), worldid: wp.int32, i: wp.int32): - """Returns the array slice for the given worldid.""" - modelid = worldid % array.shape[0] - return array[modelid, i] - - @wp.func def get_batched_value( array: wp.array3d(dtype=Any), worldid: wp.int32, i: wp.int32, j: wp.int32 From ea5c629e4469cae7408772ded97a0b652225900f Mon Sep 17 00:00:00 2001 From: Alain Denzler Date: Wed, 30 Apr 2025 14:06:22 +0200 Subject: [PATCH 16/31] first parts of get_batched_value --- mujoco_warp/_src/collision_box.py | 9 ++++----- mujoco_warp/_src/collision_convex.py | 5 ++--- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/mujoco_warp/_src/collision_box.py b/mujoco_warp/_src/collision_box.py index b482c0e7..1ab37e74 100644 --- a/mujoco_warp/_src/collision_box.py +++ b/mujoco_warp/_src/collision_box.py @@ -22,7 +22,6 @@ from .collision_primitive import contact_params from .collision_primitive import write_contact from .math import make_frame -from .support import get_batched_value from .types import Data from .types import GeomType from .types import Model @@ -221,8 +220,8 @@ def box_box_kernel( trans_atob = b_mat_inv @ (a_pos - b_pos) rot_atob = b_mat_inv @ a_mat - a_size = get_batched_value(m.geom_size, worldid, ga) - b_size = get_batched_value(m.geom_size, worldid, gb) + a_size = m.geom_size[worldid, ga] + b_size = m.geom_size[worldid, gb] a = box(rot_atob, trans_atob, a_size) b = box(wp.identity(3, wp.float32), wp.vec3(0.0), b_size) @@ -314,8 +313,8 @@ def box_box_kernel( pos[i] = pos[idx] margin = wp.max( - get_batched_value(m.geom_margin, worldid, ga), - get_batched_value(m.geom_margin, worldid, gb), + m.geom_margin[worldid, ga], + m.geom_margin[worldid, gb], ) for i in range(4): pos_glob = b_mat @ pos[i] + b_pos diff --git a/mujoco_warp/_src/collision_convex.py b/mujoco_warp/_src/collision_convex.py index 71d562fa..144bf724 100644 --- a/mujoco_warp/_src/collision_convex.py +++ b/mujoco_warp/_src/collision_convex.py @@ -24,7 +24,6 @@ from .math import orthonormal from .support import all_same from .support import any_different -from .support import get_batched_value from .types import MJ_MINVAL from .types import NUM_GEOM_TYPES from .types import Data @@ -737,8 +736,8 @@ def gjk_epa_sparse( info2 = _geom(g2, m, d.geom_xpos[worldid], d.geom_xmat[worldid], worldid) margin = wp.max( - get_batched_value(m.geom_margin, worldid, g1), - get_batched_value(m.geom_margin, worldid, g2), + m.geom_margin[worldid, g1], + m.geom_margin[worldid, g2], ) simplex, normal = _gjk(m, info1, info2) From c7cb2e70096ae1f71afccc0743e79205b43d29f9 Mon Sep 17 00:00:00 2001 From: Alain Denzler Date: Wed, 30 Apr 2025 14:06:44 +0200 Subject: [PATCH 17/31] collision driver --- mujoco_warp/_src/collision_driver.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/mujoco_warp/_src/collision_driver.py b/mujoco_warp/_src/collision_driver.py index 1580360a..d68e8aff 100644 --- a/mujoco_warp/_src/collision_driver.py +++ b/mujoco_warp/_src/collision_driver.py @@ -20,7 +20,6 @@ from .collision_box import box_box_narrowphase from .collision_convex import gjk_narrowphase from .collision_primitive import primitive_narrowphase -from .support import get_batched_value from .types import MJ_MAXVAL from .types import MJ_MINVAL from .types import Data @@ -33,8 +32,8 @@ @wp.func def _sphere_filter(m: Model, d: Data, geom1: int, geom2: int, worldid: int) -> bool: - margin1 = get_batched_value(m.geom_margin, worldid, geom1) - margin2 = get_batched_value(m.geom_margin, worldid, geom2) + margin1 = m.geom_margin[worldid, geom1] + margin2 = m.geom_margin[worldid, geom2] pos1 = d.geom_xpos[worldid, geom1] pos2 = d.geom_xpos[worldid, geom2] size1 = m.geom_rbound[geom1] @@ -112,7 +111,7 @@ def _sap_project(m: Model, d: Data, direction: wp.vec3): # geom is a plane rbound = MJ_MAXVAL - radius = rbound + get_batched_value(m.geom_margin, worldid, geomid) + radius = rbound + m.geom_margin[worldid, geomid] center = wp.dot(direction, xpos) d.sap_projection_lower[worldid, geomid] = center - radius From a4ee24ee47ee2c6a4300389567dec4baa41fa124 Mon Sep 17 00:00:00 2001 From: Alain Denzler Date: Wed, 30 Apr 2025 14:07:59 +0200 Subject: [PATCH 18/31] primitive collisions --- mujoco_warp/_src/collision_primitive.py | 51 ++++++++++++------------- 1 file changed, 25 insertions(+), 26 deletions(-) diff --git a/mujoco_warp/_src/collision_primitive.py b/mujoco_warp/_src/collision_primitive.py index 7338b6c6..28e64a97 100644 --- a/mujoco_warp/_src/collision_primitive.py +++ b/mujoco_warp/_src/collision_primitive.py @@ -19,7 +19,6 @@ from .math import closest_segment_to_segment_points from .math import make_frame from .math import normalize_with_norm -from .support import get_batched_value from .types import MJ_MINVAL from .types import Data from .types import GeomType @@ -51,7 +50,7 @@ def _geom( geom.pos = geom_xpos[gid] rot = geom_xmat[gid] geom.rot = rot - geom.size = get_batched_value(m.geom_size, worldid, gid) + geom.size = m.geom_size[worldid, gid] geom.normal = wp.vec3(rot[0, 2], rot[1, 2], rot[2, 2]) # plane dataid = m.geom_dataid[gid] if dataid >= 0: @@ -744,22 +743,22 @@ def contact_params(m: Model, d: Data, cid: int, worldid: int): pairid = d.collision_pairid[cid] if pairid > -1: - margin = get_batched_value(m.pair_margin, worldid, pairid) - gap = get_batched_value(m.pair_gap, worldid, pairid) + margin = m.pair_margin[worldid, pairid] + gap = m.pair_gap[worldid, pairid] condim = m.pair_dim[pairid] - friction = get_batched_value(m.pair_friction, worldid, pairid) - solref = get_batched_value(m.pair_solref, worldid, pairid) - solreffriction = get_batched_value(m.pair_solreffriction, worldid, pairid) - solimp = get_batched_value(m.pair_solimp, worldid, pairid) + friction = m.pair_friction[worldid, pairid] + solref = m.pair_solref[worldid, pairid] + solreffriction = m.pair_solreffriction[worldid, pairid] + solimp = m.pair_solimp[worldid, pairid] else: g1 = geoms[0] g2 = geoms[1] - p1 = get_batched_value(m.geom_priority, worldid, g1) - p2 = get_batched_value(m.geom_priority, worldid, g2) + p1 = m.geom_priority[worldid, g1] + p2 = m.geom_priority[worldid, g2] - solmix1 = get_batched_value(m.geom_solmix, worldid, g1) - solmix2 = get_batched_value(m.geom_solmix, worldid, g2) + solmix1 = m.geom_solmix[worldid, g1] + solmix2 = m.geom_solmix[worldid, g2] mix = solmix1 / (solmix1 + solmix2) mix = wp.where((solmix1 < MJ_MINVAL) and (solmix2 < MJ_MINVAL), 0.5, mix) @@ -768,12 +767,12 @@ def contact_params(m: Model, d: Data, cid: int, worldid: int): mix = wp.where(p1 == p2, mix, wp.where(p1 > p2, 1.0, 0.0)) margin = wp.max( - get_batched_value(m.geom_margin, worldid, g1), - get_batched_value(m.geom_margin, worldid, g2), + m.geom_margin[worldid, g1], + m.geom_margin[worldid, g2], ) gap = wp.max( - get_batched_value(m.geom_gap, worldid, g1), - get_batched_value(m.geom_gap, worldid, g2), + m.geom_gap[worldid, g1], + m.geom_gap[worldid, g2], ) condim1 = m.geom_condim[g1] @@ -783,8 +782,8 @@ def contact_params(m: Model, d: Data, cid: int, worldid: int): ) geom_friction = wp.max( - get_batched_value(m.geom_friction, worldid, g1), - get_batched_value(m.geom_friction, worldid, g2), + m.geom_friction[worldid, g1], + m.geom_friction[worldid, g2], ) friction = vec5( geom_friction[0], @@ -795,23 +794,23 @@ def contact_params(m: Model, d: Data, cid: int, worldid: int): ) if ( - get_batched_value(m.geom_solref, worldid, g1).x > 0.0 - and get_batched_value(m.geom_solref, worldid, g2).x > 0.0 + m.geom_solref[worldid, g1].x > 0.0 + and m.geom_solref[worldid, g2].x > 0.0 ): - solref = mix * get_batched_value(m.geom_solref, worldid, g1) + ( + solref = mix * m.geom_solref[worldid, g1] + ( 1.0 - mix - ) * get_batched_value(m.geom_solref, worldid, g2) + ) * m.geom_solref[worldid, g2] else: solref = wp.min( - get_batched_value(m.geom_solref, worldid, g1), - get_batched_value(m.geom_solref, worldid, g2), + m.geom_solref[worldid, g1], + m.geom_solref[worldid, g2], ) solreffriction = wp.vec2(0.0, 0.0) - solimp = mix * get_batched_value(m.geom_solimp, worldid, g1) + ( + solimp = mix * m.geom_solimp[worldid, g1] + ( 1.0 - mix - ) * get_batched_value(m.geom_solimp, worldid, g2) + ) * m.geom_solimp[worldid, g2] return geoms, margin, gap, condim, friction, solref, solreffriction, solimp From a0ae5c19432716f38a8f8527837479e7efd0399d Mon Sep 17 00:00:00 2001 From: Alain Denzler Date: Wed, 30 Apr 2025 14:17:07 +0200 Subject: [PATCH 19/31] constraint --- mujoco_warp/_src/constraint.py | 95 +++++++++++++++------------------- 1 file changed, 41 insertions(+), 54 deletions(-) diff --git a/mujoco_warp/_src/constraint.py b/mujoco_warp/_src/constraint.py index d0a42381..5a859de5 100644 --- a/mujoco_warp/_src/constraint.py +++ b/mujoco_warp/_src/constraint.py @@ -18,7 +18,6 @@ from . import math from . import support from . import types -from .support import get_batched_value from .warp_util import event_scope wp.config.enable_backward = False @@ -99,7 +98,7 @@ def _efc_equality_connect( for i in range(wp.static(3)): d.efc.worldid[efcid + i] = worldid - data = get_batched_value(m.eq_data, worldid, i_eq) + data = m.eq_data[worldid, i_eq] anchor1 = wp.vec3f(data[0], data[1], data[2]) anchor2 = wp.vec3f(data[3], data[4], data[5]) @@ -132,13 +131,11 @@ def _efc_equality_connect( d.efc.J[efcid + 2, dofid] = j1mj2[2] Jqvel += j1mj2 * d.qvel[worldid, dofid] - invweight = get_batched_value( - m.body_invweight0, worldid, body1id, 0 - ) + get_batched_value(m.body_invweight0, worldid, body2id, 0) + invweight = m.body_invweight0[worldid, body1id, 0] + m.body_invweight0[worldid, body2id, 0] pos_imp = wp.length(pos) - solref = get_batched_value(m.eq_solref, worldid, i_eq) - solimp = get_batched_value(m.eq_solimp, worldid, i_eq) + solref = m.eq_solref[worldid, i_eq] + solimp = m.eq_solimp[worldid, i_eq] for i in range(3): efcidi = efcid + i @@ -176,7 +173,7 @@ def _efc_equality_joint( jntid_1 = m.eq_obj1id[i_eq] jntid_2 = m.eq_obj2id[i_eq] - data = get_batched_value(m.eq_data, worldid, i_eq) + data = m.eq_data[worldid, i_eq] dofadr1 = m.jnt_dofadr[jntid_1] qposadr1 = m.jnt_qposadr[jntid_1] d.efc.J[efcid, dofadr1] = 1.0 @@ -185,7 +182,7 @@ def _efc_equality_joint( # Two joint constraint qposadr2 = m.jnt_qposadr[jntid_2] dofadr2 = m.jnt_dofadr[jntid_2] - dif = d.qpos[worldid, qposadr2] - get_batched_value(m.qpos0, worldid, qposadr2) + dif = d.qpos[worldid, qposadr2] - m.qpos0[worldid, qposadr2] # Horner's method for polynomials rhs = data[0] + dif * (data[1] + dif * (data[2] + dif * (data[3] + dif * data[4]))) @@ -194,23 +191,21 @@ def _efc_equality_joint( ) pos = ( - d.qpos[worldid, qposadr1] - get_batched_value(m.qpos0, worldid, qposadr1) - rhs + d.qpos[worldid, qposadr1] - m.qpos0[worldid, qposadr1] - rhs ) Jqvel = d.qvel[worldid, dofadr1] - d.qvel[worldid, dofadr2] * deriv_2 - invweight = get_batched_value( - m.dof_invweight0, worldid, dofadr1 - ) + get_batched_value(m.dof_invweight0, worldid, dofadr2) + invweight = m.dof_invweight0[worldid, dofadr1] + m.dof_invweight0[worldid, dofadr2] d.efc.J[efcid, dofadr2] = -deriv_2 else: # Single joint constraint pos = ( d.qpos[worldid, qposadr1] - - get_batched_value(m.qpos0, worldid, qposadr1) + - m.qpos0[worldid, qposadr1] - data[0] ) Jqvel = d.qvel[worldid, dofadr1] - invweight = get_batched_value(m.dof_invweight0, worldid, dofadr1) + invweight = m.dof_invweight0[worldid, dofadr1] # Update constraint parameters _update_efc_row( @@ -220,8 +215,8 @@ def _efc_equality_joint( pos, pos, invweight, - get_batched_value(m.eq_solref, worldid, i_eq), - get_batched_value(m.eq_solimp, worldid, i_eq), + m.eq_solref[worldid, i_eq], + m.eq_solimp[worldid, i_eq], wp.float32(0.0), Jqvel, 0.0, @@ -237,7 +232,7 @@ def _efc_friction( # TODO(team): tendon worldid, dofid = wp.tid() - if get_batched_value(m.dof_frictionloss, worldid, dofid) <= 0.0: + if m.dof_frictionloss[worldid, dofid] <= 0.0: return efcid = wp.atomic_add(d.nefc, 0, 1) @@ -253,12 +248,12 @@ def _efc_friction( efcid, 0.0, 0.0, - get_batched_value(m.dof_invweight0, worldid, dofid), - get_batched_value(m.dof_solref, worldid, dofid), - get_batched_value(m.dof_solimp, worldid, dofid), + m.dof_invweight0[worldid, dofid], + m.dof_solref[worldid, dofid], + m.dof_solimp[worldid, dofid], 0.0, Jqvel, - get_batched_value(m.dof_frictionloss, worldid, dofid), + m.dof_frictionloss[worldid, dofid], m.jnt_bodyid[m.dof_jntid[dofid]], ) @@ -283,7 +278,7 @@ def _efc_equality_weld( obj1id = m.eq_obj1id[i_eq] obj2id = m.eq_obj2id[i_eq] - data = get_batched_value(m.eq_data, worldid, i_eq) + data = m.eq_data[worldid, i_eq] anchor1 = wp.vec3(data[0], data[1], data[2]) anchor2 = wp.vec3(data[3], data[4], data[5]) relpose = wp.quat(data[6], data[7], data[8], data[9]) @@ -296,11 +291,11 @@ def _efc_equality_weld( pos1 = d.site_xpos[worldid, obj1id] pos2 = d.site_xpos[worldid, obj2id] quat = math.mul_quat( - d.xquat[worldid, body1id], get_batched_value(m.site_quat, worldid, obj1id) + d.xquat[worldid, body1id], m.site_quat[worldid, obj1id] ) quat1 = math.quat_inv( math.mul_quat( - d.xquat[worldid, body2id], get_batched_value(m.site_quat, worldid, obj2id) + d.xquat[worldid, body2id], m.site_quat[worldid, obj2id] ) ) @@ -341,14 +336,12 @@ def _efc_equality_weld( crotq = math.mul_quat(quat1, quat) # copy axis components crot = wp.vec3(crotq[1], crotq[2], crotq[3]) * torquescale - invweight_t = get_batched_value( - m.body_invweight0, worldid, body1id, 0 - ) + get_batched_value(m.body_invweight0, worldid, body2id, 0) + invweight_t = m.body_invweight0[worldid, body1id, 0] + m.body_invweight0[worldid, body2id, 0] pos_imp = wp.sqrt(wp.length_sq(cpos) + wp.length_sq(crot)) - solref = get_batched_value(m.eq_solref, worldid, i_eq) - solimp = get_batched_value(m.eq_solimp, worldid, i_eq) + solref = m.eq_solref[worldid, i_eq] + solimp = m.eq_solimp[worldid, i_eq] for i in range(3): _update_efc_row( @@ -366,9 +359,7 @@ def _efc_equality_weld( i_eq, ) - invweight_r = get_batched_value( - m.body_invweight0, worldid, body1id, 1 - ) + get_batched_value(m.body_invweight0, worldid, body2id, 1) + invweight_r = m.body_invweight0[worldid, body1id, 1] + m.body_invweight0[worldid, body2id, 1] for i in range(3): _update_efc_row( @@ -398,7 +389,7 @@ def _efc_limit_slide_hinge( qpos = d.qpos[worldid, m.jnt_qposadr[jntid]] jnt_range = m.jnt_range[worldid, jntid] dist_min, dist_max = qpos - jnt_range[0], jnt_range[1] - qpos - pos = wp.min(dist_min, dist_max) - get_batched_value(m.jnt_margin, worldid, jntid) + pos = wp.min(dist_min, dist_max) - m.jnt_margin[worldid, jntid] active = pos < 0 if active: @@ -418,10 +409,10 @@ def _efc_limit_slide_hinge( efcid, pos, pos, - get_batched_value(m.dof_invweight0, worldid, dofadr), - get_batched_value(m.jnt_solref, worldid, jntid), - get_batched_value(m.jnt_solimp, worldid, jntid), - get_batched_value(m.jnt_margin, worldid, jntid), + m.dof_invweight0[worldid, dofadr], + m.jnt_solref[worldid, jntid], + m.jnt_solimp[worldid, jntid], + m.jnt_margin[worldid, jntid], Jqvel, 0.0, dofadr, @@ -443,7 +434,7 @@ def _efc_limit_ball( ) axis_angle = math.quat_to_vel(jnt_quat) axis, angle = math.normalize_with_norm(axis_angle) - jnt_margin = get_batched_value(m.jnt_margin, worldid, jntid) + jnt_margin = m.jnt_margin[worldid, jntid] jnt_range = m.jnt_range[worldid, jntid] pos = wp.max(jnt_range[0], jnt_range[1]) - angle - jnt_margin @@ -470,10 +461,10 @@ def _efc_limit_ball( efcid, pos, pos, - get_batched_value(m.dof_invweight0, worldid, dofadr), - get_batched_value(m.jnt_solref, worldid, jntid), - get_batched_value(m.jnt_solimp, worldid, jntid), - get_batched_value(m.jnt_margin, worldid, jntid), + m.dof_invweight0[worldid, dofadr], + m.jnt_solref[worldid, jntid], + m.jnt_solimp[worldid, jntid], + m.jnt_margin[worldid, jntid], Jqvel, 0.0, jntid, @@ -488,10 +479,10 @@ def _efc_limit_tendon( worldid, tenlimitedid = wp.tid() tenid = m.tendon_limited_adr[tenlimitedid] - ten_range = get_batched_value(m.tendon_range, worldid, tenid) + ten_range = m.tendon_range[worldid, tenid] length = d.ten_length[worldid, tenid] dist_min, dist_max = length - ten_range[0], ten_range[1] - length - ten_margin = get_batched_value(m.tendon_margin, worldid, tenid) + ten_margin = m.tendon_margin[worldid, tenid] pos = wp.min(dist_min, dist_max) - ten_margin active = pos < 0 @@ -523,9 +514,9 @@ def _efc_limit_tendon( efcid, pos, pos, - get_batched_value(m.tendon_invweight0, worldid, tenid), - get_batched_value(m.tendon_solref_lim, worldid, tenid), - get_batched_value(m.tendon_solimp_lim, worldid, tenid), + m.tendon_invweight0[worldid, tenid], + m.tendon_solref_lim[worldid, tenid], + m.tendon_solimp_lim[worldid, tenid], ten_margin, Jqvel, 0.0, @@ -567,9 +558,7 @@ def _efc_contact_pyramidal( frame = d.contact.frame[conid] # pyramidal has common invweight across all edges - invweight = get_batched_value( - m.body_invweight0, worldid, body1, 0 - ) + get_batched_value(m.body_invweight0, worldid, body2, 0) + invweight = m.body_invweight0[worldid, body1, 0] + m.body_invweight0[worldid, body2, 0] if condim > 1: dimid2 = dimid / 2 + 1 @@ -670,9 +659,7 @@ def _efc_contact_elliptic( d.efc.J[efcid, i] = J Jqvel += J * d.qvel[worldid, i] - invweight = get_batched_value( - m.body_invweight0, worldid, body1, 0 - ) + get_batched_value(m.body_invweight0, worldid, body2, 0) + invweight = m.body_invweight0[worldid, body1, 0] + m.body_invweight0[worldid, body2, 0] ref = d.contact.solref[conid] pos_aref = pos From 6ed86cb7329462509601346916cd74a574f4f4a1 Mon Sep 17 00:00:00 2001 From: Alain Denzler Date: Wed, 30 Apr 2025 14:19:01 +0200 Subject: [PATCH 20/31] forward --- mujoco_warp/_src/forward.py | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/mujoco_warp/_src/forward.py b/mujoco_warp/_src/forward.py index d436850e..925c331e 100644 --- a/mujoco_warp/_src/forward.py +++ b/mujoco_warp/_src/forward.py @@ -25,7 +25,6 @@ from . import sensor from . import smooth from . import solver -from .support import get_batched_value from .support import xfrc_accumulate from .types import MJ_MINVAL from .types import BiasType @@ -129,7 +128,7 @@ def _next_act( # advance the actuation if m.actuator_dyntype[actid] == wp.static(DynType.FILTEREXACT.value): - dyn_prm = get_batched_value(m.actuator_dynprm, worldid, actid) + dyn_prm = m.actuator_dynprm[worldid, actid] tau = wp.max(MJ_MINVAL, dyn_prm[0]) act += act_dot * tau * (1.0 - wp.exp(-m.opt.timestep / tau)) else: @@ -137,7 +136,7 @@ def _next_act( # clamp to actrange if m.actuator_actlimited[actid]: - actrange = get_batched_value(m.actuator_actrange, worldid, actid) + actrange = m.actuator_actrange[worldid, actid] act = wp.clamp(act, actrange[0], actrange[1]) d.act[worldid, actid] = act @@ -194,9 +193,7 @@ def add_damping_sum_qfrc_kernel_sparse(m: Model, d: Data): worldid, tid = wp.tid() dof_Madr = m.dof_Madr[tid] - d.qM_integration[worldid, 0, dof_Madr] += m.opt.timestep * get_batched_value( - m.dof_damping, worldid, tid - ) + d.qM_integration[worldid, 0, dof_Madr] += m.opt.timestep * m.dof_damping[worldid, tid] d.qfrc_integration[worldid, tid] = ( d.qfrc_smooth[worldid, tid] + d.qfrc_constraint[worldid, tid] @@ -383,10 +380,10 @@ def actuator_bias_gain_vel(m: Model, d: Data): actuator_dyntype = m.actuator_dyntype[actid] if actuator_biastype == wp.static(BiasType.AFFINE.value): - bias_vel = get_batched_value(m.actuator_biasprm, worldid, actid)[2] + bias_vel = m.actuator_biasprm[worldid, actid][2] if actuator_gaintype == wp.static(GainType.AFFINE.value): - gain_vel = get_batched_value(m.actuator_gainprm, worldid, actid)[2] + gain_vel = m.actuator_gainprm[worldid, actid][2] ctrl = d.ctrl[worldid, actid] @@ -634,7 +631,7 @@ def _force(m: Model, d: Data): dsbl_clampctrl = m.opt.disableflags & wp.static(DisableBit.CLAMPCTRL.value) if m.actuator_ctrllimited[uid] and not dsbl_clampctrl: - r = get_batched_value(m.actuator_ctrlrange, worldid, uid) + r = m.actuator_ctrlrange[worldid, uid] ctrl = wp.clamp(ctrl, r[0], r[1]) if m.na: @@ -645,7 +642,7 @@ def _force(m: Model, d: Data): elif dyntype == int(DynType.FILTER.value) or dyntype == int( DynType.FILTEREXACT.value ): - dynprm = get_batched_value(m.actuator_dynprm, worldid, uid) + dynprm = m.actuator_dynprm[worldid, uid] actadr = m.actuator_actadr[uid] act = d.act[worldid, actadr] d.act_dot[worldid, actadr] = (ctrl - act) / wp.max(dynprm[0], MJ_MINVAL) @@ -664,7 +661,7 @@ def _force(m: Model, d: Data): # gain gaintype = m.actuator_gaintype[uid] - gainprm = get_batched_value(m.actuator_gainprm, worldid, uid) + gainprm = m.actuator_gainprm[worldid, uid] gain = 0.0 if gaintype == int(GainType.FIXED.value): @@ -676,7 +673,7 @@ def _force(m: Model, d: Data): # bias biastype = m.actuator_biastype[uid] - biasprm = get_batched_value(m.actuator_biasprm, worldid, uid) + biasprm = m.actuator_biasprm[worldid, uid] bias = 0.0 # BiasType.NONE if biastype == int(BiasType.AFFINE.value): @@ -689,7 +686,7 @@ def _force(m: Model, d: Data): # TODO(team): tendon total force clamping if m.actuator_forcelimited[uid]: - r = get_batched_value(m.actuator_forcerange, worldid, uid) + r = m.actuator_forcerange[worldid, uid] f = wp.clamp(f, r[0], r[1]) d.actuator_force[worldid, uid] = f @@ -698,7 +695,7 @@ def _qfrc_limited(m: Model, d: Data): worldid, dofid = wp.tid() jntid = m.dof_jntid[dofid] if m.jnt_actfrclimited[jntid]: - range = get_batched_value(m.jnt_actfrcrange, worldid, jntid) + range = m.jnt_actfrcrange[worldid, jntid] d.qfrc_actuator[worldid, dofid] = wp.clamp( d.qfrc_actuator[worldid, dofid], range[0], @@ -717,7 +714,7 @@ def _qfrc(m: Model, moment: array3df, force: array2df, qfrc: array2df): s += moment[worldid, uid, vid] * force[worldid, uid] jntid = m.dof_jntid[vid] if m.jnt_actfrclimited[jntid]: - range = get_batched_value(m.jnt_actfrcrange, worldid, jntid) + range = m.jnt_actfrcrange[worldid, jntid] s = wp.clamp(s, range[0], range[1]) qfrc[worldid, vid] = s From 72fd650e1b2338f22ef5230fefde5b653825f957 Mon Sep 17 00:00:00 2001 From: Alain Denzler Date: Wed, 30 Apr 2025 14:22:01 +0200 Subject: [PATCH 21/31] passive --- mujoco_warp/_src/passive.py | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/mujoco_warp/_src/passive.py b/mujoco_warp/_src/passive.py index 39c3eb6f..a2e7c3cb 100644 --- a/mujoco_warp/_src/passive.py +++ b/mujoco_warp/_src/passive.py @@ -16,7 +16,6 @@ import warp as wp from . import math -from .support import get_batched_value from .types import Data from .types import DisableBit from .types import JointType @@ -36,7 +35,7 @@ def passive(m: Model, d: Data): @kernel def _spring(m: Model, d: Data): worldid, jntid = wp.tid() - stiffness = get_batched_value(m.jnt_stiffness, worldid, jntid) + stiffness = m.jnt_stiffness[worldid, jntid] dofid = m.jnt_dofadr[jntid] if stiffness == 0.0: @@ -48,11 +47,11 @@ def _spring(m: Model, d: Data): if jnt_type == wp.static(JointType.FREE.value): dif = wp.vec3( d.qpos[worldid, qposid + 0] - - get_batched_value(m.qpos_spring, worldid, qposid + 0), + - m.qpos_spring[worldid, qposid + 0], d.qpos[worldid, qposid + 1] - - get_batched_value(m.qpos_spring, worldid, qposid + 1), + - m.qpos_spring[worldid, qposid + 1], d.qpos[worldid, qposid + 2] - - get_batched_value(m.qpos_spring, worldid, qposid + 2), + - m.qpos_spring[worldid, qposid + 2], ) d.qfrc_spring[worldid, dofid + 0] = -stiffness * dif[0] d.qfrc_spring[worldid, dofid + 1] = -stiffness * dif[1] @@ -64,10 +63,10 @@ def _spring(m: Model, d: Data): d.qpos[worldid, qposid + 6], ) ref = wp.quat( - get_batched_value(m.qpos_spring, worldid, qposid + 3), - get_batched_value(m.qpos_spring, worldid, qposid + 4), - get_batched_value(m.qpos_spring, worldid, qposid + 5), - get_batched_value(m.qpos_spring, worldid, qposid + 6), + m.qpos_spring[worldid, qposid + 3], + m.qpos_spring[worldid, qposid + 4], + m.qpos_spring[worldid, qposid + 5], + m.qpos_spring[worldid, qposid + 6], ) dif = math.quat_sub(rot, ref) d.qfrc_spring[worldid, dofid + 3] = -stiffness * dif[0] @@ -81,23 +80,23 @@ def _spring(m: Model, d: Data): d.qpos[worldid, qposid + 3], ) ref = wp.quat( - get_batched_value(m.qpos_spring, worldid, qposid + 0), - get_batched_value(m.qpos_spring, worldid, qposid + 1), - get_batched_value(m.qpos_spring, worldid, qposid + 2), - get_batched_value(m.qpos_spring, worldid, qposid + 3), + m.qpos_spring[worldid, qposid + 0], + m.qpos_spring[worldid, qposid + 1], + m.qpos_spring[worldid, qposid + 2], + m.qpos_spring[worldid, qposid + 3], ) dif = math.quat_sub(rot, ref) d.qfrc_spring[worldid, dofid + 0] = -stiffness * dif[0] d.qfrc_spring[worldid, dofid + 1] = -stiffness * dif[1] d.qfrc_spring[worldid, dofid + 2] = -stiffness * dif[2] else: # mjJNT_SLIDE, mjJNT_HINGE - fdif = d.qpos[worldid, qposid] - get_batched_value(m.qpos_spring, worldid, qposid) + fdif = d.qpos[worldid, qposid] - m.qpos_spring[worldid, qposid] d.qfrc_spring[worldid, dofid] = -stiffness * fdif @kernel def _damper_passive(m: Model, d: Data): worldid, dofid = wp.tid() - damping = get_batched_value(m.dof_damping, worldid, dofid) + damping = m.dof_damping[worldid, dofid] qfrc_damper = -damping * d.qvel[worldid, dofid] d.qfrc_damper[worldid, dofid] = qfrc_damper From 991d6ddb6bad92b7112fcaa699c2ea47a9ac2c93 Mon Sep 17 00:00:00 2001 From: Alain Denzler Date: Wed, 30 Apr 2025 14:22:36 +0200 Subject: [PATCH 22/31] sensor --- mujoco_warp/_src/sensor.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/mujoco_warp/_src/sensor.py b/mujoco_warp/_src/sensor.py index ed1e611d..a547d976 100644 --- a/mujoco_warp/_src/sensor.py +++ b/mujoco_warp/_src/sensor.py @@ -20,7 +20,6 @@ from . import math from . import smooth -from .support import get_batched_value from .types import Data from .types import DisableBit from .types import Model @@ -138,41 +137,41 @@ def _frame_quat( ) -> wp.quat: if objtype == int(ObjType.BODY.value): quat = math.mul_quat( - d.xquat[worldid, objid], get_batched_value(m.body_iquat, worldid, objid) + d.xquat[worldid, objid], m.body_iquat[worldid, objid] ) if refid == -1: return quat refquat = math.mul_quat( - d.xquat[worldid, refid], get_batched_value(m.body_iquat, worldid, refid) + d.xquat[worldid, refid], m.body_iquat[worldid, refid] ) elif objtype == int(ObjType.XBODY.value): quat = d.xquat[worldid, objid] if refid == -1: return quat refquat = math.mul_quat( - d.xquat[worldid, refid], get_batched_value(m.body_iquat, worldid, refid) + d.xquat[worldid, refid], m.body_iquat[worldid, refid] ) elif objtype == int(ObjType.GEOM.value): quat = math.mul_quat( d.xquat[worldid, m.geom_bodyid[objid]], - get_batched_value(m.geom_quat, worldid, objid), + m.geom_quat[worldid, objid], ) if refid == -1: return quat refquat = math.mul_quat( d.xquat[worldid, m.geom_bodyid[refid]], - get_batched_value(m.geom_quat, worldid, refid), + m.geom_quat[worldid, refid], ) elif objtype == int(ObjType.SITE.value): quat = math.mul_quat( d.xquat[worldid, m.site_bodyid[objid]], - get_batched_value(m.site_quat, worldid, objid), + m.site_quat[worldid, objid], ) if refid == -1: return quat refquat = math.mul_quat( d.xquat[worldid, m.site_bodyid[refid]], - get_batched_value(m.site_quat, worldid, refid), + m.site_quat[worldid, refid], ) # TODO(team): camera From 40fac973804c118318199fc96371fdb5967800f8 Mon Sep 17 00:00:00 2001 From: Alain Denzler Date: Wed, 30 Apr 2025 14:24:11 +0200 Subject: [PATCH 23/31] smooth --- mujoco_warp/_src/smooth.py | 65 +++++++++++++++++--------------------- 1 file changed, 29 insertions(+), 36 deletions(-) diff --git a/mujoco_warp/_src/smooth.py b/mujoco_warp/_src/smooth.py index 075b139c..37e47029 100644 --- a/mujoco_warp/_src/smooth.py +++ b/mujoco_warp/_src/smooth.py @@ -19,7 +19,6 @@ from . import math from . import support -from .support import get_batched_value from .types import MJ_MINVAL from .types import CamLightType from .types import Data @@ -63,10 +62,10 @@ def _level(m: Model, d: Data, leveladr: int): # no joints - apply fixed translation and rotation relative to parent pid = m.body_parentid[bodyid] xpos = ( - d.xmat[worldid, pid] * get_batched_value(m.body_pos, worldid, bodyid) + d.xmat[worldid, pid] * m.body_pos[worldid, bodyid] ) + d.xpos[worldid, pid] xquat = math.mul_quat( - d.xquat[worldid, pid], get_batched_value(m.body_quat, worldid, bodyid) + d.xquat[worldid, pid], m.body_quat[worldid, bodyid] ) elif jntnum == 1 and m.jnt_type[jntadr] == wp.static(JointType.FREE.value): # free joint @@ -80,10 +79,10 @@ def _level(m: Model, d: Data, leveladr: int): # apply fixed translation and rotation relative to parent pid = m.body_parentid[bodyid] xpos = ( - d.xmat[worldid, pid] * get_batched_value(m.body_pos, worldid, bodyid) + d.xmat[worldid, pid] * m.body_pos[worldid, bodyid] ) + d.xpos[worldid, pid] xquat = math.mul_quat( - d.xquat[worldid, pid], get_batched_value(m.body_quat, worldid, bodyid) + d.xquat[worldid, pid], m.body_quat[worldid, bodyid] ) for _ in range(jntnum): @@ -91,7 +90,7 @@ def _level(m: Model, d: Data, leveladr: int): jnt_type = m.jnt_type[jntadr] jnt_axis = m.jnt_axis[jntadr] xanchor = ( - math.rot_vec_quat(get_batched_value(m.jnt_pos, worldid, jntadr), xquat) + xpos + math.rot_vec_quat(m.jnt_pos[worldid, jntadr], xquat) + xpos ) xaxis = math.rot_vec_quat(jnt_axis, xquat) @@ -105,17 +104,17 @@ def _level(m: Model, d: Data, leveladr: int): xquat = math.mul_quat(xquat, qloc) # correct for off-center rotation xpos = xanchor - math.rot_vec_quat( - get_batched_value(m.jnt_pos, worldid, jntadr), xquat + m.jnt_pos[worldid, jntadr], xquat ) elif jnt_type == wp.static(JointType.SLIDE.value): - xpos += xaxis * (qpos[qadr] - get_batched_value(m.qpos0, worldid, qadr)) + xpos += xaxis * (qpos[qadr] - m.qpos0[worldid, qadr]) elif jnt_type == wp.static(JointType.HINGE.value): - qpos0 = get_batched_value(m.qpos0, worldid, qadr) + qpos0 = m.qpos0[worldid, qadr] qloc = math.axis_angle_to_quat(jnt_axis, qpos[qadr] - qpos0) xquat = math.mul_quat(xquat, qloc) # correct for off-center rotation xpos = xanchor - math.rot_vec_quat( - get_batched_value(m.jnt_pos, worldid, jntadr), xquat + m.jnt_pos[worldid, jntadr], xquat ) d.xanchor[worldid, jntadr] = xanchor @@ -127,10 +126,10 @@ def _level(m: Model, d: Data, leveladr: int): d.xquat[worldid, bodyid] = xquat d.xmat[worldid, bodyid] = math.quat_to_mat(xquat) d.xipos[worldid, bodyid] = xpos + math.rot_vec_quat( - get_batched_value(m.body_ipos, worldid, bodyid), xquat + m.body_ipos[worldid, bodyid], xquat ) d.ximat[worldid, bodyid] = math.quat_to_mat( - math.mul_quat(xquat, get_batched_value(m.body_iquat, worldid, bodyid)) + math.mul_quat(xquat, m.body_iquat[worldid, bodyid]) ) @kernel @@ -140,10 +139,10 @@ def geom_local_to_global(m: Model, d: Data): xpos = d.xpos[worldid, bodyid] xquat = d.xquat[worldid, bodyid] d.geom_xpos[worldid, geomid] = xpos + math.rot_vec_quat( - get_batched_value(m.geom_pos, worldid, geomid), xquat + m.geom_pos[worldid, geomid], xquat ) d.geom_xmat[worldid, geomid] = math.quat_to_mat( - math.mul_quat(xquat, get_batched_value(m.geom_quat, worldid, geomid)) + math.mul_quat(xquat, m.geom_quat[worldid, geomid]) ) @kernel @@ -153,10 +152,10 @@ def site_local_to_global(m: Model, d: Data): xpos = d.xpos[worldid, bodyid] xquat = d.xquat[worldid, bodyid] d.site_xpos[worldid, siteid] = xpos + math.rot_vec_quat( - get_batched_value(m.site_pos, worldid, siteid), xquat + m.site_pos[worldid, siteid], xquat ) d.site_xmat[worldid, siteid] = math.quat_to_mat( - math.mul_quat(xquat, get_batched_value(m.site_quat, worldid, siteid)) + math.mul_quat(xquat, m.site_quat[worldid, siteid]) ) wp.launch(_root, dim=(d.nworld), inputs=[m, d]) @@ -181,9 +180,7 @@ def com_pos(m: Model, d: Data): @kernel def subtree_com_init(m: Model, d: Data): worldid, bodyid = wp.tid() - d.subtree_com[worldid, bodyid] = d.xipos[worldid, bodyid] * get_batched_value( - m.body_mass, worldid, bodyid - ) + d.subtree_com[worldid, bodyid] = d.xipos[worldid, bodyid] * m.body_mass[worldid, bodyid] @kernel def subtree_com_acc(m: Model, d: Data, leveladr: int): @@ -195,14 +192,14 @@ def subtree_com_acc(m: Model, d: Data, leveladr: int): @kernel def subtree_div(m: Model, d: Data): worldid, bodyid = wp.tid() - d.subtree_com[worldid, bodyid] /= get_batched_value(m.subtree_mass, worldid, bodyid) + d.subtree_com[worldid, bodyid] /= m.subtree_mass[worldid, bodyid] @kernel def cinert(m: Model, d: Data): worldid, bodyid = wp.tid() mat = d.ximat[worldid, bodyid] - inert = get_batched_value(m.body_inertia, worldid, bodyid) - mass = get_batched_value(m.body_mass, worldid, bodyid) + inert = m.body_inertia[worldid, bodyid] + mass = m.body_mass[worldid, bodyid] dif = d.xipos[worldid, bodyid] - d.subtree_com[worldid, m.body_rootid[bodyid]] # express inertia in com-based frame (mju_inertCom) @@ -288,10 +285,10 @@ def cam_local_to_global(m: Model, d: Data): xpos = d.xpos[worldid, bodyid] xquat = d.xquat[worldid, bodyid] d.cam_xpos[worldid, camid] = xpos + math.rot_vec_quat( - get_batched_value(m.cam_pos, worldid, camid), xquat + m.cam_pos[worldid, camid], xquat ) d.cam_xmat[worldid, camid] = math.quat_to_mat( - math.mul_quat(xquat, get_batched_value(m.cam_quat, worldid, camid)) + math.mul_quat(xquat, m.cam_quat[worldid, camid]) ) @kernel @@ -305,13 +302,11 @@ def cam_fn(m: Model, d: Data): return elif m.cam_mode[camid] == wp.static(CamLightType.TRACK.value): body_xpos = d.xpos[worldid, m.cam_bodyid[camid]] - d.cam_xpos[worldid, camid] = body_xpos + get_batched_value( - m.cam_pos0, worldid, camid - ) + d.cam_xpos[worldid, camid] = body_xpos + m.cam_pos0[worldid, camid] elif m.cam_mode[camid] == wp.static(CamLightType.TRACKCOM.value): d.cam_xpos[worldid, camid] = d.subtree_com[ worldid, m.cam_bodyid[camid] - ] + get_batched_value(m.cam_poscom0, worldid, camid) + ] + m.cam_poscom0[worldid, camid] elif m.cam_mode[camid] == wp.static(CamLightType.TARGETBODY.value) or m.cam_mode[ camid ] == wp.static(CamLightType.TARGETBODYCOM.value): @@ -339,10 +334,10 @@ def light_local_to_global(m: Model, d: Data): xpos = d.xpos[worldid, bodyid] xquat = d.xquat[worldid, bodyid] d.light_xpos[worldid, lightid] = xpos + math.rot_vec_quat( - get_batched_value(m.light_pos, worldid, lightid), xquat + m.light_pos[worldid, lightid], xquat ) d.light_xdir[worldid, lightid] = math.rot_vec_quat( - get_batched_value(m.light_dir, worldid, lightid), xquat + m.light_dir[worldid, lightid], xquat ) @kernel @@ -356,13 +351,11 @@ def light_fn(m: Model, d: Data): return elif m.light_mode[lightid] == wp.static(CamLightType.TRACK.value): body_xpos = d.xpos[worldid, m.light_bodyid[lightid]] - d.light_xpos[worldid, lightid] = body_xpos + get_batched_value( - m.light_pos0, worldid, lightid - ) + d.light_xpos[worldid, lightid] = body_xpos + m.light_pos0[worldid, lightid] elif m.light_mode[lightid] == wp.static(CamLightType.TRACKCOM.value): d.light_xpos[worldid, lightid] = d.subtree_com[ worldid, m.light_bodyid[lightid] - ] + get_batched_value(m.light_poscom0, worldid, lightid) + ] + m.light_poscom0[worldid, lightid] elif m.light_mode[lightid] == wp.static( CamLightType.TARGETBODY.value ) or m.light_mode[lightid] == wp.static(CamLightType.TARGETBODYCOM.value): @@ -402,7 +395,7 @@ def qM_sparse(m: Model, d: Data): bodyid = m.dof_bodyid[dofid] # init M(i,i) with armature inertia - d.qM[worldid, 0, madr_ij] = get_batched_value(m.dof_armature, worldid, dofid) + d.qM[worldid, 0, madr_ij] = m.dof_armature[worldid, dofid] # precompute buf = crb_body_i * cdof_i buf = math.inert_vec(d.crb[worldid, bodyid], d.cdof[worldid, dofid]) @@ -419,7 +412,7 @@ def qM_dense(m: Model, d: Data): bodyid = m.dof_bodyid[dofid] # init M(i,i) with armature inertia - M = get_batched_value(m.dof_armature, worldid, dofid) + M = m.dof_armature[worldid, dofid] # precompute buf = crb_body_i * cdof_i buf = math.inert_vec(d.crb[worldid, bodyid], d.cdof[worldid, dofid]) From b8bf457554b054f9104743fa0f478d836a655dbb Mon Sep 17 00:00:00 2001 From: Alain Denzler Date: Wed, 30 Apr 2025 14:25:36 +0200 Subject: [PATCH 24/31] smooth part 2 --- mujoco_warp/_src/smooth.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/mujoco_warp/_src/smooth.py b/mujoco_warp/_src/smooth.py index 37e47029..4bd8feca 100644 --- a/mujoco_warp/_src/smooth.py +++ b/mujoco_warp/_src/smooth.py @@ -713,7 +713,7 @@ def _cfrc_ext_equality(m: Model, d: Data): worldid = d.efc.worldid[efcid] id = d.efc.id[efcid] - eq_data = get_batched_value(m.eq_data, worldid, id) + eq_data = m.eq_data[worldid, id] body_semantic = m.eq_objtype[id] == wp.static(ObjType.BODY.value) obj1 = m.eq_obj1id[id] @@ -734,7 +734,7 @@ def _cfrc_ext_equality(m: Model, d: Data): else: offset = wp.vec3(eq_data[3], eq_data[4], eq_data[5]) else: - offset = get_batched_value(m.site_pos, worldid, obj1) + offset = m.site_pos[worldid, obj1] # transform point on body1: local -> global pos = d.xmat[worldid, bodyid1] @ offset + d.xpos[worldid, bodyid1] @@ -756,7 +756,7 @@ def _cfrc_ext_equality(m: Model, d: Data): else: offset = wp.vec3(eq_data[0], eq_data[1], eq_data[2]) else: - offset = get_batched_value(m.site_pos, worldid, obj2) + offset = m.site_pos[worldid, obj2] # transform point on body2: local -> global pos = d.xmat[worldid, bodyid2] @ offset + d.xpos[worldid, bodyid2] @@ -835,7 +835,7 @@ def _transmission( ): worldid, actid = wp.tid() trntype = m.actuator_trntype[actid] - gear = get_batched_value(m.actuator_gear, worldid, actid) + gear = m.actuator_gear[worldid, actid] if trntype == wp.static(TrnType.JOINT.value) or trntype == wp.static( TrnType.JOINTINPARENT.value ): @@ -1131,12 +1131,12 @@ def _forward(m: Model, d: Data): lin -= wp.cross(xipos - subtree_com_root, ang) d.subtree_linvel[worldid, bodyid] = ( - get_batched_value(m.body_mass, worldid, bodyid) * lin + m.body_mass[worldid, bodyid] * lin ) dv = wp.transpose(ximat) @ ang - dv[0] *= get_batched_value(m.body_inertia, worldid, bodyid)[0] - dv[1] *= get_batched_value(m.body_inertia, worldid, bodyid)[1] - dv[2] *= get_batched_value(m.body_inertia, worldid, bodyid)[2] + dv[0] *= m.body_inertia[worldid, bodyid][0] + dv[1] *= m.body_inertia[worldid, bodyid][1] + dv[2] *= m.body_inertia[worldid, bodyid][2] d.subtree_angmom[worldid, bodyid] = ximat @ dv d.subtree_bodyvel[worldid, bodyid] = wp.spatial_vector(ang, lin) @@ -1174,7 +1174,7 @@ def _angular_momentum(m: Model, d: Data, leveladr: int): vel = d.subtree_bodyvel[worldid, bodyid] linvel = d.subtree_linvel[worldid, bodyid] linvel_parent = d.subtree_linvel[worldid, pid] - mass = get_batched_value(m.body_mass, worldid, bodyid) + mass = m.body_mass[worldid, bodyid] subtreemass = m.body_subtreemass[bodyid] # momentum wrt body i @@ -1223,7 +1223,7 @@ def _joint_tendon(m: Model, d: Data): wrap_jnt_adr = m.wrap_jnt_adr[wrapid] wrap_objid = m.wrap_objid[wrap_jnt_adr] - prm = get_batched_value(m.wrap_prm, worldid, wrap_jnt_adr) + prm = m.wrap_prm[worldid, wrap_jnt_adr] # add to length L = prm * d.qpos[worldid, m.jnt_qposadr[wrap_objid]] From fdbc94b05e8114a32112d3d9d96e22384bc9c490 Mon Sep 17 00:00:00 2001 From: Alain Denzler Date: Wed, 30 Apr 2025 14:25:49 +0200 Subject: [PATCH 25/31] remove infra --- mujoco_warp/_src/support.py | 22 ---------------------- 1 file changed, 22 deletions(-) diff --git a/mujoco_warp/_src/support.py b/mujoco_warp/_src/support.py index 54b816d5..1d8d7de7 100644 --- a/mujoco_warp/_src/support.py +++ b/mujoco_warp/_src/support.py @@ -349,25 +349,3 @@ def jac( jacr = cdof_ang return jacp, jacr - - - -@wp.func -def get_batched_value(array: wp.array2d(dtype=Any), worldid: wp.int32, i: wp.int32): - modelid = worldid % array.shape[0] - return array[modelid, i] - -@wp.func -def get_batched_value( - array: wp.array3d(dtype=Any), worldid: wp.int32, i: wp.int32, j: wp.int32 -): - modelid = worldid % array.shape[0] - return array[modelid, i, j] - - -@wp.func -def get_batched_value( - array: wp.array4d(dtype=Any), worldid: wp.int32, i: wp.int32, j: wp.int32, k: wp.int32 -): - modelid = worldid % array.shape[0] - return array[modelid, i, j, k] From 55389b51538645461d8dacd4a4fb177e0ad59966 Mon Sep 17 00:00:00 2001 From: Alain Denzler Date: Wed, 30 Apr 2025 14:34:54 +0200 Subject: [PATCH 26/31] fix an issue in gjk --- mujoco_warp/_src/collision_convex.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mujoco_warp/_src/collision_convex.py b/mujoco_warp/_src/collision_convex.py index 3cdceafa..87a1eda9 100644 --- a/mujoco_warp/_src/collision_convex.py +++ b/mujoco_warp/_src/collision_convex.py @@ -754,7 +754,7 @@ def _gjk_epa_sparse(m: Model, d: Data): worldid = d.collision_worldid[tid] geoms, margin, gap, condim, friction, solref, solreffriction, solimp = ( - contact_params(m, d, tid) + contact_params(m, d, tid, worldid) ) g1 = geoms[0] From badb9195b8798289a0b765348d1635d9635b6cfc Mon Sep 17 00:00:00 2001 From: Alain Denzler Date: Wed, 30 Apr 2025 14:42:45 +0200 Subject: [PATCH 27/31] subtree mass --- mujoco_warp/_src/io.py | 2 +- mujoco_warp/_src/smooth.py | 4 ++-- mujoco_warp/_src/types.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/mujoco_warp/_src/io.py b/mujoco_warp/_src/io.py index f8fd9c72..258021cd 100644 --- a/mujoco_warp/_src/io.py +++ b/mujoco_warp/_src/io.py @@ -375,7 +375,7 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: m.body_rootid = wp.array(mjm.body_rootid, dtype=wp.int32, ndim=1) m.body_inertia = create_nmodel_batched_array(mjm.body_inertia, dtype=wp.vec3) m.body_mass = create_nmodel_batched_array(mjm.body_mass, dtype=wp.float32) - m.body_subtreemass = wp.array(mjm.body_subtreemass, dtype=wp.float32, ndim=1) + m.body_subtreemass = create_nmodel_batched_array(mjm.body_subtreemass, dtype=wp.float32) subtree_mass = np.copy(mjm.body_mass) # TODO(team): should this be [mjm.nbody - 1, 0) ? diff --git a/mujoco_warp/_src/smooth.py b/mujoco_warp/_src/smooth.py index 4bd8feca..e338ea9c 100644 --- a/mujoco_warp/_src/smooth.py +++ b/mujoco_warp/_src/smooth.py @@ -1150,7 +1150,7 @@ def _linear_momentum(m: Model, d: Data, leveladr: int): if bodyid: pid = m.body_parentid[bodyid] wp.atomic_add(d.subtree_linvel[worldid], pid, d.subtree_linvel[worldid, bodyid]) - d.subtree_linvel[worldid, bodyid] /= wp.max(MJ_MINVAL, m.body_subtreemass[bodyid]) + d.subtree_linvel[worldid, bodyid] /= wp.max(MJ_MINVAL, m.body_subtreemass[worldid,bodyid]) body_treeadr = m.body_treeadr.numpy() for i in reversed(range(len(body_treeadr))): @@ -1175,7 +1175,7 @@ def _angular_momentum(m: Model, d: Data, leveladr: int): linvel = d.subtree_linvel[worldid, bodyid] linvel_parent = d.subtree_linvel[worldid, pid] mass = m.body_mass[worldid, bodyid] - subtreemass = m.body_subtreemass[bodyid] + subtreemass = m.body_subtreemass[worldid, bodyid] # momentum wrt body i dx = xipos - com diff --git a/mujoco_warp/_src/types.py b/mujoco_warp/_src/types.py index c39f0dbd..ec0dd993 100644 --- a/mujoco_warp/_src/types.py +++ b/mujoco_warp/_src/types.py @@ -603,7 +603,7 @@ class Model: body_ipos: local position of center of mass (nmodel, nbody, 3) body_iquat: local orientation of inertia ellipsoid (nmodel, nbody, 4) body_mass: mass (nmodel, nbody,) - body_subtreemass: mass of subtree starting at this body (nbody,) + body_subtreemass: mass of subtree starting at this body (nmodel, nbody,) subtree_mass: mass of subtree (nmodel, nbody,) body_inertia: diagonal inertia in ipos/iquat frame (nmodel, nbody, 3) body_invweight0: mean inv inert in qpos0 (trn, rot) (nmodel, nbody, 2) @@ -815,7 +815,7 @@ class Model: body_ipos: wp.array(dtype=wp.vec3, ndim=2) body_iquat: wp.array(dtype=wp.quat, ndim=2) body_mass: wp.array(dtype=wp.float32, ndim=2) - body_subtreemass: wp.array(dtype=wp.float32, ndim=1) + body_subtreemass: wp.array(dtype=wp.float32, ndim=2) subtree_mass: wp.array(dtype=wp.float32, ndim=2) body_inertia: wp.array(dtype=wp.vec3, ndim=2) body_invweight0: wp.array(dtype=wp.float32, ndim=3) From 25369ea7a8c4a68644801a9082cb91610fac4444 Mon Sep 17 00:00:00 2001 From: Alain Denzler Date: Wed, 30 Apr 2025 14:44:59 +0200 Subject: [PATCH 28/31] geom for body --- mujoco_warp/_src/io.py | 2 +- mujoco_warp/_src/types.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mujoco_warp/_src/io.py b/mujoco_warp/_src/io.py index 258021cd..2f2c83a9 100644 --- a/mujoco_warp/_src/io.py +++ b/mujoco_warp/_src/io.py @@ -385,7 +385,7 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: m.subtree_mass = create_nmodel_batched_array(subtree_mass, dtype=wp.float32) m.body_invweight0 = create_nmodel_batched_array(mjm.body_invweight0, dtype=wp.float32) m.body_geomnum = wp.array(mjm.body_geomnum, dtype=wp.int32, ndim=1) - m.body_geomadr = wp.array(mjm.body_geomadr, dtype=wp.int32, ndim=1) + m.body_geomadr = create_nmodel_batched_array(mjm.body_geomadr, dtype=wp.int32) m.body_contype = wp.array(mjm.body_contype, dtype=wp.int32, ndim=1) m.body_conaffinity = wp.array(mjm.body_conaffinity, dtype=wp.int32, ndim=1) m.jnt_bodyid = wp.array(mjm.jnt_bodyid, dtype=wp.int32, ndim=1) diff --git a/mujoco_warp/_src/types.py b/mujoco_warp/_src/types.py index ec0dd993..ffb79102 100644 --- a/mujoco_warp/_src/types.py +++ b/mujoco_warp/_src/types.py @@ -597,7 +597,7 @@ class Model: body_dofnum: number of motion degrees of freedom (nbody,) body_dofadr: start addr of dofs; -1: no dofs (nbody,) body_geomnum: number of geoms (nbody,) - body_geomadr: start addr of geoms; -1: no geoms (nbody,) + body_geomadr: start addr of geoms; -1: no geoms (nmodel, nbody) body_pos: position offset rel. to parent body (nmodel, nbody, 3) body_quat: orientation offset rel. to parent body (nmodel, nbody, 4) body_ipos: local position of center of mass (nmodel, nbody, 3) @@ -809,7 +809,7 @@ class Model: body_dofnum: wp.array(dtype=wp.int32, ndim=1) body_dofadr: wp.array(dtype=wp.int32, ndim=1) body_geomnum: wp.array(dtype=wp.int32, ndim=1) - body_geomadr: wp.array(dtype=wp.int32, ndim=1) + body_geomadr: wp.array(dtype=wp.int32, ndim=2) body_pos: wp.array(dtype=wp.vec3, ndim=2) body_quat: wp.array(dtype=wp.quat, ndim=2) body_ipos: wp.array(dtype=wp.vec3, ndim=2) From 00ea73bde9e356df276cc1f999110837b3ff4327 Mon Sep 17 00:00:00 2001 From: Alain Denzler Date: Wed, 30 Apr 2025 14:50:23 +0200 Subject: [PATCH 29/31] geomadr --- mujoco_warp/_src/collision_primitive.py | 2 +- mujoco_warp/_src/io.py | 4 ++-- mujoco_warp/_src/types.py | 8 ++++---- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/mujoco_warp/_src/collision_primitive.py b/mujoco_warp/_src/collision_primitive.py index b39e95c7..3bd8eff0 100644 --- a/mujoco_warp/_src/collision_primitive.py +++ b/mujoco_warp/_src/collision_primitive.py @@ -52,7 +52,7 @@ def _geom( geom.rot = rot geom.size = m.geom_size[worldid, gid] geom.normal = wp.vec3(rot[0, 2], rot[1, 2], rot[2, 2]) # plane - dataid = m.geom_dataid[gid] + dataid = m.geom_dataid[worldid, gid] if dataid >= 0: geom.vertadr = m.mesh_vertadr[dataid] geom.vertnum = m.mesh_vertnum[dataid] diff --git a/mujoco_warp/_src/io.py b/mujoco_warp/_src/io.py index 2f2c83a9..683c79c1 100644 --- a/mujoco_warp/_src/io.py +++ b/mujoco_warp/_src/io.py @@ -385,7 +385,7 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: m.subtree_mass = create_nmodel_batched_array(subtree_mass, dtype=wp.float32) m.body_invweight0 = create_nmodel_batched_array(mjm.body_invweight0, dtype=wp.float32) m.body_geomnum = wp.array(mjm.body_geomnum, dtype=wp.int32, ndim=1) - m.body_geomadr = create_nmodel_batched_array(mjm.body_geomadr, dtype=wp.int32) + m.body_geomadr = wp.array(mjm.body_geomadr, dtype=wp.int32, ndim=1) m.body_contype = wp.array(mjm.body_contype, dtype=wp.int32, ndim=1) m.body_conaffinity = wp.array(mjm.body_conaffinity, dtype=wp.int32, ndim=1) m.jnt_bodyid = wp.array(mjm.jnt_bodyid, dtype=wp.int32, ndim=1) @@ -423,7 +423,7 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: m.geom_gap = create_nmodel_batched_array(mjm.geom_gap, dtype=wp.float32) m.geom_aabb = wp.array(mjm.geom_aabb, dtype=wp.vec3, ndim=3) m.geom_rbound = wp.array(mjm.geom_rbound, dtype=wp.float32, ndim=1) - m.geom_dataid = wp.array(mjm.geom_dataid, dtype=wp.int32, ndim=1) + m.geom_dataid = create_nmodel_batched_array(mjm.geom_dataid, dtype=wp.int32) m.mesh_vertadr = wp.array(mjm.mesh_vertadr, dtype=wp.int32, ndim=1) m.mesh_vertnum = wp.array(mjm.mesh_vertnum, dtype=wp.int32, ndim=1) m.mesh_vert = wp.array(mjm.mesh_vert, dtype=wp.vec3, ndim=1) diff --git a/mujoco_warp/_src/types.py b/mujoco_warp/_src/types.py index ffb79102..57a638eb 100644 --- a/mujoco_warp/_src/types.py +++ b/mujoco_warp/_src/types.py @@ -597,7 +597,7 @@ class Model: body_dofnum: number of motion degrees of freedom (nbody,) body_dofadr: start addr of dofs; -1: no dofs (nbody,) body_geomnum: number of geoms (nbody,) - body_geomadr: start addr of geoms; -1: no geoms (nmodel, nbody) + body_geomadr: start addr of geoms; -1: no geoms (nbody,) body_pos: position offset rel. to parent body (nmodel, nbody, 3) body_quat: orientation offset rel. to parent body (nmodel, nbody, 4) body_ipos: local position of center of mass (nmodel, nbody, 3) @@ -642,7 +642,7 @@ class Model: geom_conaffinity: geom contact affinity (nmodel, ngeom,) geom_condim: contact dimensionality (1, 3, 4, 6) (ngeom,) geom_bodyid: id of geom's body (ngeom,) - geom_dataid: id of geom's mesh/hfield; -1: none (ngeom,) + geom_dataid: id of geom's mesh/hfield; -1: none (nmodel, ngeom) geom_priority: geom contact priority (nmodel, ngeom) geom_solmix: mixing coef for solref/imp in geom pair (nmodel, ngeom,) geom_solref: constraint solver reference: contact (nmodel, ngeom, mjNREF) @@ -809,7 +809,7 @@ class Model: body_dofnum: wp.array(dtype=wp.int32, ndim=1) body_dofadr: wp.array(dtype=wp.int32, ndim=1) body_geomnum: wp.array(dtype=wp.int32, ndim=1) - body_geomadr: wp.array(dtype=wp.int32, ndim=2) + body_geomadr: wp.array(dtype=wp.int32, ndim=1) body_pos: wp.array(dtype=wp.vec3, ndim=2) body_quat: wp.array(dtype=wp.quat, ndim=2) body_ipos: wp.array(dtype=wp.vec3, ndim=2) @@ -854,7 +854,7 @@ class Model: geom_conaffinity: wp.array(dtype=wp.int32, ndim=2) geom_condim: wp.array(dtype=wp.int32, ndim=1) geom_bodyid: wp.array(dtype=wp.int32, ndim=1) - geom_dataid: wp.array(dtype=wp.int32, ndim=1) + geom_dataid: wp.array(dtype=wp.int32, ndim=2) geom_priority: wp.array(dtype=wp.int32, ndim=2) geom_solmix: wp.array(dtype=wp.float32, ndim=2) geom_solref: wp.array(dtype=wp.vec2, ndim=2) From d5c7d81fe5af18d837bee8ba5f812f06fd8c432f Mon Sep 17 00:00:00 2001 From: Alain Denzler Date: Wed, 30 Apr 2025 14:55:56 +0200 Subject: [PATCH 30/31] rbound --- mujoco_warp/_src/collision_driver.py | 6 +++--- mujoco_warp/_src/io.py | 4 ++-- mujoco_warp/_src/types.py | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/mujoco_warp/_src/collision_driver.py b/mujoco_warp/_src/collision_driver.py index d68e8aff..db0d72f3 100644 --- a/mujoco_warp/_src/collision_driver.py +++ b/mujoco_warp/_src/collision_driver.py @@ -36,8 +36,8 @@ def _sphere_filter(m: Model, d: Data, geom1: int, geom2: int, worldid: int) -> b margin2 = m.geom_margin[worldid, geom2] pos1 = d.geom_xpos[worldid, geom1] pos2 = d.geom_xpos[worldid, geom2] - size1 = m.geom_rbound[geom1] - size2 = m.geom_rbound[geom2] + size1 = m.geom_rbound[worldid, geom1] + size2 = m.geom_rbound[worldid, geom2] bound = size1 + size2 + wp.max(margin1, margin2) dif = pos2 - pos1 @@ -105,7 +105,7 @@ def _sap_project(m: Model, d: Data, direction: wp.vec3): worldid, geomid = wp.tid() xpos = d.geom_xpos[worldid, geomid] - rbound = m.geom_rbound[geomid] + rbound = m.geom_rbound[worldid, geomid] if rbound == 0.0: # geom is a plane diff --git a/mujoco_warp/_src/io.py b/mujoco_warp/_src/io.py index 683c79c1..22753d23 100644 --- a/mujoco_warp/_src/io.py +++ b/mujoco_warp/_src/io.py @@ -422,7 +422,7 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: m.geom_margin = create_nmodel_batched_array(mjm.geom_margin, dtype=wp.float32) m.geom_gap = create_nmodel_batched_array(mjm.geom_gap, dtype=wp.float32) m.geom_aabb = wp.array(mjm.geom_aabb, dtype=wp.vec3, ndim=3) - m.geom_rbound = wp.array(mjm.geom_rbound, dtype=wp.float32, ndim=1) + m.geom_rbound = create_nmodel_batched_array(mjm.geom_rbound, dtype=wp.float32) m.geom_dataid = create_nmodel_batched_array(mjm.geom_dataid, dtype=wp.int32) m.mesh_vertadr = wp.array(mjm.mesh_vertadr, dtype=wp.int32, ndim=1) m.mesh_vertnum = wp.array(mjm.mesh_vertnum, dtype=wp.int32, ndim=1) @@ -523,7 +523,7 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: # predefined collision pairs m.pair_dim = wp.array(mjm.pair_dim, dtype=wp.int32, ndim=1) m.pair_geom1 = wp.array(mjm.pair_geom1, dtype=wp.int32, ndim=1) - m.pair_geom2 = create_nmodel_batched_array(mjm.pair_geom2, dtype=wp.int32) + m.pair_geom2 = wp.array(mjm.pair_geom2, dtype=wp.int32, ndim=1) m.pair_solref = create_nmodel_batched_array(mjm.pair_solref, dtype=wp.vec2) m.pair_solreffriction = create_nmodel_batched_array( mjm.pair_solreffriction, dtype=wp.vec2 diff --git a/mujoco_warp/_src/types.py b/mujoco_warp/_src/types.py index 57a638eb..6b3a441f 100644 --- a/mujoco_warp/_src/types.py +++ b/mujoco_warp/_src/types.py @@ -649,7 +649,7 @@ class Model: geom_solimp: constraint solver impedance: contact (nmodel, ngeom, mjNIMP) geom_size: geom-specific size parameters (nmodel, ngeom, 3) geom_aabb: bounding box, (center, size) (ngeom, 6) - geom_rbound: radius of bounding sphere (ngeom,) + geom_rbound: radius of bounding sphere (nmodel, ngeom) geom_pos: local position offset rel. to body (nmodel, ngeom, 3) geom_quat: local orientation offset rel. to body (nmodel, ngeom, 4) geom_friction: friction for (slide, spin, roll) (nmodel, ngeom, 3) @@ -861,7 +861,7 @@ class Model: geom_solimp: wp.array(dtype=vec5, ndim=2) geom_size: wp.array(dtype=wp.vec3, ndim=2) geom_aabb: wp.array(dtype=wp.vec3, ndim=2) - geom_rbound: wp.array(dtype=wp.float32, ndim=1) + geom_rbound: wp.array(dtype=wp.float32, ndim=2) geom_pos: wp.array(dtype=wp.vec3, ndim=2) geom_quat: wp.array(dtype=wp.quat, ndim=2) geom_friction: wp.array(dtype=wp.vec3, ndim=2) From 16caa15bfd6917134f5cfca40b07365e50fecc1d Mon Sep 17 00:00:00 2001 From: Alain Denzler Date: Wed, 30 Apr 2025 15:03:44 +0200 Subject: [PATCH 31/31] formatting --- mujoco_warp/_src/collision_primitive.py | 15 +++---- mujoco_warp/_src/constraint.py | 38 +++++++++--------- mujoco_warp/_src/forward.py | 4 +- mujoco_warp/_src/io.py | 4 +- mujoco_warp/_src/passive.py | 9 ++--- mujoco_warp/_src/sensor.py | 12 ++---- mujoco_warp/_src/smooth.py | 53 ++++++++++--------------- 7 files changed, 57 insertions(+), 78 deletions(-) diff --git a/mujoco_warp/_src/collision_primitive.py b/mujoco_warp/_src/collision_primitive.py index 3bd8eff0..e5715a5a 100644 --- a/mujoco_warp/_src/collision_primitive.py +++ b/mujoco_warp/_src/collision_primitive.py @@ -793,13 +793,10 @@ def contact_params(m: Model, d: Data, cid: int, worldid: int): geom_friction[2], ) - if ( - m.geom_solref[worldid, g1].x > 0.0 - and m.geom_solref[worldid, g2].x > 0.0 - ): - solref = mix * m.geom_solref[worldid, g1] + ( - 1.0 - mix - ) * m.geom_solref[worldid, g2] + if m.geom_solref[worldid, g1].x > 0.0 and m.geom_solref[worldid, g2].x > 0.0: + solref = ( + mix * m.geom_solref[worldid, g1] + (1.0 - mix) * m.geom_solref[worldid, g2] + ) else: solref = wp.min( m.geom_solref[worldid, g1], @@ -808,9 +805,7 @@ def contact_params(m: Model, d: Data, cid: int, worldid: int): solreffriction = wp.vec2(0.0, 0.0) - solimp = mix * m.geom_solimp[worldid, g1] + ( - 1.0 - mix - ) * m.geom_solimp[worldid, g2] + solimp = mix * m.geom_solimp[worldid, g1] + (1.0 - mix) * m.geom_solimp[worldid, g2] return geoms, margin, gap, condim, friction, solref, solreffriction, solimp diff --git a/mujoco_warp/_src/constraint.py b/mujoco_warp/_src/constraint.py index 5a859de5..436c2c1b 100644 --- a/mujoco_warp/_src/constraint.py +++ b/mujoco_warp/_src/constraint.py @@ -131,7 +131,9 @@ def _efc_equality_connect( d.efc.J[efcid + 2, dofid] = j1mj2[2] Jqvel += j1mj2 * d.qvel[worldid, dofid] - invweight = m.body_invweight0[worldid, body1id, 0] + m.body_invweight0[worldid, body2id, 0] + invweight = ( + m.body_invweight0[worldid, body1id, 0] + m.body_invweight0[worldid, body2id, 0] + ) pos_imp = wp.length(pos) solref = m.eq_solref[worldid, i_eq] @@ -190,20 +192,14 @@ def _efc_equality_joint( 2.0 * data[2] + dif * (3.0 * data[3] + dif * 4.0 * data[4]) ) - pos = ( - d.qpos[worldid, qposadr1] - m.qpos0[worldid, qposadr1] - rhs - ) + pos = d.qpos[worldid, qposadr1] - m.qpos0[worldid, qposadr1] - rhs Jqvel = d.qvel[worldid, dofadr1] - d.qvel[worldid, dofadr2] * deriv_2 invweight = m.dof_invweight0[worldid, dofadr1] + m.dof_invweight0[worldid, dofadr2] d.efc.J[efcid, dofadr2] = -deriv_2 else: # Single joint constraint - pos = ( - d.qpos[worldid, qposadr1] - - m.qpos0[worldid, qposadr1] - - data[0] - ) + pos = d.qpos[worldid, qposadr1] - m.qpos0[worldid, qposadr1] - data[0] Jqvel = d.qvel[worldid, dofadr1] invweight = m.dof_invweight0[worldid, dofadr1] @@ -290,13 +286,9 @@ def _efc_equality_weld( body2id = m.site_bodyid[obj2id] pos1 = d.site_xpos[worldid, obj1id] pos2 = d.site_xpos[worldid, obj2id] - quat = math.mul_quat( - d.xquat[worldid, body1id], m.site_quat[worldid, obj1id] - ) + quat = math.mul_quat(d.xquat[worldid, body1id], m.site_quat[worldid, obj1id]) quat1 = math.quat_inv( - math.mul_quat( - d.xquat[worldid, body2id], m.site_quat[worldid, obj2id] - ) + math.mul_quat(d.xquat[worldid, body2id], m.site_quat[worldid, obj2id]) ) else: @@ -336,7 +328,9 @@ def _efc_equality_weld( crotq = math.mul_quat(quat1, quat) # copy axis components crot = wp.vec3(crotq[1], crotq[2], crotq[3]) * torquescale - invweight_t = m.body_invweight0[worldid, body1id, 0] + m.body_invweight0[worldid, body2id, 0] + invweight_t = ( + m.body_invweight0[worldid, body1id, 0] + m.body_invweight0[worldid, body2id, 0] + ) pos_imp = wp.sqrt(wp.length_sq(cpos) + wp.length_sq(crot)) @@ -359,7 +353,9 @@ def _efc_equality_weld( i_eq, ) - invweight_r = m.body_invweight0[worldid, body1id, 1] + m.body_invweight0[worldid, body2id, 1] + invweight_r = ( + m.body_invweight0[worldid, body1id, 1] + m.body_invweight0[worldid, body2id, 1] + ) for i in range(3): _update_efc_row( @@ -558,7 +554,9 @@ def _efc_contact_pyramidal( frame = d.contact.frame[conid] # pyramidal has common invweight across all edges - invweight = m.body_invweight0[worldid, body1, 0] + m.body_invweight0[worldid, body2, 0] + invweight = ( + m.body_invweight0[worldid, body1, 0] + m.body_invweight0[worldid, body2, 0] + ) if condim > 1: dimid2 = dimid / 2 + 1 @@ -659,7 +657,9 @@ def _efc_contact_elliptic( d.efc.J[efcid, i] = J Jqvel += J * d.qvel[worldid, i] - invweight = m.body_invweight0[worldid, body1, 0] + m.body_invweight0[worldid, body2, 0] + invweight = ( + m.body_invweight0[worldid, body1, 0] + m.body_invweight0[worldid, body2, 0] + ) ref = d.contact.solref[conid] pos_aref = pos diff --git a/mujoco_warp/_src/forward.py b/mujoco_warp/_src/forward.py index 7c3b78de..76b4c703 100644 --- a/mujoco_warp/_src/forward.py +++ b/mujoco_warp/_src/forward.py @@ -193,7 +193,9 @@ def add_damping_sum_qfrc_kernel_sparse(m: Model, d: Data): worldid, tid = wp.tid() dof_Madr = m.dof_Madr[tid] - d.qM_integration[worldid, 0, dof_Madr] += m.opt.timestep * m.dof_damping[worldid, tid] + d.qM_integration[worldid, 0, dof_Madr] += ( + m.opt.timestep * m.dof_damping[worldid, tid] + ) d.qfrc_integration[worldid, tid] = ( d.qfrc_smooth[worldid, tid] + d.qfrc_constraint[worldid, tid] diff --git a/mujoco_warp/_src/io.py b/mujoco_warp/_src/io.py index 22753d23..b29bb1d2 100644 --- a/mujoco_warp/_src/io.py +++ b/mujoco_warp/_src/io.py @@ -375,7 +375,9 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: m.body_rootid = wp.array(mjm.body_rootid, dtype=wp.int32, ndim=1) m.body_inertia = create_nmodel_batched_array(mjm.body_inertia, dtype=wp.vec3) m.body_mass = create_nmodel_batched_array(mjm.body_mass, dtype=wp.float32) - m.body_subtreemass = create_nmodel_batched_array(mjm.body_subtreemass, dtype=wp.float32) + m.body_subtreemass = create_nmodel_batched_array( + mjm.body_subtreemass, dtype=wp.float32 + ) subtree_mass = np.copy(mjm.body_mass) # TODO(team): should this be [mjm.nbody - 1, 0) ? diff --git a/mujoco_warp/_src/passive.py b/mujoco_warp/_src/passive.py index a2e7c3cb..033aad9a 100644 --- a/mujoco_warp/_src/passive.py +++ b/mujoco_warp/_src/passive.py @@ -46,12 +46,9 @@ def _spring(m: Model, d: Data): if jnt_type == wp.static(JointType.FREE.value): dif = wp.vec3( - d.qpos[worldid, qposid + 0] - - m.qpos_spring[worldid, qposid + 0], - d.qpos[worldid, qposid + 1] - - m.qpos_spring[worldid, qposid + 1], - d.qpos[worldid, qposid + 2] - - m.qpos_spring[worldid, qposid + 2], + d.qpos[worldid, qposid + 0] - m.qpos_spring[worldid, qposid + 0], + d.qpos[worldid, qposid + 1] - m.qpos_spring[worldid, qposid + 1], + d.qpos[worldid, qposid + 2] - m.qpos_spring[worldid, qposid + 2], ) d.qfrc_spring[worldid, dofid + 0] = -stiffness * dif[0] d.qfrc_spring[worldid, dofid + 1] = -stiffness * dif[1] diff --git a/mujoco_warp/_src/sensor.py b/mujoco_warp/_src/sensor.py index 608fb14e..ae02738a 100644 --- a/mujoco_warp/_src/sensor.py +++ b/mujoco_warp/_src/sensor.py @@ -167,21 +167,15 @@ def _frame_quat( m: Model, d: Data, worldid: int, objid: int, objtype: int, refid: int ) -> wp.quat: if objtype == int(ObjType.BODY.value): - quat = math.mul_quat( - d.xquat[worldid, objid], m.body_iquat[worldid, objid] - ) + quat = math.mul_quat(d.xquat[worldid, objid], m.body_iquat[worldid, objid]) if refid == -1: return quat - refquat = math.mul_quat( - d.xquat[worldid, refid], m.body_iquat[worldid, refid] - ) + refquat = math.mul_quat(d.xquat[worldid, refid], m.body_iquat[worldid, refid]) elif objtype == int(ObjType.XBODY.value): quat = d.xquat[worldid, objid] if refid == -1: return quat - refquat = math.mul_quat( - d.xquat[worldid, refid], m.body_iquat[worldid, refid] - ) + refquat = math.mul_quat(d.xquat[worldid, refid], m.body_iquat[worldid, refid]) elif objtype == int(ObjType.GEOM.value): quat = math.mul_quat( d.xquat[worldid, m.geom_bodyid[objid]], diff --git a/mujoco_warp/_src/smooth.py b/mujoco_warp/_src/smooth.py index e338ea9c..ff7d4c62 100644 --- a/mujoco_warp/_src/smooth.py +++ b/mujoco_warp/_src/smooth.py @@ -61,12 +61,8 @@ def _level(m: Model, d: Data, leveladr: int): if jntnum == 0: # no joints - apply fixed translation and rotation relative to parent pid = m.body_parentid[bodyid] - xpos = ( - d.xmat[worldid, pid] * m.body_pos[worldid, bodyid] - ) + d.xpos[worldid, pid] - xquat = math.mul_quat( - d.xquat[worldid, pid], m.body_quat[worldid, bodyid] - ) + xpos = (d.xmat[worldid, pid] * m.body_pos[worldid, bodyid]) + d.xpos[worldid, pid] + xquat = math.mul_quat(d.xquat[worldid, pid], m.body_quat[worldid, bodyid]) elif jntnum == 1 and m.jnt_type[jntadr] == wp.static(JointType.FREE.value): # free joint qadr = m.jnt_qposadr[jntadr] @@ -78,20 +74,14 @@ def _level(m: Model, d: Data, leveladr: int): # regular or no joints # apply fixed translation and rotation relative to parent pid = m.body_parentid[bodyid] - xpos = ( - d.xmat[worldid, pid] * m.body_pos[worldid, bodyid] - ) + d.xpos[worldid, pid] - xquat = math.mul_quat( - d.xquat[worldid, pid], m.body_quat[worldid, bodyid] - ) + xpos = (d.xmat[worldid, pid] * m.body_pos[worldid, bodyid]) + d.xpos[worldid, pid] + xquat = math.mul_quat(d.xquat[worldid, pid], m.body_quat[worldid, bodyid]) for _ in range(jntnum): qadr = m.jnt_qposadr[jntadr] jnt_type = m.jnt_type[jntadr] jnt_axis = m.jnt_axis[jntadr] - xanchor = ( - math.rot_vec_quat(m.jnt_pos[worldid, jntadr], xquat) + xpos - ) + xanchor = math.rot_vec_quat(m.jnt_pos[worldid, jntadr], xquat) + xpos xaxis = math.rot_vec_quat(jnt_axis, xquat) if jnt_type == wp.static(JointType.BALL.value): @@ -103,9 +93,7 @@ def _level(m: Model, d: Data, leveladr: int): ) xquat = math.mul_quat(xquat, qloc) # correct for off-center rotation - xpos = xanchor - math.rot_vec_quat( - m.jnt_pos[worldid, jntadr], xquat - ) + xpos = xanchor - math.rot_vec_quat(m.jnt_pos[worldid, jntadr], xquat) elif jnt_type == wp.static(JointType.SLIDE.value): xpos += xaxis * (qpos[qadr] - m.qpos0[worldid, qadr]) elif jnt_type == wp.static(JointType.HINGE.value): @@ -113,9 +101,7 @@ def _level(m: Model, d: Data, leveladr: int): qloc = math.axis_angle_to_quat(jnt_axis, qpos[qadr] - qpos0) xquat = math.mul_quat(xquat, qloc) # correct for off-center rotation - xpos = xanchor - math.rot_vec_quat( - m.jnt_pos[worldid, jntadr], xquat - ) + xpos = xanchor - math.rot_vec_quat(m.jnt_pos[worldid, jntadr], xquat) d.xanchor[worldid, jntadr] = xanchor d.xaxis[worldid, jntadr] = xaxis @@ -180,7 +166,9 @@ def com_pos(m: Model, d: Data): @kernel def subtree_com_init(m: Model, d: Data): worldid, bodyid = wp.tid() - d.subtree_com[worldid, bodyid] = d.xipos[worldid, bodyid] * m.body_mass[worldid, bodyid] + d.subtree_com[worldid, bodyid] = ( + d.xipos[worldid, bodyid] * m.body_mass[worldid, bodyid] + ) @kernel def subtree_com_acc(m: Model, d: Data, leveladr: int): @@ -304,9 +292,9 @@ def cam_fn(m: Model, d: Data): body_xpos = d.xpos[worldid, m.cam_bodyid[camid]] d.cam_xpos[worldid, camid] = body_xpos + m.cam_pos0[worldid, camid] elif m.cam_mode[camid] == wp.static(CamLightType.TRACKCOM.value): - d.cam_xpos[worldid, camid] = d.subtree_com[ - worldid, m.cam_bodyid[camid] - ] + m.cam_poscom0[worldid, camid] + d.cam_xpos[worldid, camid] = ( + d.subtree_com[worldid, m.cam_bodyid[camid]] + m.cam_poscom0[worldid, camid] + ) elif m.cam_mode[camid] == wp.static(CamLightType.TARGETBODY.value) or m.cam_mode[ camid ] == wp.static(CamLightType.TARGETBODYCOM.value): @@ -353,9 +341,10 @@ def light_fn(m: Model, d: Data): body_xpos = d.xpos[worldid, m.light_bodyid[lightid]] d.light_xpos[worldid, lightid] = body_xpos + m.light_pos0[worldid, lightid] elif m.light_mode[lightid] == wp.static(CamLightType.TRACKCOM.value): - d.light_xpos[worldid, lightid] = d.subtree_com[ - worldid, m.light_bodyid[lightid] - ] + m.light_poscom0[worldid, lightid] + d.light_xpos[worldid, lightid] = ( + d.subtree_com[worldid, m.light_bodyid[lightid]] + + m.light_poscom0[worldid, lightid] + ) elif m.light_mode[lightid] == wp.static( CamLightType.TARGETBODY.value ) or m.light_mode[lightid] == wp.static(CamLightType.TARGETBODYCOM.value): @@ -1130,9 +1119,7 @@ def _forward(m: Model, d: Data): # update linear velocity lin -= wp.cross(xipos - subtree_com_root, ang) - d.subtree_linvel[worldid, bodyid] = ( - m.body_mass[worldid, bodyid] * lin - ) + d.subtree_linvel[worldid, bodyid] = m.body_mass[worldid, bodyid] * lin dv = wp.transpose(ximat) @ ang dv[0] *= m.body_inertia[worldid, bodyid][0] dv[1] *= m.body_inertia[worldid, bodyid][1] @@ -1150,7 +1137,9 @@ def _linear_momentum(m: Model, d: Data, leveladr: int): if bodyid: pid = m.body_parentid[bodyid] wp.atomic_add(d.subtree_linvel[worldid], pid, d.subtree_linvel[worldid, bodyid]) - d.subtree_linvel[worldid, bodyid] /= wp.max(MJ_MINVAL, m.body_subtreemass[worldid,bodyid]) + d.subtree_linvel[worldid, bodyid] /= wp.max( + MJ_MINVAL, m.body_subtreemass[worldid, bodyid] + ) body_treeadr = m.body_treeadr.numpy() for i in reversed(range(len(body_treeadr))):