diff --git a/contrib/kernel_analyzer/kernel_analyzer/ast_analyzer_test.py b/contrib/kernel_analyzer/kernel_analyzer/ast_analyzer_test.py index 498377e4..87cb15b6 100644 --- a/contrib/kernel_analyzer/kernel_analyzer/ast_analyzer_test.py +++ b/contrib/kernel_analyzer/kernel_analyzer/ast_analyzer_test.py @@ -136,8 +136,8 @@ def test_all_issues( @kernel def test_no_issues( # Model: - qpos0: wp.array(dtype=float), - geom_pos: wp.array(dtype=wp.vec3), + qpos0: wp.array2d(dtype=float), + geom_pos: wp.array2d(dtype=wp.vec3), # Data in: qpos_in: wp.array2d(dtype=float), qvel_in: wp.array2d(dtype=float), diff --git a/mujoco_warp/_src/collision_box.py b/mujoco_warp/_src/collision_box.py new file mode 100644 index 00000000..0ff58080 --- /dev/null +++ b/mujoco_warp/_src/collision_box.py @@ -0,0 +1,652 @@ +# Copyright 2025 The Newton Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + + +import math +from typing import Any + +import warp as wp + +from .collision_primitive import contact_params +from .collision_primitive import write_contact +from .math import make_frame +from .types import Data +from .types import GeomType +from .types import Model +from .types import vec5 + +BOX_BOX_BLOCK_DIM = 32 + + +_HUGE_VAL = 1e6 +_TINY_VAL = 1e-6 + + +class vec16b(wp.types.vector(length=16, dtype=wp.int8)): + pass + + +class mat43f(wp.types.matrix(shape=(4, 3), dtype=wp.float32)): + pass + + +class mat83f(wp.types.matrix(shape=(8, 3), dtype=wp.float32)): + pass + + +class mat16_3f(wp.types.matrix(shape=(16, 3), dtype=wp.float32)): + pass + + +Box = mat83f + + +@wp.func +def _argmin(a: Any) -> wp.int32: + amin = wp.int32(0) + vmin = wp.float32(a[0]) + for i in range(1, len(a)): + if a[i] < vmin: + amin = i + vmin = a[i] + return amin + + +@wp.func +def box_normals(i: int) -> wp.vec3: + direction = wp.where(i < 3, -1.0, 1.0) + mod = i % 3 + if mod == 0: + return wp.vec3(0.0, direction, 0.0) + if mod == 1: + return wp.vec3(0.0, 0.0, direction) + return wp.vec3(-direction, 0.0, 0.0) + + +@wp.func +def box(R: wp.mat33, t: wp.vec3, size: wp.vec3) -> Box: + """Get a transformed box""" + x = size[0] + y = size[1] + z = size[2] + m = Box() + for i in range(8): + ix = wp.where(i & 4, x, -x) + iy = wp.where(i & 2, y, -y) + iz = wp.where(i & 1, z, -z) + m[i] = R @ wp.vec3(ix, iy, iz) + t + return m + + +@wp.func +def box_face_verts(box: Box, idx: int) -> mat43f: + """Get the quad corresponding to a box face""" + if idx == 0: + verts = wp.vec4i(0, 4, 5, 1) + if idx == 1: + verts = wp.vec4i(0, 2, 6, 4) + if idx == 2: + verts = wp.vec4i(6, 7, 5, 4) + if idx == 3: + verts = wp.vec4i(2, 3, 7, 6) + if idx == 4: + verts = wp.vec4i(1, 5, 7, 3) + if idx == 5: + verts = wp.vec4i(0, 1, 3, 2) + + m = mat43f() + for i in range(4): + m[i] = box[verts[i]] + return m + + +@wp.func +def get_box_axis(axis_idx: int, R: wp.mat33): + """Get the axis at index axis_idx. + R: rotation matrix from a to b + Axes 0-12 are face normals of boxes a & b + Axes 12-21 are edge cross products.""" + if axis_idx < 6: # a faces + axis = R @ wp.vec3(box_normals(axis_idx)) + is_degenerate = False + elif axis_idx < 12: # b faces + axis = wp.vec3(box_normals(axis_idx - 6)) + is_degenerate = False + else: # edges cross products + assert axis_idx < 21 + edges = axis_idx - 12 + axis_a, axis_b = edges / 3, edges % 3 + edge_a = wp.transpose(R)[axis_a] + if axis_b == 0: + axis = wp.vec3(0.0, -edge_a[2], edge_a[1]) + elif axis_b == 1: + axis = wp.vec3(edge_a[2], 0.0, -edge_a[0]) + else: + axis = wp.vec3(-edge_a[1], edge_a[0], 0.0) + is_degenerate = wp.length_sq(axis) < _TINY_VAL + return wp.normalize(axis), is_degenerate + + +@wp.func +def get_box_axis_support(axis: wp.vec3, degenerate_axis: bool, a: Box, b: Box): + """Get the overlap (or separating distance if negative) along `axis`, and the sign.""" + axis_d = wp.vec3d(axis) + support_a_max, support_b_max = wp.float32(-_HUGE_VAL), wp.float32(-_HUGE_VAL) + support_a_min, support_b_min = wp.float32(_HUGE_VAL), wp.float32(_HUGE_VAL) + for i in range(8): + vert_a = wp.vec3d(a[i]) + vert_b = wp.vec3d(b[i]) + proj_a = wp.float32(wp.dot(vert_a, axis_d)) + proj_b = wp.float32(wp.dot(vert_b, axis_d)) + support_a_max = wp.max(support_a_max, proj_a) + support_b_max = wp.max(support_b_max, proj_b) + support_a_min = wp.min(support_a_min, proj_a) + support_b_min = wp.min(support_b_min, proj_b) + dist1 = support_a_max - support_b_min + dist2 = support_b_max - support_a_min + dist = wp.where(degenerate_axis, _HUGE_VAL, wp.min(dist1, dist2)) + sign = wp.where(dist1 > dist2, -1, 1) + return dist, sign + + +@wp.struct +class AxisSupport: + best_dist: wp.float32 + best_sign: wp.int8 + best_idx: wp.int8 + + +@wp.func +def reduce_axis_support(a: AxisSupport, b: AxisSupport): + return wp.where(a.best_dist > b.best_dist, b, a) + + +@wp.func +def face_axis_alignment(a: wp.vec3, R: wp.mat33) -> wp.int32: + """Find the box faces most aligned with the axis `a`""" + max_dot = wp.float32(0.0) + max_idx = wp.int32(0) + for i in range(6): + d = wp.dot(R @ box_normals(i), a) + if d > max_dot: + max_dot = d + max_idx = i + return max_idx + + +@wp.kernel(enable_backward=False) +def _box_box( + # Model: + geom_type: wp.array(dtype=int), + geom_condim: wp.array(dtype=int), + geom_priority: wp.array(dtype=int), + geom_solmix: wp.array2d(dtype=float), + geom_solref: wp.array2d(dtype=wp.vec2), + geom_solimp: wp.array2d(dtype=vec5), + geom_size: wp.array2d(dtype=wp.vec3), + geom_friction: wp.array2d(dtype=wp.vec3), + geom_margin: wp.array2d(dtype=float), + geom_gap: wp.array2d(dtype=float), + pair_dim: wp.array(dtype=int), + pair_solref: wp.array2d(dtype=wp.vec2), + pair_solreffriction: wp.array2d(dtype=wp.vec2), + pair_solimp: wp.array2d(dtype=vec5), + pair_margin: wp.array2d(dtype=float), + pair_gap: wp.array2d(dtype=float), + pair_friction: wp.array2d(dtype=vec5), + # Data in: + nconmax_in: int, + geom_xpos_in: wp.array2d(dtype=wp.vec3), + geom_xmat_in: wp.array2d(dtype=wp.mat33), + collision_pair_in: wp.array(dtype=wp.vec2i), + collision_pairid_in: wp.array(dtype=int), + collision_worldid_in: wp.array(dtype=int), + ncollision_in: wp.array(dtype=int), + # In: + num_kernels_in: int, + # Data out: + ncon_out: wp.array(dtype=int), + contact_dist_out: wp.array(dtype=float), + contact_pos_out: wp.array(dtype=wp.vec3), + contact_frame_out: wp.array(dtype=wp.mat33), + contact_includemargin_out: wp.array(dtype=float), + contact_dim_out: wp.array(dtype=int), + contact_friction_out: wp.array(dtype=vec5), + contact_solref_out: wp.array(dtype=wp.vec2), + contact_solreffriction_out: wp.array(dtype=wp.vec2), + contact_solimp_out: wp.array(dtype=vec5), + contact_geom_out: wp.array(dtype=wp.vec2i), + contact_worldid_out: wp.array(dtype=int), +): + """Calculates contacts between pairs of boxes.""" + tid, axis_idx = wp.tid() + + for bp_idx in range(tid, min(ncollision_in[0], nconmax_in), num_kernels_in): + geoms = collision_pair_in[bp_idx] + + ga, gb = geoms[0], geoms[1] + + if geom_type[ga] != int(GeomType.BOX.value) or geom_type[gb] != int(GeomType.BOX.value): + continue + + worldid = collision_worldid_in[bp_idx] + + geoms, margin, gap, condim, friction, solref, solreffriction, solimp = contact_params( + geom_condim, + geom_priority, + geom_solmix, + geom_solref, + geom_solimp, + geom_friction, + geom_margin, + geom_gap, + pair_dim, + pair_solref, + pair_solreffriction, + pair_solimp, + pair_margin, + pair_gap, + pair_friction, + collision_pair_in, + collision_pairid_in, + tid, + worldid, + ) + + # transformations + a_pos, b_pos = geom_xpos_in[worldid, ga], geom_xpos_in[worldid, gb] + a_mat, b_mat = geom_xmat_in[worldid, ga], geom_xmat_in[worldid, gb] + b_mat_inv = wp.transpose(b_mat) + trans_atob = b_mat_inv @ (a_pos - b_pos) + rot_atob = b_mat_inv @ a_mat + + a_size = geom_size[worldid, ga] + b_size = geom_size[worldid, gb] + a = box(rot_atob, trans_atob, a_size) + b = box(wp.identity(3, wp.float32), wp.vec3(0.0), b_size) + + # box-box implementation + + # Inlined def collision_axis_tiled( a: Box, b: Box, R: wp.mat33, axis_idx: wp.int32,): + # Finds the axis of minimum separation. + # a: Box a vertices, in frame b + # b: Box b vertices, in frame b + # R: rotation matrix from a to b + # Returns: + # best_axis: vec3 + # best_sign: int32 + # best_idx: int32 + R = rot_atob + + # launch tiled with block_dim=21 + if axis_idx > 20: + continue + + axis, degenerate_axis = get_box_axis(axis_idx, R) + axis_dist, axis_sign = get_box_axis_support(axis, degenerate_axis, a, b) + + supports = wp.tile(AxisSupport(axis_dist, wp.int8(axis_sign), wp.int8(axis_idx))) + + face_supports = wp.tile_view(supports, offset=(0,), shape=(12,)) + edge_supports = wp.tile_view(supports, offset=(12,), shape=(9,)) + + face_supports_red = wp.tile_reduce(reduce_axis_support, face_supports) + edge_supports_red = wp.tile_reduce(reduce_axis_support, edge_supports) + + face = face_supports_red[0] + edge = edge_supports_red[0] + + if axis_idx > 0: # single thread + continue + + # choose the best separating axis + face_axis, _ = get_box_axis(wp.int32(face.best_idx), R) + best_axis = wp.vec3(face_axis) + best_sign = wp.int32(face.best_sign) + best_idx = wp.int32(face.best_idx) + best_dist = wp.float32(face.best_dist) + + if edge.best_dist < face.best_dist: + edge_axis, _ = get_box_axis(wp.int32(edge.best_idx), R) + if wp.abs(wp.dot(face_axis, edge_axis)) < 0.99: + best_axis = edge_axis + best_sign = wp.int32(edge.best_sign) + best_idx = wp.int32(edge.best_idx) + best_dist = wp.float32(edge.best_dist) + # end inlined collision_axis_tiled + + # if axis_idx != 0: + # continue + if best_dist < 0: + continue + + # get the (reference) face most aligned with the separating axis + a_max = face_axis_alignment(best_axis, rot_atob) + b_max = face_axis_alignment(best_axis, wp.identity(3, wp.float32)) + + sep_axis = wp.float32(best_sign) * best_axis + + if best_sign > 0: + b_min = (b_max + 3) % 6 + dist, pos = _create_contact_manifold( + box_face_verts(a, a_max), + rot_atob @ box_normals(a_max), + box_face_verts(b, b_min), + box_normals(b_min), + ) + else: + a_min = (a_max + 3) % 6 + dist, pos = _create_contact_manifold( + box_face_verts(b, b_max), + box_normals(b_max), + box_face_verts(a, a_min), + rot_atob @ box_normals(a_min), + ) + + # For edge contacts, we use the clipped face point, mainly for performance + # reasons. For small penetration, the clipped face point is roughly the edge + # contact point. + if best_idx > 11: # is_edge_contact + idx = _argmin(dist) + dist = wp.vec4f(dist[idx], 1.0, 1.0, 1.0) + for i in range(4): + pos[i] = pos[idx] + + margin = wp.max(geom_margin[worldid, ga], geom_margin[worldid, gb]) + for i in range(4): + pos_glob = b_mat @ pos[i] + b_pos + n_glob = b_mat @ sep_axis + + write_contact( + nconmax_in, + dist[i], + pos_glob, + make_frame(n_glob), + margin, + gap, + condim, + friction, + solref, + solreffriction, + solimp, + geoms, + worldid, + ncon_out, + contact_dist_out, + contact_pos_out, + contact_frame_out, + contact_includemargin_out, + contact_friction_out, + contact_solref_out, + contact_solreffriction_out, + contact_solimp_out, + contact_dim_out, + contact_geom_out, + contact_worldid_out, + ) + + +@wp.func +def _closest_segment_point_plane(a: wp.vec3, b: wp.vec3, p0: wp.vec3, plane_normal: wp.vec3) -> wp.vec3: + """Gets the closest point between a line segment and a plane. + + Args: + a: first line segment point + b: second line segment point + p0: point on plane + plane_normal: plane normal + + Returns: + closest point between the line segment and the plane + """ + # Parametrize a line segment as S(t) = a + t * (b - a), plug it into the plane + # equation dot(n, S(t)) - d = 0, then solve for t to get the line-plane + # intersection. We then clip t to be in [0, 1] to be on the line segment. + n = plane_normal + d = wp.dot(p0, n) # shortest distance from origin to plane + denom = wp.dot(n, (b - a)) + t = (d - wp.dot(n, a)) / (denom + wp.where(denom == 0.0, _TINY_VAL, 0.0)) + t = wp.clamp(t, 0.0, 1.0) + segment_point = a + t * (b - a) + + return segment_point + + +@wp.func +def _project_poly_onto_plane(poly: Any, poly_n: wp.vec3, plane_n: wp.vec3, plane_pt: wp.vec3): + """Projects poly1 onto the poly2 plane along poly2's normal.""" + d = wp.dot(plane_pt, plane_n) + denom = wp.dot(poly_n, plane_n) + qn_scaled = poly_n / (denom + wp.where(denom == 0.0, _TINY_VAL, 0.0)) + + for i in range(len(poly)): + poly[i] = poly[i] + (d - wp.dot(poly[i], plane_n)) * qn_scaled + return poly + + +@wp.func +def _clip_edge_to_quad(subject_poly: mat43f, clipping_poly: mat43f, clipping_normal: wp.vec3): + p0 = mat43f() + p1 = mat43f() + mask = wp.vec4b() + for edge_idx in range(4): + subject_p0 = subject_poly[(edge_idx + 3) % 4] + subject_p1 = subject_poly[edge_idx] + + any_both_in_front = wp.int32(0) + clipped0_dist_max = wp.float32(-_HUGE_VAL) + clipped1_dist_max = wp.float32(-_HUGE_VAL) + clipped_p0_distmax = wp.vec3(0.0) + clipped_p1_distmax = wp.vec3(0.0) + + for clipping_edge_idx in range(4): + clipping_p0 = clipping_poly[(clipping_edge_idx + 3) % 4] + clipping_p1 = clipping_poly[clipping_edge_idx] + edge_normal = wp.cross(clipping_p1 - clipping_p0, clipping_normal) + + p0_in_front = wp.dot(subject_p0 - clipping_p0, edge_normal) > _TINY_VAL + p1_in_front = wp.dot(subject_p1 - clipping_p0, edge_normal) > _TINY_VAL + candidate_clipped_p = _closest_segment_point_plane(subject_p0, subject_p1, clipping_p1, edge_normal) + clipped_p0 = wp.where(p0_in_front, candidate_clipped_p, subject_p0) + clipped_p1 = wp.where(p1_in_front, candidate_clipped_p, subject_p1) + clipped_dist_p0 = wp.dot(clipped_p0 - subject_p0, subject_p1 - subject_p0) + clipped_dist_p1 = wp.dot(clipped_p1 - subject_p1, subject_p0 - subject_p1) + any_both_in_front |= wp.int32(p0_in_front and p1_in_front) + + if clipped_dist_p0 > clipped0_dist_max: + clipped0_dist_max = clipped_dist_p0 + clipped_p0_distmax = clipped_p0 + + if clipped_dist_p1 > clipped1_dist_max: + clipped1_dist_max = clipped_dist_p1 + clipped_p1_distmax = clipped_p1 + new_p0 = wp.where(any_both_in_front, subject_p0, clipped_p0_distmax) + new_p1 = wp.where(any_both_in_front, subject_p1, clipped_p1_distmax) + + mask_val = wp.int8( + wp.where( + wp.dot(subject_p0 - subject_p1, new_p0 - new_p1) < 0, + 0, + wp.int32(not any_both_in_front), + ) + ) + + p0[edge_idx] = new_p0 + p1[edge_idx] = new_p1 + mask[edge_idx] = mask_val + return p0, p1, mask + + +@wp.func +def _clip_quad(subject_quad: mat43f, subject_normal: wp.vec3, clipping_quad: mat43f, clipping_normal: wp.vec3): + """Clips a subject quad against a clipping quad. + Serial implementation. + """ + + subject_clipped_p0, subject_clipped_p1, subject_mask = _clip_edge_to_quad(subject_quad, clipping_quad, clipping_normal) + clipping_proj = _project_poly_onto_plane(clipping_quad, clipping_normal, subject_normal, subject_quad[0]) + clipping_clipped_p0, clipping_clipped_p1, clipping_mask = _clip_edge_to_quad(clipping_proj, subject_quad, subject_normal) + + clipped = mat16_3f() + mask = vec16b() + for i in range(4): + clipped[i] = subject_clipped_p0[i] + clipped[i + 4] = clipping_clipped_p0[i] + clipped[i + 8] = subject_clipped_p1[i] + clipped[i + 12] = clipping_clipped_p1[i] + mask[i] = subject_mask[i] + mask[i + 4] = clipping_mask[i] + mask[i + 8] = subject_mask[i] + mask[i + 8 + 4] = clipping_mask[i] + + return clipped, mask + + +# TODO(ca): tiling variant +@wp.func +def _manifold_points(poly: Any, mask: Any, clipping_norm: wp.vec3) -> wp.vec4b: + """Chooses four points on the polygon with approximately maximal area. Return the indices""" + n = len(poly) + + a_idx = wp.int32(0) + a_mask = wp.int8(mask[0]) + for i in range(n): + if mask[i] >= a_mask: + a_idx = i + a_mask = mask[i] + a = poly[a_idx] + + b_idx = wp.int32(0) + b_dist = wp.float32(-_HUGE_VAL) + for i in range(n): + dist = wp.length_sq(poly[i] - a) + wp.where(mask[i], 0.0, -_HUGE_VAL) + if dist >= b_dist: + b_idx = i + b_dist = dist + b = poly[b_idx] + + ab = wp.cross(clipping_norm, a - b) + + c_idx = wp.int32(0) + c_dist = wp.float32(-_HUGE_VAL) + for i in range(n): + ap = a - poly[i] + dist = wp.abs(wp.dot(ap, ab)) + wp.where(mask[i], 0.0, -_HUGE_VAL) + if dist >= c_dist: + c_idx = i + c_dist = dist + c = poly[c_idx] + + ac = wp.cross(clipping_norm, a - c) + bc = wp.cross(clipping_norm, b - c) + + d_idx = wp.int32(0) + d_dist = wp.float32(-2.0 * _HUGE_VAL) + for i in range(n): + ap = a - poly[i] + dist_ap = wp.abs(wp.dot(ap, ac)) + wp.where(mask[i], 0.0, -_HUGE_VAL) + bp = b - poly[i] + dist_bp = wp.abs(wp.dot(bp, bc)) + wp.where(mask[i], 0.0, -_HUGE_VAL) + if dist_ap + dist_bp >= d_dist: + d_idx = i + d_dist = dist_ap + dist_bp + d = poly[d_idx] + return wp.vec4b(wp.int8(a_idx), wp.int8(b_idx), wp.int8(c_idx), wp.int8(d_idx)) + + +@wp.func +def _create_contact_manifold(clipping_quad: mat43f, clipping_normal: wp.vec3, subject_quad: mat43f, subject_normal: wp.vec3): + # Clip the subject (incident) face onto the clipping (reference) face. + # The incident points are clipped points on the subject polygon. + incident, mask = _clip_quad(subject_quad, subject_normal, clipping_quad, clipping_normal) + + clipping_normal_neg = -clipping_normal + d = wp.dot(clipping_quad[0], clipping_normal_neg) + _TINY_VAL + + for i in range(16): + if wp.dot(incident[i], clipping_normal_neg) < d: + mask[i] = wp.int8(0) + + ref = _project_poly_onto_plane(incident, clipping_normal, clipping_normal, clipping_quad[0]) + + # Choose four contact points. + best = _manifold_points(ref, mask, clipping_normal) + contact_pts = mat43f() + dist = wp.vec4f() + + for i in range(4): + idx = wp.int32(best[i]) + contact_pt = ref[idx] + contact_pts[i] = contact_pt + penetration_dir = incident[idx] - contact_pt + penetration = wp.dot(penetration_dir, clipping_normal) + dist[i] = wp.where(mask[idx], penetration, 1.0) + + return dist, contact_pts + + +def box_box_narrowphase( + m: Model, + d: Data, +): + """Calculates contacts between pairs of boxes.""" + kernel_ratio = 16 + nthread = math.ceil(d.nconmax / kernel_ratio) # parallel threads excluding tile dim + wp.launch_tiled( + kernel=_box_box, + dim=nthread, + inputs=[ + m.geom_type, + m.geom_condim, + m.geom_priority, + m.geom_solmix, + m.geom_solref, + m.geom_solimp, + m.geom_size, + m.geom_friction, + m.geom_margin, + m.geom_gap, + m.pair_dim, + m.pair_solref, + m.pair_solreffriction, + m.pair_solimp, + m.pair_margin, + m.pair_gap, + m.pair_friction, + d.nconmax, + d.geom_xpos, + d.geom_xmat, + d.collision_pair, + d.collision_pairid, + d.collision_worldid, + d.ncollision, + nthread, + ], + outputs=[ + d.ncon, + d.contact.dist, + d.contact.pos, + d.contact.frame, + d.contact.includemargin, + d.contact.dim, + d.contact.friction, + d.contact.solref, + d.contact.solreffriction, + d.contact.solimp, + d.contact.geom, + d.contact.worldid, + ], + block_dim=BOX_BOX_BLOCK_DIM, + ) diff --git a/mujoco_warp/_src/collision_convex.py b/mujoco_warp/_src/collision_convex.py index 6e8b5ce3..2f326759 100644 --- a/mujoco_warp/_src/collision_convex.py +++ b/mujoco_warp/_src/collision_convex.py @@ -709,23 +709,23 @@ def gjk_epa_sparse( geom_condim: wp.array(dtype=int), geom_dataid: wp.array(dtype=int), geom_priority: wp.array(dtype=int), - geom_solmix: wp.array(dtype=float), - geom_solref: wp.array(dtype=wp.vec2), - geom_solimp: wp.array(dtype=vec5), - geom_size: wp.array(dtype=wp.vec3), - geom_friction: wp.array(dtype=wp.vec3), - geom_margin: wp.array(dtype=float), - geom_gap: wp.array(dtype=float), + geom_solmix: wp.array2d(dtype=float), + geom_solref: wp.array2d(dtype=wp.vec2), + geom_solimp: wp.array2d(dtype=vec5), + geom_size: wp.array2d(dtype=wp.vec3), + geom_friction: wp.array2d(dtype=wp.vec3), + geom_margin: wp.array2d(dtype=float), + geom_gap: wp.array2d(dtype=float), mesh_vertadr: wp.array(dtype=int), mesh_vertnum: wp.array(dtype=int), mesh_vert: wp.array(dtype=wp.vec3), pair_dim: wp.array(dtype=int), - pair_solref: wp.array(dtype=wp.vec2), - pair_solreffriction: wp.array(dtype=wp.vec2), - pair_solimp: wp.array(dtype=vec5), - pair_margin: wp.array(dtype=float), - pair_gap: wp.array(dtype=float), - pair_friction: wp.array(dtype=vec5), + pair_solref: wp.array2d(dtype=wp.vec2), + pair_solreffriction: wp.array2d(dtype=wp.vec2), + pair_solimp: wp.array2d(dtype=vec5), + pair_margin: wp.array2d(dtype=float), + pair_gap: wp.array2d(dtype=float), + pair_friction: wp.array2d(dtype=vec5), # Data in: nconmax_in: int, geom_xpos_in: wp.array2d(dtype=wp.vec3), @@ -772,6 +772,7 @@ def gjk_epa_sparse( collision_pair_in, collision_pairid_in, tid, + worldid, ) g1 = geoms[0] @@ -783,7 +784,7 @@ def gjk_epa_sparse( geom1 = _geom( geom_type, geom_dataid, - geom_size, + geom_size[worldid], mesh_vertadr, mesh_vertnum, mesh_vert, @@ -796,7 +797,7 @@ def gjk_epa_sparse( geom2 = _geom( geom_type, geom_dataid, - geom_size, + geom_size[worldid], mesh_vertadr, mesh_vertnum, mesh_vert, @@ -806,7 +807,7 @@ def gjk_epa_sparse( g2, ) - margin = wp.max(geom_margin[g1], geom_margin[g2]) + margin = wp.max(geom_margin[worldid, g1], geom_margin[worldid, g2]) simplex, normal = _gjk(mesh_vert, geom1, geom2) diff --git a/mujoco_warp/_src/collision_driver.py b/mujoco_warp/_src/collision_driver.py index b50bab96..3bff5279 100644 --- a/mujoco_warp/_src/collision_driver.py +++ b/mujoco_warp/_src/collision_driver.py @@ -31,8 +31,8 @@ @wp.func def _sphere_filter( # Model: - geom_rbound: wp.array(dtype=float), - geom_margin: wp.array(dtype=float), + geom_rbound: wp.array2d(dtype=float), + geom_margin: wp.array2d(dtype=float), # Data in: geom_xpos_in: wp.array2d(dtype=wp.vec3), geom_xmat_in: wp.array2d(dtype=wp.mat33), @@ -41,12 +41,12 @@ def _sphere_filter( geom2: int, worldid: int, ) -> bool: - margin1 = geom_margin[geom1] - margin2 = geom_margin[geom2] + margin1 = geom_margin[worldid, geom1] + margin2 = geom_margin[worldid, geom2] pos1 = geom_xpos_in[worldid, geom1] pos2 = geom_xpos_in[worldid, geom2] - size1 = geom_rbound[geom1] - size2 = geom_rbound[geom2] + size1 = geom_rbound[worldid, geom1] + size2 = geom_rbound[worldid, geom2] bound = size1 + size2 + wp.max(margin1, margin2) dif = pos2 - pos1 @@ -123,8 +123,8 @@ def _upper_tri_index(n: int, i: int, j: int) -> int: @wp.kernel def _sap_project( # Model: - geom_rbound: wp.array(dtype=float), - geom_margin: wp.array(dtype=float), + geom_rbound: wp.array2d(dtype=float), + geom_margin: wp.array2d(dtype=float), # Data in: geom_xpos_in: wp.array2d(dtype=wp.vec3), # In: @@ -137,13 +137,13 @@ def _sap_project( worldid, geomid = wp.tid() xpos = geom_xpos_in[worldid, geomid] - rbound = geom_rbound[geomid] + rbound = geom_rbound[worldid, geomid] if rbound == 0.0: # geom is a plane rbound = MJ_MAXVAL - radius = rbound + geom_margin[geomid] + radius = rbound + geom_margin[worldid, geomid] center = wp.dot(direction_in, xpos) sap_projection_lower_out[worldid, geomid] = center - radius @@ -181,8 +181,8 @@ def _sap_broadphase( # Model: ngeom: int, geom_type: wp.array(dtype=int), - geom_rbound: wp.array(dtype=float), - geom_margin: wp.array(dtype=float), + geom_rbound: wp.array2d(dtype=float), + geom_margin: wp.array2d(dtype=float), nxn_pairid: wp.array(dtype=int), # Data in: nworld_in: int, @@ -342,8 +342,8 @@ def sap_broadphase(m: Model, d: Data): def _nxn_broadphase( # Model: geom_type: wp.array(dtype=int), - geom_rbound: wp.array(dtype=float), - geom_margin: wp.array(dtype=float), + geom_rbound: wp.array2d(dtype=float), + geom_margin: wp.array2d(dtype=float), nxn_geom_pair: wp.array(dtype=wp.vec2i), nxn_pairid: wp.array(dtype=int), # Data in: diff --git a/mujoco_warp/_src/collision_primitive.py b/mujoco_warp/_src/collision_primitive.py index 58957caa..ee60523d 100644 --- a/mujoco_warp/_src/collision_primitive.py +++ b/mujoco_warp/_src/collision_primitive.py @@ -56,7 +56,7 @@ def _geom( # Model: geom_type: wp.array(dtype=int), geom_dataid: wp.array(dtype=int), - geom_size: wp.array(dtype=wp.vec3), + geom_size: wp.array2d(dtype=wp.vec3), mesh_vertadr: wp.array(dtype=int), mesh_vertnum: wp.array(dtype=int), mesh_vert: wp.array(dtype=wp.vec3), @@ -71,7 +71,7 @@ def _geom( geom.pos = geom_xpos_in[worldid, gid] rot = geom_xmat_in[worldid, gid] geom.rot = rot - geom.size = geom_size[gid] + geom.size = geom_size[worldid, gid] geom.normal = wp.vec3(rot[0, 2], rot[1, 2], rot[2, 2]) # plane dataid = geom_dataid[gid] @@ -1268,36 +1268,37 @@ def contact_params( # Model: geom_condim: wp.array(dtype=int), geom_priority: wp.array(dtype=int), - geom_solmix: wp.array(dtype=float), - geom_solref: wp.array(dtype=wp.vec2), - geom_solimp: wp.array(dtype=vec5), - geom_friction: wp.array(dtype=wp.vec3), - geom_margin: wp.array(dtype=float), - geom_gap: wp.array(dtype=float), + geom_solmix: wp.array2d(dtype=float), + geom_solref: wp.array2d(dtype=wp.vec2), + geom_solimp: wp.array2d(dtype=vec5), + geom_friction: wp.array2d(dtype=wp.vec3), + geom_margin: wp.array2d(dtype=float), + geom_gap: wp.array2d(dtype=float), pair_dim: wp.array(dtype=int), - pair_solref: wp.array(dtype=wp.vec2), - pair_solreffriction: wp.array(dtype=wp.vec2), - pair_solimp: wp.array(dtype=vec5), - pair_margin: wp.array(dtype=float), - pair_gap: wp.array(dtype=float), - pair_friction: wp.array(dtype=vec5), + pair_solref: wp.array2d(dtype=wp.vec2), + pair_solreffriction: wp.array2d(dtype=wp.vec2), + pair_solimp: wp.array2d(dtype=vec5), + pair_margin: wp.array2d(dtype=float), + pair_gap: wp.array2d(dtype=float), + pair_friction: wp.array2d(dtype=vec5), # Data in: collision_pair_in: wp.array(dtype=wp.vec2i), collision_pairid_in: wp.array(dtype=int), # In: cid: int, + worldid: int, ): geoms = collision_pair_in[cid] pairid = collision_pairid_in[cid] if pairid > -1: - margin = pair_margin[pairid] - gap = pair_gap[pairid] + margin = pair_margin[worldid, pairid] + gap = pair_gap[worldid, pairid] condim = pair_dim[pairid] - friction = pair_friction[pairid] - solref = pair_solref[pairid] - solreffriction = pair_solreffriction[pairid] - solimp = pair_solimp[pairid] + friction = pair_friction[worldid, pairid] + solref = pair_solref[worldid, pairid] + solreffriction = pair_solreffriction[worldid, pairid] + solimp = pair_solimp[worldid, pairid] else: g1 = geoms[0] g2 = geoms[1] @@ -1305,8 +1306,8 @@ def contact_params( p1 = geom_priority[g1] p2 = geom_priority[g2] - solmix1 = geom_solmix[g1] - solmix2 = geom_solmix[g2] + solmix1 = geom_solmix[worldid, g1] + solmix2 = geom_solmix[worldid, g2] mix = solmix1 / (solmix1 + solmix2) mix = wp.where((solmix1 < MJ_MINVAL) and (solmix2 < MJ_MINVAL), 0.5, mix) @@ -1314,14 +1315,14 @@ def contact_params( mix = wp.where((solmix1 >= MJ_MINVAL) and (solmix2 < MJ_MINVAL), 1.0, mix) mix = wp.where(p1 == p2, mix, wp.where(p1 > p2, 1.0, 0.0)) - margin = wp.max(geom_margin[g1], geom_margin[g2]) - gap = wp.max(geom_gap[g1], geom_gap[g2]) + margin = wp.max(geom_margin[worldid, g1], geom_margin[worldid, g2]) + gap = wp.max(geom_gap[worldid, g1], geom_gap[worldid, g2]) condim1 = geom_condim[g1] condim2 = geom_condim[g2] condim = wp.where(p1 == p2, wp.max(condim1, condim2), wp.where(p1 > p2, condim1, condim2)) - max_geom_friction = wp.max(geom_friction[g1], geom_friction[g2]) + max_geom_friction = wp.max(geom_friction[worldid, g1], geom_friction[worldid, g2]) friction = vec5( max_geom_friction[0], max_geom_friction[0], @@ -1330,14 +1331,14 @@ def contact_params( max_geom_friction[2], ) - if geom_solref[g1].x > 0.0 and geom_solref[g2].x > 0.0: - solref = mix * geom_solref[g1] + (1.0 - mix) * geom_solref[g2] + if geom_solref[worldid, g1].x > 0.0 and geom_solref[worldid, g2].x > 0.0: + solref = mix * geom_solref[worldid, g1] + (1.0 - mix) * geom_solref[worldid, g2] else: - solref = wp.min(geom_solref[g1], geom_solref[g2]) + solref = wp.min(geom_solref[worldid, g1], geom_solref[worldid, g2]) solreffriction = wp.vec2(0.0, 0.0) - solimp = mix * geom_solimp[g1] + (1.0 - mix) * geom_solimp[g2] + solimp = mix * geom_solimp[worldid, g1] + (1.0 - mix) * geom_solimp[worldid, g2] return geoms, margin, gap, condim, friction, solref, solreffriction, solimp @@ -2380,23 +2381,23 @@ def _primitive_narrowphase( geom_condim: wp.array(dtype=int), geom_dataid: wp.array(dtype=int), geom_priority: wp.array(dtype=int), - geom_solmix: wp.array(dtype=float), - geom_solref: wp.array(dtype=wp.vec2), - geom_solimp: wp.array(dtype=vec5), - geom_size: wp.array(dtype=wp.vec3), - geom_friction: wp.array(dtype=wp.vec3), - geom_margin: wp.array(dtype=float), - geom_gap: wp.array(dtype=float), + geom_solmix: wp.array2d(dtype=float), + geom_solref: wp.array2d(dtype=wp.vec2), + geom_solimp: wp.array2d(dtype=vec5), + geom_size: wp.array2d(dtype=wp.vec3), + geom_friction: wp.array2d(dtype=wp.vec3), + geom_margin: wp.array2d(dtype=float), + geom_gap: wp.array2d(dtype=float), mesh_vertadr: wp.array(dtype=int), mesh_vertnum: wp.array(dtype=int), mesh_vert: wp.array(dtype=wp.vec3), pair_dim: wp.array(dtype=int), - pair_solref: wp.array(dtype=wp.vec2), - pair_solreffriction: wp.array(dtype=wp.vec2), - pair_solimp: wp.array(dtype=vec5), - pair_margin: wp.array(dtype=float), - pair_gap: wp.array(dtype=float), - pair_friction: wp.array(dtype=vec5), + pair_solref: wp.array2d(dtype=wp.vec2), + pair_solreffriction: wp.array2d(dtype=wp.vec2), + pair_solimp: wp.array2d(dtype=vec5), + pair_margin: wp.array2d(dtype=float), + pair_gap: wp.array2d(dtype=float), + pair_friction: wp.array2d(dtype=vec5), # Data in: nconmax_in: int, geom_xpos_in: wp.array2d(dtype=wp.vec3), @@ -2424,6 +2425,8 @@ def _primitive_narrowphase( if tid >= ncollision_in[0]: return + worldid = collision_worldid_in[tid] + geoms, margin, gap, condim, friction, solref, solreffriction, solimp = contact_params( geom_condim, geom_priority, @@ -2443,12 +2446,11 @@ def _primitive_narrowphase( collision_pair_in, collision_pairid_in, tid, + worldid, ) g1 = geoms[0] g2 = geoms[1] - worldid = collision_worldid_in[tid] - geom1 = _geom( geom_type, geom_dataid, diff --git a/mujoco_warp/_src/constraint.py b/mujoco_warp/_src/constraint.py index f7b5481b..6a09bdfe 100644 --- a/mujoco_warp/_src/constraint.py +++ b/mujoco_warp/_src/constraint.py @@ -99,15 +99,15 @@ def _efc_equality_connect( opt_timestep: float, body_parentid: wp.array(dtype=int), body_rootid: wp.array(dtype=int), - body_invweight0: wp.array2d(dtype=float), + body_invweight0: wp.array3d(dtype=float), dof_bodyid: wp.array(dtype=int), site_bodyid: wp.array(dtype=int), eq_obj1id: wp.array(dtype=int), eq_obj2id: wp.array(dtype=int), eq_objtype: wp.array(dtype=int), - eq_solref: wp.array(dtype=wp.vec2), - eq_solimp: wp.array(dtype=vec5), - eq_data: wp.array(dtype=vec11), + eq_solref: wp.array2d(dtype=wp.vec2), + eq_solimp: wp.array2d(dtype=vec5), + eq_data: wp.array2d(dtype=vec11), eq_connect_adr: wp.array(dtype=int), # Data in: nefc_in: wp.array(dtype=int), @@ -142,7 +142,7 @@ def _efc_equality_connect( necid = wp.atomic_add(ne_connect_out, 0, 3) efcid = nefc_in[0] + necid - data = eq_data[i_eq] + data = eq_data[worldid, i_eq] anchor1 = wp.vec3f(data[0], data[1], data[2]) anchor2 = wp.vec3f(data[3], data[4], data[5]) @@ -195,11 +195,11 @@ def _efc_equality_connect( efc_J_out[efcid + 2, dofid] = j1mj2[2] Jqvel += j1mj2 * qvel_in[worldid, dofid] - invweight = body_invweight0[body1id, 0] + body_invweight0[body2id, 0] + invweight = body_invweight0[worldid, body1id, 0] + body_invweight0[worldid, body2id, 0] pos_imp = wp.length(pos) - solref = eq_solref[i_eq] - solimp = eq_solimp[i_eq] + solref = eq_solref[worldid, i_eq] + solimp = eq_solimp[worldid, i_eq] for i in range(3): efcidi = efcid + i @@ -231,15 +231,15 @@ def _efc_equality_connect( def _efc_equality_joint( # Model: opt_timestep: float, - qpos0: wp.array(dtype=float), + qpos0: wp.array2d(dtype=float), jnt_qposadr: wp.array(dtype=int), jnt_dofadr: wp.array(dtype=int), - dof_invweight0: wp.array(dtype=float), + dof_invweight0: wp.array2d(dtype=float), eq_obj1id: wp.array(dtype=int), eq_obj2id: wp.array(dtype=int), - eq_solref: wp.array(dtype=wp.vec2), - eq_solimp: wp.array(dtype=vec5), - eq_data: wp.array(dtype=vec11), + eq_solref: wp.array2d(dtype=wp.vec2), + eq_solimp: wp.array2d(dtype=vec5), + eq_data: wp.array2d(dtype=vec11), eq_jnt_adr: wp.array(dtype=int), # Data in: ne_connect_in: wp.array(dtype=int), @@ -272,7 +272,7 @@ def _efc_equality_joint( jntid_1 = eq_obj1id[i_eq] jntid_2 = eq_obj2id[i_eq] - data = eq_data[i_eq] + data = eq_data[worldid, i_eq] dofadr1 = jnt_dofadr[jntid_1] qposadr1 = jnt_qposadr[jntid_1] efc_J_out[efcid, dofadr1] = 1.0 @@ -281,22 +281,22 @@ def _efc_equality_joint( # Two joint constraint qposadr2 = jnt_qposadr[jntid_2] dofadr2 = jnt_dofadr[jntid_2] - dif = qpos_in[worldid, qposadr2] - qpos0[qposadr2] + dif = qpos_in[worldid, qposadr2] - qpos0[worldid, qposadr2] # Horner's method for polynomials rhs = data[0] + dif * (data[1] + dif * (data[2] + dif * (data[3] + dif * data[4]))) deriv_2 = data[1] + dif * (2.0 * data[2] + dif * (3.0 * data[3] + dif * 4.0 * data[4])) - pos = qpos_in[worldid, qposadr1] - qpos0[qposadr1] - rhs + pos = qpos_in[worldid, qposadr1] - qpos0[worldid, qposadr1] - rhs Jqvel = qvel_in[worldid, dofadr1] - qvel_in[worldid, dofadr2] * deriv_2 - invweight = dof_invweight0[dofadr1] + dof_invweight0[dofadr2] + invweight = dof_invweight0[worldid, dofadr1] + dof_invweight0[worldid, dofadr2] efc_J_out[efcid, dofadr2] = -deriv_2 else: # Single joint constraint - pos = qpos_in[worldid, qposadr1] - qpos0[qposadr1] - data[0] + pos = qpos_in[worldid, qposadr1] - qpos0[worldid, qposadr1] - data[0] Jqvel = qvel_in[worldid, dofadr1] - invweight = dof_invweight0[dofadr1] + invweight = dof_invweight0[worldid, dofadr1] # Update constraint parameters _update_efc_row( @@ -306,8 +306,8 @@ def _efc_equality_joint( pos, pos, invweight, - eq_solref[i_eq], - eq_solimp[i_eq], + eq_solref[worldid, i_eq], + eq_solimp[worldid, i_eq], 0.0, Jqvel, 0.0, @@ -328,12 +328,12 @@ def _efc_equality_tendon( opt_timestep: float, eq_obj1id: wp.array(dtype=int), eq_obj2id: wp.array(dtype=int), - eq_solref: wp.array(dtype=wp.vec2), - eq_solimp: wp.array(dtype=vec5), - eq_data: wp.array(dtype=vec11), + eq_solref: wp.array2d(dtype=wp.vec2), + eq_solimp: wp.array2d(dtype=vec5), + eq_data: wp.array2d(dtype=vec11), eq_ten_adr: wp.array(dtype=int), - tendon_length0: wp.array(dtype=float), - tendon_invweight0: wp.array(dtype=float), + tendon_length0: wp.array2d(dtype=float), + tendon_invweight0: wp.array2d(dtype=float), # Data in: ne_connect_in: wp.array(dtype=int), ne_weld_in: wp.array(dtype=int), @@ -368,16 +368,16 @@ def _efc_equality_tendon( obj1id = eq_obj1id[eqid] obj2id = eq_obj2id[eqid] - data = eq_data[eqid] - solref = eq_solref[eqid] - solimp = eq_solimp[eqid] - pos1 = ten_length_in[worldid, obj1id] - tendon_length0[obj1id] - pos2 = ten_length_in[worldid, obj2id] - tendon_length0[obj2id] + data = eq_data[worldid, eqid] + solref = eq_solref[worldid, eqid] + solimp = eq_solimp[worldid, eqid] + pos1 = ten_length_in[worldid, obj1id] - tendon_length0[worldid, obj1id] + pos2 = ten_length_in[worldid, obj2id] - tendon_length0[worldid, obj2id] jac1 = ten_J_in[worldid, obj1id] jac2 = ten_J_in[worldid, obj2id] if obj2id > -1: - invweight = tendon_invweight0[obj1id] + tendon_invweight0[obj2id] + invweight = tendon_invweight0[worldid, obj1id] + tendon_invweight0[worldid, obj2id] dif = pos2 dif2 = dif * dif @@ -387,7 +387,7 @@ def _efc_equality_tendon( pos = pos1 - (data[0] + data[1] * dif + data[2] * dif2 + data[3] * dif3 + data[4] * dif4) deriv = data[1] + 2.0 * data[2] * dif + 3.0 * data[3] * dif2 + 4.0 * data[4] * dif3 else: - invweight = tendon_invweight0[obj1id] + invweight = tendon_invweight0[worldid, obj1id] pos = pos1 - data[0] deriv = 0.0 @@ -426,10 +426,10 @@ def _efc_equality_tendon( def _efc_friction( # Model: opt_timestep: float, - dof_invweight0: wp.array(dtype=float), - dof_frictionloss: wp.array(dtype=float), - dof_solimp: wp.array(dtype=vec5), - dof_solref: wp.array(dtype=wp.vec2), + dof_invweight0: wp.array2d(dtype=float), + dof_frictionloss: wp.array2d(dtype=float), + dof_solimp: wp.array2d(dtype=vec5), + dof_solref: wp.array2d(dtype=wp.vec2), # Data in: qvel_in: wp.array2d(dtype=float), # In: @@ -449,7 +449,7 @@ def _efc_friction( # TODO(team): tendon worldid, dofid = wp.tid() - if dof_frictionloss[dofid] <= 0.0: + if dof_frictionloss[worldid, dofid] <= 0.0: return efcid = wp.atomic_add(nefc_out, 0, 1) @@ -465,12 +465,12 @@ def _efc_friction( efcid, 0.0, 0.0, - dof_invweight0[dofid], - dof_solref[dofid], - dof_solimp[dofid], + dof_invweight0[worldid, dofid], + dof_solref[worldid, dofid], + dof_solimp[worldid, dofid], 0.0, Jqvel, - dof_frictionloss[dofid], + dof_frictionloss[worldid, dofid], dofid, efc_id_out, efc_pos_out, @@ -489,16 +489,16 @@ def _efc_equality_weld( opt_timestep: float, body_parentid: wp.array(dtype=int), body_rootid: wp.array(dtype=int), - body_invweight0: wp.array2d(dtype=float), + body_invweight0: wp.array3d(dtype=float), dof_bodyid: wp.array(dtype=int), site_bodyid: wp.array(dtype=int), - site_quat: wp.array(dtype=wp.quat), + site_quat: wp.array2d(dtype=wp.quat), eq_obj1id: wp.array(dtype=int), eq_obj2id: wp.array(dtype=int), eq_objtype: wp.array(dtype=int), - eq_solref: wp.array(dtype=wp.vec2), - eq_solimp: wp.array(dtype=vec5), - eq_data: wp.array(dtype=vec11), + eq_solref: wp.array2d(dtype=wp.vec2), + eq_solimp: wp.array2d(dtype=vec5), + eq_data: wp.array2d(dtype=vec11), eq_wld_adr: wp.array(dtype=int), # Data in: ne_connect_in: wp.array(dtype=int), @@ -539,7 +539,7 @@ def _efc_equality_weld( obj1id = eq_obj1id[i_eq] obj2id = eq_obj2id[i_eq] - data = eq_data[i_eq] + data = eq_data[worldid, i_eq] anchor1 = wp.vec3(data[0], data[1], data[2]) anchor2 = wp.vec3(data[3], data[4], data[5]) relpose = wp.quat(data[6], data[7], data[8], data[9]) @@ -552,8 +552,8 @@ def _efc_equality_weld( pos1 = site_xpos_in[worldid, obj1id] pos2 = site_xpos_in[worldid, obj2id] - quat = math.mul_quat(xquat_in[worldid, body1id], site_quat[obj1id]) - quat1 = math.quat_inv(math.mul_quat(xquat_in[worldid, body2id], site_quat[obj2id])) + quat = math.mul_quat(xquat_in[worldid, body1id], site_quat[worldid, obj1id]) + quat1 = math.quat_inv(math.mul_quat(xquat_in[worldid, body2id], site_quat[worldid, obj2id])) else: body1id = obj1id @@ -612,12 +612,12 @@ def _efc_equality_weld( crotq = math.mul_quat(quat1, quat) # copy axis components crot = wp.vec3(crotq[1], crotq[2], crotq[3]) * torquescale - invweight_t = body_invweight0[body1id, 0] + body_invweight0[body2id, 0] + invweight_t = body_invweight0[worldid, body1id, 0] + body_invweight0[worldid, body2id, 0] pos_imp = wp.sqrt(wp.length_sq(cpos) + wp.length_sq(crot)) - solref = eq_solref[i_eq] - solimp = eq_solimp[i_eq] + solref = eq_solref[worldid, i_eq] + solimp = eq_solimp[worldid, i_eq] for i in range(3): _update_efc_row( @@ -641,7 +641,7 @@ def _efc_equality_weld( efc_frictionloss_out, ) - invweight_r = body_invweight0[body1id, 1] + body_invweight0[body2id, 1] + invweight_r = body_invweight0[worldid, body1id, 1] + body_invweight0[worldid, body2id, 1] for i in range(3): _update_efc_row( @@ -672,12 +672,12 @@ def _efc_limit_slide_hinge( opt_timestep: float, jnt_qposadr: wp.array(dtype=int), jnt_dofadr: wp.array(dtype=int), - jnt_solref: wp.array(dtype=wp.vec2), - jnt_solimp: wp.array(dtype=vec5), - jnt_range: wp.array2d(dtype=float), - jnt_margin: wp.array(dtype=float), + jnt_solref: wp.array2d(dtype=wp.vec2), + jnt_solimp: wp.array2d(dtype=vec5), + jnt_range: wp.array3d(dtype=float), + jnt_margin: wp.array2d(dtype=float), jnt_limited_slide_hinge_adr: wp.array(dtype=int), - dof_invweight0: wp.array(dtype=float), + dof_invweight0: wp.array2d(dtype=float), # Data in: nefc_in: wp.array(dtype=int), qpos_in: wp.array2d(dtype=float), @@ -697,10 +697,10 @@ def _efc_limit_slide_hinge( ): worldid, jntlimitedid = wp.tid() jntid = jnt_limited_slide_hinge_adr[jntlimitedid] - jntrange = jnt_range[jntid] + jntrange = jnt_range[worldid, jntid] qpos = qpos_in[worldid, jnt_qposadr[jntid]] - jntmargin = jnt_margin[jntid] + jntmargin = jnt_margin[worldid, jntid] dist_min, dist_max = qpos - jntrange[0], jntrange[1] - qpos pos = wp.min(dist_min, dist_max) - jntmargin active = pos < 0 @@ -722,9 +722,9 @@ def _efc_limit_slide_hinge( efcid, pos, pos, - dof_invweight0[dofadr], - jnt_solref[jntid], - jnt_solimp[jntid], + dof_invweight0[worldid, dofadr], + jnt_solref[worldid, jntid], + jnt_solimp[worldid, jntid], jntmargin, Jqvel, 0.0, @@ -744,12 +744,12 @@ def _efc_limit_ball( opt_timestep: float, jnt_qposadr: wp.array(dtype=int), jnt_dofadr: wp.array(dtype=int), - jnt_solref: wp.array(dtype=wp.vec2), - jnt_solimp: wp.array(dtype=vec5), - jnt_range: wp.array2d(dtype=float), - jnt_margin: wp.array(dtype=float), + jnt_solref: wp.array2d(dtype=wp.vec2), + jnt_solimp: wp.array2d(dtype=vec5), + jnt_range: wp.array3d(dtype=float), + jnt_margin: wp.array2d(dtype=float), jnt_limited_ball_adr: wp.array(dtype=int), - dof_invweight0: wp.array(dtype=float), + dof_invweight0: wp.array2d(dtype=float), # Data in: nefc_in: wp.array(dtype=int), qpos_in: wp.array2d(dtype=float), @@ -774,9 +774,9 @@ def _efc_limit_ball( qpos = qpos_in[worldid] jnt_quat = wp.quat(qpos[qposadr + 0], qpos[qposadr + 1], qpos[qposadr + 2], qpos[qposadr + 3]) axis_angle = math.quat_to_vel(jnt_quat) - jntrange = jnt_range[jntid] + jntrange = jnt_range[worldid, jntid] axis, angle = math.normalize_with_norm(axis_angle) - jntmargin = jnt_margin[jntid] + jntmargin = jnt_margin[worldid, jntid] pos = wp.max(jntrange[0], jntrange[1]) - angle - jntmargin active = pos < 0 @@ -802,9 +802,9 @@ def _efc_limit_ball( efcid, pos, pos, - dof_invweight0[dofadr], - jnt_solref[jntid], - jnt_solimp[jntid], + dof_invweight0[worldid, dofadr], + jnt_solref[worldid, jntid], + jnt_solimp[worldid, jntid], jntmargin, Jqvel, 0.0, @@ -827,11 +827,11 @@ def _efc_limit_tendon( tendon_adr: wp.array(dtype=int), tendon_num: wp.array(dtype=int), tendon_limited_adr: wp.array(dtype=int), - tendon_solref_lim: wp.array(dtype=wp.vec2), - tendon_solimp_lim: wp.array(dtype=vec5), - tendon_range: wp.array(dtype=wp.vec2), - tendon_margin: wp.array(dtype=float), - tendon_invweight0: wp.array(dtype=float), + tendon_solref_lim: wp.array2d(dtype=wp.vec2), + tendon_solimp_lim: wp.array2d(dtype=vec5), + tendon_range: wp.array2d(dtype=wp.vec2), + tendon_margin: wp.array2d(dtype=float), + tendon_invweight0: wp.array2d(dtype=float), wrap_objid: wp.array(dtype=int), wrap_type: wp.array(dtype=int), # Data in: @@ -855,10 +855,10 @@ def _efc_limit_tendon( worldid, tenlimitedid = wp.tid() tenid = tendon_limited_adr[tenlimitedid] - tenrange = tendon_range[tenid] + tenrange = tendon_range[worldid, tenid] length = ten_length_in[worldid, tenid] dist_min, dist_max = length - tenrange[0], tenrange[1] - length - tenmargin = tendon_margin[tenid] + tenmargin = tendon_margin[worldid, tenid] pos = wp.min(dist_min, dist_max) - tenmargin active = pos < 0 @@ -890,9 +890,9 @@ def _efc_limit_tendon( efcid, pos, pos, - tendon_invweight0[tenid], - tendon_solref_lim[tenid], - tendon_solimp_lim[tenid], + tendon_invweight0[worldid, tenid], + tendon_solref_lim[worldid, tenid], + tendon_solimp_lim[worldid, tenid], tenmargin, Jqvel, 0.0, @@ -914,7 +914,7 @@ def _efc_contact_pyramidal( opt_impratio: float, body_parentid: wp.array(dtype=int), body_rootid: wp.array(dtype=int), - body_invweight0: wp.array2d(dtype=float), + body_invweight0: wp.array3d(dtype=float), dof_bodyid: wp.array(dtype=int), geom_bodyid: wp.array(dtype=int), # Data in: @@ -974,7 +974,7 @@ def _efc_contact_pyramidal( frame = frame_in[conid] # pyramidal has common invweight across all edges - invweight = body_invweight0[body1, 0] + body_invweight0[body2, 0] + invweight = body_invweight0[worldid, body1, 0] + body_invweight0[worldid, body2, 0] if condim > 1: dimid2 = dimid / 2 + 1 @@ -1060,7 +1060,7 @@ def _efc_contact_elliptic( opt_impratio: float, body_parentid: wp.array(dtype=int), body_rootid: wp.array(dtype=int), - body_invweight0: wp.array2d(dtype=float), + body_invweight0: wp.array3d(dtype=float), dof_bodyid: wp.array(dtype=int), geom_bodyid: wp.array(dtype=int), # Data in: @@ -1158,7 +1158,7 @@ def _efc_contact_elliptic( efc_J_out[efcid, i] = J Jqvel += J * qvel_in[worldid, i] - invweight = body_invweight0[body1, 0] + body_invweight0[body2, 0] + invweight = body_invweight0[worldid, body1, 0] + body_invweight0[worldid, body2, 0] ref = solref_in[conid] pos_aref = pos diff --git a/mujoco_warp/_src/forward.py b/mujoco_warp/_src/forward.py index 0b104c75..c3c9cb7e 100644 --- a/mujoco_warp/_src/forward.py +++ b/mujoco_warp/_src/forward.py @@ -141,8 +141,8 @@ def _next_activation( opt_timestep: float, actuator_dyntype: wp.array(dtype=int), actuator_actlimited: wp.array(dtype=bool), - actuator_dynprm: wp.array(dtype=vec10f), - actuator_actrange: wp.array(dtype=wp.vec2), + actuator_dynprm: wp.array2d(dtype=vec10f), + actuator_actrange: wp.array2d(dtype=wp.vec2), # Data in: act_in: wp.array2d(dtype=float), act_dot_in: wp.array2d(dtype=float), @@ -159,7 +159,7 @@ def _next_activation( # advance the actuation if actuator_dyntype[actid] == wp.static(DynType.FILTEREXACT.value): - dyn_prm = actuator_dynprm[actid] + dyn_prm = actuator_dynprm[worldid, actid] tau = wp.max(MJ_MINVAL, dyn_prm[0]) act += act_dot_scale_in * act_dot * tau * (1.0 - wp.exp(-opt_timestep / tau)) else: @@ -167,7 +167,7 @@ def _next_activation( # clamp to actrange if limit and actuator_actlimited[actid]: - actrange = actuator_actrange[actid] + actrange = actuator_actrange[worldid, actid] act = wp.clamp(act, actrange[0], actrange[1]) act_out[worldid, actid] = act @@ -267,7 +267,7 @@ def _euler_damp_qfrc_sparse( # Model: opt_timestep: float, dof_Madr: wp.array(dtype=int), - dof_damping: wp.array(dtype=float), + dof_damping: wp.array2d(dtype=float), # Data in: qfrc_smooth_in: wp.array2d(dtype=float), qfrc_constraint_in: wp.array2d(dtype=float), @@ -278,7 +278,7 @@ def _euler_damp_qfrc_sparse( worldid, tid = wp.tid() adr = dof_Madr[tid] - qM_integration_out[worldid, 0, adr] += opt_timestep * dof_damping[tid] + qM_integration_out[worldid, 0, adr] += opt_timestep * dof_damping[worldid, tid] qfrc_integration_out[worldid, tid] = qfrc_smooth_in[worldid, tid] + qfrc_constraint_in[worldid, tid] @@ -314,7 +314,7 @@ def _tile_euler_dense(tile: TileSet): @nested_kernel def euler_dense( # Model: - dof_damping: wp.array(dtype=float), + dof_damping: wp.array2d(dtype=float), opt_timestep: float, # Data in: qM_in: wp.array3d(dtype=float), @@ -330,7 +330,7 @@ def euler_dense( dofid = adr_in[nodeid] M_tile = wp.tile_load(qM_in[worldid], shape=(TILE_SIZE, TILE_SIZE), offset=(dofid, dofid)) - damping_tile = wp.tile_load(dof_damping, shape=(TILE_SIZE,), offset=(dofid,)) + damping_tile = wp.tile_load(dof_damping[worldid], shape=(TILE_SIZE,), offset=(dofid,)) damping_scaled = damping_tile * opt_timestep qm_integration_tile = wp.tile_diag_add(M_tile, damping_scaled) @@ -481,8 +481,8 @@ def _implicit_actuator_bias_gain_vel( actuator_dyntype: wp.array(dtype=int), actuator_gaintype: wp.array(dtype=int), actuator_biastype: wp.array(dtype=int), - actuator_gainprm: wp.array(dtype=vec10f), - actuator_biasprm: wp.array(dtype=vec10f), + actuator_gainprm: wp.array2d(dtype=vec10f), + actuator_biasprm: wp.array2d(dtype=vec10f), # Data in: act_in: wp.array2d(dtype=float), ctrl_in: wp.array2d(dtype=float), @@ -492,12 +492,12 @@ def _implicit_actuator_bias_gain_vel( worldid, actid = wp.tid() if actuator_biastype[actid] == wp.static(BiasType.AFFINE.value): - bias_vel = actuator_biasprm[actid][2] + bias_vel = actuator_biasprm[worldid, actid][2] else: bias_vel = 0.0 if actuator_gaintype[actid] == wp.static(GainType.AFFINE.value): - gain_vel = actuator_gainprm[actid][2] + gain_vel = actuator_gainprm[worldid, actid][2] else: gain_vel = 0.0 @@ -523,7 +523,7 @@ def subtract_multiply(x: float, y: float): @nested_kernel def implicit_actuator_qderiv( # Model: - dof_damping: wp.array(dtype=float), + dof_damping: wp.array2d(dtype=float), # Data in: actuator_moment_in: wp.array3d(dtype=float), qM_in: wp.array3d(dtype=float), @@ -562,7 +562,7 @@ def implicit_actuator_qderiv( qderiv_tile = wp.tile_zeros(shape=(TILE_NV_SIZE, TILE_NV_SIZE), dtype=wp.float32) if wp.static(passive_enabled): - dof_damping_tile = wp.tile_load(dof_damping, shape=TILE_NV_SIZE, offset=offset_nv) + dof_damping_tile = wp.tile_load(dof_damping[worldid], shape=TILE_NV_SIZE, offset=offset_nv) negative = wp.neg(dof_damping_tile) qderiv_tile = wp.tile_diag_add(qderiv_tile, negative) @@ -818,11 +818,11 @@ def _actuator_force( actuator_actnum: wp.array(dtype=int), actuator_ctrllimited: wp.array(dtype=bool), actuator_forcelimited: wp.array(dtype=bool), - actuator_dynprm: wp.array(dtype=vec10f), - actuator_gainprm: wp.array(dtype=vec10f), - actuator_biasprm: wp.array(dtype=vec10f), - actuator_ctrlrange: wp.array(dtype=wp.vec2), - actuator_forcerange: wp.array(dtype=wp.vec2), + actuator_dynprm: wp.array2d(dtype=vec10f), + actuator_gainprm: wp.array2d(dtype=vec10f), + actuator_biasprm: wp.array2d(dtype=vec10f), + actuator_ctrlrange: wp.array2d(dtype=wp.vec2), + actuator_forcerange: wp.array2d(dtype=wp.vec2), # Data in: act_in: wp.array2d(dtype=float), ctrl_in: wp.array2d(dtype=float), @@ -839,7 +839,7 @@ def _actuator_force( ctrl = ctrl_in[worldid, uid] if actuator_ctrllimited[uid] and not dsbl_clampctrl: - ctrlrange = actuator_ctrlrange[uid] + ctrlrange = actuator_ctrlrange[worldid, uid] ctrl = wp.clamp(ctrl, ctrlrange[0], ctrlrange[1]) if na: @@ -848,7 +848,7 @@ def _actuator_force( if dyntype == int(DynType.INTEGRATOR.value): act_dot_out[worldid, actuator_actadr[uid]] = ctrl elif dyntype == int(DynType.FILTER.value) or dyntype == int(DynType.FILTEREXACT.value): - dynprm = actuator_dynprm[uid] + dynprm = actuator_dynprm[worldid, uid] actadr = actuator_actadr[uid] act = act_in[worldid, actadr] act_dot_out[worldid, actadr] = (ctrl - act) / wp.max(dynprm[0], MJ_MINVAL) @@ -867,7 +867,7 @@ def _actuator_force( # gain gaintype = actuator_gaintype[uid] - gainprm = actuator_gainprm[uid] + gainprm = actuator_gainprm[worldid, uid] gain = 0.0 if gaintype == int(GainType.FIXED.value): @@ -879,7 +879,7 @@ def _actuator_force( # bias biastype = actuator_biastype[uid] - biasprm = actuator_biasprm[uid] + biasprm = actuator_biasprm[worldid, uid] bias = 0.0 # BiasType.NONE if biastype == int(BiasType.AFFINE.value): @@ -892,7 +892,7 @@ def _actuator_force( # TODO(team): tendon total force clamping if actuator_forcelimited[uid]: - forcerange = actuator_forcerange[uid] + forcerange = actuator_forcerange[worldid, uid] force = wp.clamp(force, forcerange[0], forcerange[1]) actuator_force_out[worldid, uid] = force @@ -904,7 +904,7 @@ def _qfrc_actuator_sparse( nu: int, ngravcomp: int, jnt_actfrclimited: wp.array(dtype=bool), - jnt_actfrcrange: wp.array(dtype=wp.vec2), + jnt_actfrcrange: wp.array2d(dtype=wp.vec2), jnt_actgravcomp: wp.array(dtype=int), dof_jntid: wp.array(dtype=int), # Data in: @@ -928,7 +928,7 @@ def _qfrc_actuator_sparse( qfrc += qfrc_gravcomp_in[worldid, dofid] if jnt_actfrclimited[jntid]: - frcrange = jnt_actfrcrange[jntid] + frcrange = jnt_actfrcrange[worldid, jntid] qfrc = wp.clamp(qfrc, frcrange[0], frcrange[1]) qfrc_actuator_out[worldid, dofid] = qfrc @@ -939,7 +939,7 @@ def _qfrc_actuator_limited( # Model: ngravcomp: int, jnt_actfrclimited: wp.array(dtype=bool), - jnt_actfrcrange: wp.array(dtype=wp.vec2), + jnt_actfrcrange: wp.array2d(dtype=wp.vec2), jnt_actgravcomp: wp.array(dtype=int), dof_jntid: wp.array(dtype=int), # Data in: @@ -957,7 +957,7 @@ def _qfrc_actuator_limited( qfrc_dof += qfrc_gravcomp_in[worldid, dofid] if jnt_actfrclimited[jntid]: - frcrange = jnt_actfrcrange[jntid] + frcrange = jnt_actfrcrange[worldid, jntid] qfrc_dof = wp.clamp(qfrc_dof, frcrange[0], frcrange[1]) qfrc_actuator_out[worldid, dofid] = qfrc_dof diff --git a/mujoco_warp/_src/io.py b/mujoco_warp/_src/io.py index 43dd60a7..8df48a2d 100644 --- a/mujoco_warp/_src/io.py +++ b/mujoco_warp/_src/io.py @@ -248,6 +248,13 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: nxn_geom_pair.append((geom1, geom2)) nxn_pairid.append(pairid) + def create_nmodel_batched_array(mjm_array, dtype): + array = wp.array(mjm_array, dtype=dtype) + array.ndim += 1 + array.shape = (1,) + array.shape + array.strides = (0,) + array.strides + return array + m = types.Model( nq=mjm.nq, nv=mjm.nv, @@ -297,8 +304,8 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: stat=types.Statistic( meaninertia=mjm.stat.meaninertia, ), - qpos0=wp.array(mjm.qpos0, dtype=float), - qpos_spring=wp.array(mjm.qpos_spring, dtype=float), + qpos0=create_nmodel_batched_array(mjm.qpos0, dtype=float), + qpos_spring=create_nmodel_batched_array(mjm.qpos_spring, dtype=float), qM_fullm_i=wp.array(qM_fullm_i, dtype=int), qM_fullm_j=wp.array(qM_fullm_j, dtype=int), qM_mulm_i=wp.array(qM_mulm_i, dtype=int), @@ -322,32 +329,32 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: body_dofadr=wp.array(mjm.body_dofadr, dtype=int), body_geomnum=wp.array(mjm.body_geomnum, dtype=int), body_geomadr=wp.array(mjm.body_geomadr, dtype=int), - body_pos=wp.array(mjm.body_pos, dtype=wp.vec3), - body_quat=wp.array(mjm.body_quat, dtype=wp.quat), - body_ipos=wp.array(mjm.body_ipos, dtype=wp.vec3), - body_iquat=wp.array(mjm.body_iquat, dtype=wp.quat), - body_mass=wp.array(mjm.body_mass, dtype=float), - body_subtreemass=wp.array(mjm.body_subtreemass, dtype=float), - subtree_mass=wp.array(subtree_mass, dtype=float), - body_inertia=wp.array(mjm.body_inertia, dtype=wp.vec3), - body_invweight0=wp.array(mjm.body_invweight0, dtype=float), + body_pos=create_nmodel_batched_array(mjm.body_pos, dtype=wp.vec3), + body_quat=create_nmodel_batched_array(mjm.body_quat, dtype=wp.quat), + body_ipos=create_nmodel_batched_array(mjm.body_ipos, dtype=wp.vec3), + body_iquat=create_nmodel_batched_array(mjm.body_iquat, dtype=wp.quat), + body_mass=create_nmodel_batched_array(mjm.body_mass, dtype=float), + body_subtreemass=create_nmodel_batched_array(mjm.body_subtreemass, dtype=float), + subtree_mass=create_nmodel_batched_array(subtree_mass, dtype=float), + body_inertia=create_nmodel_batched_array(mjm.body_inertia, dtype=wp.vec3), + body_invweight0=create_nmodel_batched_array(mjm.body_invweight0, dtype=float), body_contype=wp.array(mjm.body_contype, dtype=int), body_conaffinity=wp.array(mjm.body_conaffinity, dtype=int), - body_gravcomp=wp.array(mjm.body_gravcomp, dtype=float), + body_gravcomp=create_nmodel_batched_array(mjm.body_gravcomp, dtype=float), jnt_type=wp.array(mjm.jnt_type, dtype=int), jnt_qposadr=wp.array(mjm.jnt_qposadr, dtype=int), jnt_dofadr=wp.array(mjm.jnt_dofadr, dtype=int), jnt_bodyid=wp.array(mjm.jnt_bodyid, dtype=int), jnt_limited=wp.array(mjm.jnt_limited, dtype=int), jnt_actfrclimited=wp.array(mjm.jnt_actfrclimited, dtype=bool), - jnt_solref=wp.array(mjm.jnt_solref, dtype=wp.vec2), - jnt_solimp=wp.array(mjm.jnt_solimp, dtype=types.vec5), - jnt_pos=wp.array(mjm.jnt_pos, dtype=wp.vec3), - jnt_axis=wp.array(mjm.jnt_axis, dtype=wp.vec3), - jnt_stiffness=wp.array(mjm.jnt_stiffness, dtype=float), - jnt_range=wp.array(mjm.jnt_range, dtype=float), - jnt_actfrcrange=wp.array(mjm.jnt_actfrcrange, dtype=wp.vec2), - jnt_margin=wp.array(mjm.jnt_margin, dtype=float), + jnt_solref=create_nmodel_batched_array(mjm.jnt_solref, dtype=wp.vec2), + jnt_solimp=create_nmodel_batched_array(mjm.jnt_solimp, dtype=types.vec5), + jnt_pos=create_nmodel_batched_array(mjm.jnt_pos, dtype=wp.vec3), + jnt_axis=create_nmodel_batched_array(mjm.jnt_axis, dtype=wp.vec3), + jnt_stiffness=create_nmodel_batched_array(mjm.jnt_stiffness, dtype=float), + jnt_range=create_nmodel_batched_array(mjm.jnt_range, dtype=float), + jnt_actfrcrange=create_nmodel_batched_array(mjm.jnt_actfrcrange, dtype=wp.vec2), + jnt_margin=create_nmodel_batched_array(mjm.jnt_margin, dtype=float), # these jnt_limited adrs are used in constraint.py jnt_limited_slide_hinge_adr=wp.array( np.nonzero( @@ -364,12 +371,12 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: dof_jntid=wp.array(mjm.dof_jntid, dtype=int), dof_parentid=wp.array(mjm.dof_parentid, dtype=int), dof_Madr=wp.array(mjm.dof_Madr, dtype=int), - dof_armature=wp.array(mjm.dof_armature, dtype=float), - dof_damping=wp.array(mjm.dof_damping, dtype=float), - dof_invweight0=wp.array(mjm.dof_invweight0, dtype=float), - dof_frictionloss=wp.array(mjm.dof_frictionloss, dtype=float), - dof_solimp=wp.array(mjm.dof_solimp, dtype=types.vec5), - dof_solref=wp.array(mjm.dof_solref, dtype=wp.vec2), + dof_armature=create_nmodel_batched_array(mjm.dof_armature, dtype=float), + dof_damping=create_nmodel_batched_array(mjm.dof_damping, dtype=float), + dof_invweight0=create_nmodel_batched_array(mjm.dof_invweight0, dtype=float), + dof_frictionloss=create_nmodel_batched_array(mjm.dof_frictionloss, dtype=float), + dof_solimp=create_nmodel_batched_array(mjm.dof_solimp, dtype=types.vec5), + dof_solref=create_nmodel_batched_array(mjm.dof_solref, dtype=wp.vec2), dof_tri_row=wp.array(dof_tri_row, dtype=int), dof_tri_col=wp.array(dof_tri_col, dtype=int), geom_type=wp.array(mjm.geom_type, dtype=int), @@ -379,30 +386,30 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: geom_bodyid=wp.array(mjm.geom_bodyid, dtype=int), geom_dataid=wp.array(mjm.geom_dataid, dtype=int), geom_group=wp.array(mjm.geom_group, dtype=int), - geom_matid=wp.array(mjm.geom_matid, dtype=int), + geom_matid=create_nmodel_batched_array(mjm.geom_matid, dtype=int), geom_priority=wp.array(mjm.geom_priority, dtype=int), - geom_solmix=wp.array(mjm.geom_solmix, dtype=float), - geom_solref=wp.array(mjm.geom_solref, dtype=wp.vec2), - geom_solimp=wp.array(mjm.geom_solimp, dtype=types.vec5), - geom_size=wp.array(mjm.geom_size, dtype=wp.vec3), + geom_solmix=create_nmodel_batched_array(mjm.geom_solmix, dtype=float), + geom_solref=create_nmodel_batched_array(mjm.geom_solref, dtype=wp.vec2), + geom_solimp=create_nmodel_batched_array(mjm.geom_solimp, dtype=types.vec5), + geom_size=create_nmodel_batched_array(mjm.geom_size, dtype=wp.vec3), geom_aabb=wp.array(mjm.geom_aabb, dtype=wp.vec3), - geom_rbound=wp.array(mjm.geom_rbound, dtype=float), - geom_pos=wp.array(mjm.geom_pos, dtype=wp.vec3), - geom_quat=wp.array(mjm.geom_quat, dtype=wp.quat), - geom_friction=wp.array(mjm.geom_friction, dtype=wp.vec3), - geom_margin=wp.array(mjm.geom_margin, dtype=float), - geom_gap=wp.array(mjm.geom_gap, dtype=float), - geom_rgba=wp.array(mjm.geom_rgba, dtype=wp.vec4), + geom_rbound=create_nmodel_batched_array(mjm.geom_rbound, dtype=float), + geom_pos=create_nmodel_batched_array(mjm.geom_pos, dtype=wp.vec3), + geom_quat=create_nmodel_batched_array(mjm.geom_quat, dtype=wp.quat), + geom_friction=create_nmodel_batched_array(mjm.geom_friction, dtype=wp.vec3), + geom_margin=create_nmodel_batched_array(mjm.geom_margin, dtype=float), + geom_gap=create_nmodel_batched_array(mjm.geom_gap, dtype=float), + geom_rgba=create_nmodel_batched_array(mjm.geom_rgba, dtype=wp.vec4), site_bodyid=wp.array(mjm.site_bodyid, dtype=int), - site_pos=wp.array(mjm.site_pos, dtype=wp.vec3), - site_quat=wp.array(mjm.site_quat, dtype=wp.quat), + site_pos=create_nmodel_batched_array(mjm.site_pos, dtype=wp.vec3), + site_quat=create_nmodel_batched_array(mjm.site_quat, dtype=wp.quat), cam_mode=wp.array(mjm.cam_mode, dtype=int), cam_bodyid=wp.array(mjm.cam_bodyid, dtype=int), cam_targetbodyid=wp.array(mjm.cam_targetbodyid, dtype=int), - cam_pos=wp.array(mjm.cam_pos, dtype=wp.vec3), - cam_quat=wp.array(mjm.cam_quat, dtype=wp.quat), - cam_poscom0=wp.array(mjm.cam_poscom0, dtype=wp.vec3), - cam_pos0=wp.array(mjm.cam_pos0, dtype=wp.vec3), + cam_pos=create_nmodel_batched_array(mjm.cam_pos, dtype=wp.vec3), + cam_quat=create_nmodel_batched_array(mjm.cam_quat, dtype=wp.quat), + cam_poscom0=create_nmodel_batched_array(mjm.cam_poscom0, dtype=wp.vec3), + cam_pos0=create_nmodel_batched_array(mjm.cam_pos0, dtype=wp.vec3), cam_fovy=wp.array(mjm.cam_fovy, dtype=float), cam_resolution=wp.array(mjm.cam_resolution, dtype=wp.vec2i), cam_sensorsize=wp.array(mjm.cam_sensorsize, dtype=wp.vec2), @@ -410,10 +417,10 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: light_mode=wp.array(mjm.light_mode, dtype=int), light_bodyid=wp.array(mjm.light_bodyid, dtype=int), light_targetbodyid=wp.array(mjm.light_targetbodyid, dtype=int), - light_pos=wp.array(mjm.light_pos, dtype=wp.vec3), - light_dir=wp.array(mjm.light_dir, dtype=wp.vec3), - light_poscom0=wp.array(mjm.light_poscom0, dtype=wp.vec3), - light_pos0=wp.array(mjm.light_pos0, dtype=wp.vec3), + light_pos=create_nmodel_batched_array(mjm.light_pos, dtype=wp.vec3), + light_dir=create_nmodel_batched_array(mjm.light_dir, dtype=wp.vec3), + light_poscom0=create_nmodel_batched_array(mjm.light_poscom0, dtype=wp.vec3), + light_pos0=create_nmodel_batched_array(mjm.light_pos0, dtype=wp.vec3), mesh_vertadr=wp.array(mjm.mesh_vertadr, dtype=int), mesh_vertnum=wp.array(mjm.mesh_vertnum, dtype=int), mesh_vert=wp.array(mjm.mesh_vert, dtype=wp.vec3), @@ -424,9 +431,9 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: eq_obj2id=wp.array(mjm.eq_obj2id, dtype=int), eq_objtype=wp.array(mjm.eq_objtype, dtype=int), eq_active0=wp.array(mjm.eq_active0, dtype=bool), - eq_solref=wp.array(mjm.eq_solref, dtype=wp.vec2), - eq_solimp=wp.array(mjm.eq_solimp, dtype=types.vec5), - eq_data=wp.array(mjm.eq_data, dtype=types.vec11), + eq_solref=create_nmodel_batched_array(mjm.eq_solref, dtype=wp.vec2), + eq_solimp=create_nmodel_batched_array(mjm.eq_solimp, dtype=types.vec5), + eq_data=create_nmodel_batched_array(mjm.eq_data, dtype=types.vec11), # pre-compute indices of equality constraints eq_connect_adr=wp.array(np.nonzero(mjm.eq_type == types.EqType.CONNECT.value)[0], dtype=int), eq_wld_adr=wp.array(np.nonzero(mjm.eq_type == types.EqType.WELD.value)[0], dtype=int), @@ -444,13 +451,13 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: actuator_ctrllimited=wp.array(mjm.actuator_ctrllimited, dtype=bool), actuator_forcelimited=wp.array(mjm.actuator_forcelimited, dtype=bool), actuator_actlimited=wp.array(mjm.actuator_actlimited, dtype=bool), - actuator_dynprm=wp.array(mjm.actuator_dynprm, dtype=types.vec10f), - actuator_gainprm=wp.array(mjm.actuator_gainprm, dtype=types.vec10f), - actuator_biasprm=wp.array(mjm.actuator_biasprm, dtype=types.vec10f), - actuator_ctrlrange=wp.array(mjm.actuator_ctrlrange, dtype=wp.vec2), - actuator_forcerange=wp.array(mjm.actuator_forcerange, dtype=wp.vec2), - actuator_actrange=wp.array(mjm.actuator_actrange, dtype=wp.vec2), - actuator_gear=wp.array(mjm.actuator_gear, dtype=wp.spatial_vector), + actuator_dynprm=create_nmodel_batched_array(mjm.actuator_dynprm, dtype=types.vec10f), + actuator_gainprm=create_nmodel_batched_array(mjm.actuator_gainprm, dtype=types.vec10f), + actuator_biasprm=create_nmodel_batched_array(mjm.actuator_biasprm, dtype=types.vec10f), + actuator_ctrlrange=create_nmodel_batched_array(mjm.actuator_ctrlrange, dtype=wp.vec2), + actuator_forcerange=create_nmodel_batched_array(mjm.actuator_forcerange, dtype=wp.vec2), + actuator_actrange=create_nmodel_batched_array(mjm.actuator_actrange, dtype=wp.vec2), + actuator_gear=create_nmodel_batched_array(mjm.actuator_gear, dtype=wp.spatial_vector), exclude_signature=wp.array(mjm.exclude_signature, dtype=int), # short-circuiting here allows us to skip a lot of code in implicit integration actuator_affine_bias_gain=bool( @@ -462,23 +469,23 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: pair_dim=wp.array(mjm.pair_dim, dtype=int), pair_geom1=wp.array(mjm.pair_geom1, dtype=int), pair_geom2=wp.array(mjm.pair_geom2, dtype=int), - pair_solref=wp.array(mjm.pair_solref, dtype=wp.vec2), - pair_solreffriction=wp.array(mjm.pair_solreffriction, dtype=wp.vec2), - pair_solimp=wp.array(mjm.pair_solimp, dtype=types.vec5), - pair_margin=wp.array(mjm.pair_margin, dtype=float), - pair_gap=wp.array(mjm.pair_gap, dtype=float), - pair_friction=wp.array(mjm.pair_friction, dtype=types.vec5), + pair_solref=create_nmodel_batched_array(mjm.pair_solref, dtype=wp.vec2), + pair_solreffriction=create_nmodel_batched_array(mjm.pair_solreffriction, dtype=wp.vec2), + pair_solimp=create_nmodel_batched_array(mjm.pair_solimp, dtype=types.vec5), + pair_margin=create_nmodel_batched_array(mjm.pair_margin, dtype=float), + pair_gap=create_nmodel_batched_array(mjm.pair_gap, dtype=float), + pair_friction=create_nmodel_batched_array(mjm.pair_friction, dtype=types.vec5), condim_max=np.max(mjm.pair_dim) if mjm.npair else np.max(mjm.geom_condim), # TODO(team): get max after filtering, tendon_adr=wp.array(mjm.tendon_adr, dtype=int), tendon_num=wp.array(mjm.tendon_num, dtype=int), tendon_limited=wp.array(mjm.tendon_limited, dtype=int), tendon_limited_adr=wp.array(np.nonzero(mjm.tendon_limited)[0], dtype=wp.int32, ndim=1), - tendon_solref_lim=wp.array(mjm.tendon_solref_lim, dtype=wp.vec2f), - tendon_solimp_lim=wp.array(mjm.tendon_solimp_lim, dtype=types.vec5), - tendon_range=wp.array(mjm.tendon_range, dtype=wp.vec2f), - tendon_margin=wp.array(mjm.tendon_margin, dtype=float), - tendon_length0=wp.array(mjm.tendon_length0, dtype=float), - tendon_invweight0=wp.array(mjm.tendon_invweight0, dtype=float), + tendon_solref_lim=create_nmodel_batched_array(mjm.tendon_solref_lim, dtype=wp.vec2f), + tendon_solimp_lim=create_nmodel_batched_array(mjm.tendon_solimp_lim, dtype=types.vec5), + tendon_range=create_nmodel_batched_array(mjm.tendon_range, dtype=wp.vec2f), + tendon_margin=create_nmodel_batched_array(mjm.tendon_margin, dtype=float), + tendon_length0=create_nmodel_batched_array(mjm.tendon_length0, dtype=float), + tendon_invweight0=create_nmodel_batched_array(mjm.tendon_invweight0, dtype=float), wrap_objid=wp.array(mjm.wrap_objid, dtype=int), wrap_prm=wp.array(mjm.wrap_prm, dtype=float), wrap_type=wp.array(mjm.wrap_type, dtype=int), @@ -525,7 +532,7 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: mujoco.mjtSensor.mjSENS_FRAMEANGACC, ], ).any(), - mat_rgba=wp.array(mjm.mat_rgba, dtype=wp.vec4), + mat_rgba=create_nmodel_batched_array(mjm.mat_rgba, dtype=wp.vec4), ) return m diff --git a/mujoco_warp/_src/passive.py b/mujoco_warp/_src/passive.py index 80e58eed..e5adacee 100644 --- a/mujoco_warp/_src/passive.py +++ b/mujoco_warp/_src/passive.py @@ -29,18 +29,18 @@ @wp.kernel def _spring_passive( # Model: - qpos_spring: wp.array(dtype=float), + qpos_spring: wp.array2d(dtype=float), jnt_type: wp.array(dtype=int), jnt_qposadr: wp.array(dtype=int), jnt_dofadr: wp.array(dtype=int), - jnt_stiffness: wp.array(dtype=float), + jnt_stiffness: wp.array2d(dtype=float), # Data in: qpos_in: wp.array2d(dtype=float), # Data out: qfrc_spring_out: wp.array2d(dtype=float), ): worldid, jntid = wp.tid() - stiffness = jnt_stiffness[jntid] + stiffness = jnt_stiffness[worldid, jntid] dofid = jnt_dofadr[jntid] if stiffness == 0.0: @@ -51,9 +51,9 @@ def _spring_passive( if jnttype == wp.static(JointType.FREE.value): dif = wp.vec3( - qpos_in[worldid, qposid + 0] - qpos_spring[qposid + 0], - qpos_in[worldid, qposid + 1] - qpos_spring[qposid + 1], - qpos_in[worldid, qposid + 2] - qpos_spring[qposid + 2], + qpos_in[worldid, qposid + 0] - qpos_spring[worldid, qposid + 0], + qpos_in[worldid, qposid + 1] - qpos_spring[worldid, qposid + 1], + qpos_in[worldid, qposid + 2] - qpos_spring[worldid, qposid + 2], ) qfrc_spring_out[worldid, dofid + 0] = -stiffness * dif[0] qfrc_spring_out[worldid, dofid + 1] = -stiffness * dif[1] @@ -65,10 +65,10 @@ def _spring_passive( qpos_in[worldid, qposid + 6], ) ref = wp.quat( - qpos_spring[qposid + 3], - qpos_spring[qposid + 4], - qpos_spring[qposid + 5], - qpos_spring[qposid + 6], + qpos_spring[worldid, qposid + 3], + qpos_spring[worldid, qposid + 4], + qpos_spring[worldid, qposid + 5], + qpos_spring[worldid, qposid + 6], ) dif = math.quat_sub(rot, ref) qfrc_spring_out[worldid, dofid + 3] = -stiffness * dif[0] @@ -82,17 +82,17 @@ def _spring_passive( qpos_in[worldid, qposid + 3], ) ref = wp.quat( - qpos_spring[qposid + 0], - qpos_spring[qposid + 1], - qpos_spring[qposid + 2], - qpos_spring[qposid + 3], + qpos_spring[worldid, qposid + 0], + qpos_spring[worldid, qposid + 1], + qpos_spring[worldid, qposid + 2], + qpos_spring[worldid, qposid + 3], ) dif = math.quat_sub(rot, ref) qfrc_spring_out[worldid, dofid + 0] = -stiffness * dif[0] qfrc_spring_out[worldid, dofid + 1] = -stiffness * dif[1] qfrc_spring_out[worldid, dofid + 2] = -stiffness * dif[2] else: # mjJNT_SLIDE, mjJNT_HINGE - fdif = qpos_in[worldid, qposid] - qpos_spring[qposid] + fdif = qpos_in[worldid, qposid] - qpos_spring[worldid, qposid] qfrc_spring_out[worldid, dofid] = -stiffness * fdif @@ -102,8 +102,8 @@ def _gravity_force( opt_gravity: wp.vec3, body_parentid: wp.array(dtype=int), body_rootid: wp.array(dtype=int), - body_mass: wp.array(dtype=float), - body_gravcomp: wp.array(dtype=float), + body_mass: wp.array2d(dtype=float), + body_gravcomp: wp.array2d(dtype=float), dof_bodyid: wp.array(dtype=int), # Data in: xipos_in: wp.array2d(dtype=wp.vec3), @@ -114,10 +114,10 @@ def _gravity_force( ): worldid, bodyid, dofid = wp.tid() bodyid += 1 # skip world body - gravcomp = body_gravcomp[bodyid] + gravcomp = body_gravcomp[worldid, bodyid] if gravcomp: - force = -opt_gravity * body_mass[bodyid] * gravcomp + force = -opt_gravity * body_mass[worldid, bodyid] * gravcomp pos = xipos_in[worldid, bodyid] jac, _ = support.jac(body_parentid, body_rootid, dof_bodyid, subtree_com_in, cdof_in, pos, bodyid, dofid, worldid) @@ -132,8 +132,8 @@ def _box_fluid( opt_density: float, opt_viscosity: float, body_rootid: wp.array(dtype=int), - body_mass: wp.array(dtype=float), - body_inertia: wp.array(dtype=wp.vec3), + body_mass: wp.array2d(dtype=float), + body_inertia: wp.array2d(dtype=wp.vec3), # Data in: xipos_in: wp.array2d(dtype=wp.vec3), ximat_in: wp.array2d(dtype=wp.mat33), @@ -175,8 +175,8 @@ def _box_fluid( density = opt_density > 0.0 if viscosity or density: - inertia = body_inertia[bodyid] - mass = body_mass[bodyid] + inertia = body_inertia[worldid, bodyid] + mass = body_mass[worldid, bodyid] scl = 6.0 / mass box0 = wp.sqrt(wp.max(MJ_MINVAL, inertia[1] + inertia[2] - inertia[0]) * scl) box1 = wp.sqrt(wp.max(MJ_MINVAL, inertia[0] + inertia[2] - inertia[1]) * scl) @@ -249,7 +249,7 @@ def _qfrc_passive( # Model: jnt_actgravcomp: wp.array(dtype=int), dof_jntid: wp.array(dtype=int), - dof_damping: wp.array(dtype=float), + dof_damping: wp.array2d(dtype=float), # Data in: qvel_in: wp.array2d(dtype=float), qfrc_spring_in: wp.array2d(dtype=float), @@ -268,7 +268,7 @@ def _qfrc_passive( qfrc_passive = qfrc_spring_in[worldid, dofid] # damper - qfrc_damper = -dof_damping[dofid] * qvel_in[worldid, dofid] + qfrc_damper = -dof_damping[worldid, dofid] * qvel_in[worldid, dofid] qfrc_damper_out[worldid, dofid] = qfrc_damper qfrc_passive += qfrc_damper diff --git a/mujoco_warp/_src/ray.py b/mujoco_warp/_src/ray.py index eb713526..f877d1fd 100644 --- a/mujoco_warp/_src/ray.py +++ b/mujoco_warp/_src/ray.py @@ -425,7 +425,7 @@ def _ray_geom_with_mesh( nmeshface: int, geom_type: wp.array(dtype=int), geom_dataid: wp.array(dtype=int), - geom_size: wp.array(dtype=wp.vec3), + geom_size: wp.array2d(dtype=wp.vec3), mesh_vertadr: wp.array(dtype=int), mesh_vertnum: wp.array(dtype=int), mesh_vert: wp.array(dtype=wp.vec3), @@ -435,9 +435,10 @@ def _ray_geom_with_mesh( geom_id: int, pnt: wp.vec3, vec: wp.vec3, + worldid: int, ) -> DistanceWithId: type = geom_type[geom_id] - size = geom_size[geom_id] + size = geom_size[worldid, geom_id] # TODO(team): static loop unrolling to remove unnecessary branching if type == int(GeomType.PLANE.value): @@ -497,15 +498,15 @@ def _ray_all_geom( geom_bodyid: wp.array(dtype=int), geom_dataid: wp.array(dtype=int), geom_group: wp.array(dtype=int), - geom_matid: wp.array(dtype=int), - geom_size: wp.array(dtype=wp.vec3), - geom_rgba: wp.array(dtype=wp.vec4), + geom_matid: wp.array2d(dtype=int), + geom_size: wp.array2d(dtype=wp.vec3), + geom_rgba: wp.array2d(dtype=wp.vec4), mesh_vertadr: wp.array(dtype=int), mesh_vertnum: wp.array(dtype=int), mesh_vert: wp.array(dtype=wp.vec3), mesh_faceadr: wp.array(dtype=int), mesh_face: wp.array(dtype=wp.vec3i), - mat_rgba: wp.array(dtype=wp.vec4), + mat_rgba: wp.array2d(dtype=wp.vec4), # Data in: geom_xpos_in: wp.array2d(dtype=wp.vec3), geom_xmat_in: wp.array2d(dtype=wp.mat33), @@ -546,11 +547,11 @@ def _ray_all_geom( geom_filter = geom_filter and (geomgroup[group] != 0) # RGBA filter - matid = geom_matid[geom_id] - geom_alpha = geom_rgba[geom_id][3] + matid = geom_matid[worldid, geom_id] + geom_alpha = geom_rgba[worldid, geom_id][3] mat_alpha = wp.float32(0.0) if matid != -1: - mat_alpha = mat_rgba[matid][3] + mat_alpha = mat_rgba[worldid, matid][3] # Geom is visible if either: # 1. No material and non-zero geom alpha, or @@ -581,6 +582,7 @@ def _ray_all_geom( geom_id, local_pnt, local_vec, + worldid, ) cur_dist = result.dist else: @@ -612,15 +614,15 @@ def _ray_all_geom_kernel( geom_bodyid: wp.array(dtype=int), geom_dataid: wp.array(dtype=int), geom_group: wp.array(dtype=int), - geom_matid: wp.array(dtype=int), - geom_size: wp.array(dtype=wp.vec3), - geom_rgba: wp.array(dtype=wp.vec4), + geom_matid: wp.array2d(dtype=int), + geom_size: wp.array2d(dtype=wp.vec3), + geom_rgba: wp.array2d(dtype=wp.vec4), mesh_vertadr: wp.array(dtype=int), mesh_vertnum: wp.array(dtype=int), mesh_vert: wp.array(dtype=wp.vec3), mesh_faceadr: wp.array(dtype=int), mesh_face: wp.array(dtype=wp.vec3i), - mat_rgba: wp.array(dtype=wp.vec4), + mat_rgba: wp.array2d(dtype=wp.vec4), # Data in: geom_xpos_in: wp.array2d(dtype=wp.vec3), geom_xmat_in: wp.array2d(dtype=wp.mat33), diff --git a/mujoco_warp/_src/ray_test.py b/mujoco_warp/_src/ray_test.py index 4bddd676..61655320 100644 --- a/mujoco_warp/_src/ray_test.py +++ b/mujoco_warp/_src/ray_test.py @@ -268,7 +268,7 @@ def test_ray_invisible(self): mjm, mjd, m, d = test_util.fixture("ray.xml") # nothing hit with transparent geoms - m.geom_rgba = wp.array([wp.vec4(0.0, 0.0, 0.0, 0.0)], dtype=wp.vec4) + m.geom_rgba = wp.array2d([[wp.vec4(0.0, 0.0, 0.0, 0.0)]], dtype=wp.vec4) mujoco.mj_forward(mjm, mjd) pnt = wp.array([wp.vec3(2.0, 1.0, 3.0)], dtype=wp.vec3) diff --git a/mujoco_warp/_src/sensor.py b/mujoco_warp/_src/sensor.py index a1e3a385..2a5f43c0 100644 --- a/mujoco_warp/_src/sensor.py +++ b/mujoco_warp/_src/sensor.py @@ -292,13 +292,13 @@ def _frame_axis( @wp.func def _frame_quat( # Model: - body_iquat: wp.array(dtype=wp.quat), + body_iquat: wp.array2d(dtype=wp.quat), geom_bodyid: wp.array(dtype=int), - geom_quat: wp.array(dtype=wp.quat), + geom_quat: wp.array2d(dtype=wp.quat), site_bodyid: wp.array(dtype=int), - site_quat: wp.array(dtype=wp.quat), + site_quat: wp.array2d(dtype=wp.quat), cam_bodyid: wp.array(dtype=int), - cam_quat: wp.array(dtype=wp.quat), + cam_quat: wp.array2d(dtype=wp.quat), # Data in: xquat_in: wp.array2d(dtype=wp.quat), # In: @@ -309,15 +309,15 @@ def _frame_quat( reftype: int, ) -> wp.quat: if objtype == int(ObjType.BODY.value): - quat = math.mul_quat(xquat_in[worldid, objid], body_iquat[objid]) + quat = math.mul_quat(xquat_in[worldid, objid], body_iquat[worldid, objid]) elif objtype == int(ObjType.XBODY.value): quat = xquat_in[worldid, objid] elif objtype == int(ObjType.GEOM.value): - quat = math.mul_quat(xquat_in[worldid, geom_bodyid[objid]], geom_quat[objid]) + quat = math.mul_quat(xquat_in[worldid, geom_bodyid[objid]], geom_quat[worldid, objid]) elif objtype == int(ObjType.SITE.value): - quat = math.mul_quat(xquat_in[worldid, site_bodyid[objid]], site_quat[objid]) + quat = math.mul_quat(xquat_in[worldid, site_bodyid[objid]], site_quat[worldid, objid]) elif objtype == int(ObjType.CAMERA.value): - quat = math.mul_quat(xquat_in[worldid, cam_bodyid[objid]], cam_quat[objid]) + quat = math.mul_quat(xquat_in[worldid, cam_bodyid[objid]], cam_quat[worldid, objid]) else: # UNKNOWN quat = wp.quat(1.0, 0.0, 0.0, 0.0) @@ -325,15 +325,15 @@ def _frame_quat( return quat if reftype == int(ObjType.BODY.value): - refquat = math.mul_quat(xquat_in[worldid, refid], body_iquat[refid]) + refquat = math.mul_quat(xquat_in[worldid, refid], body_iquat[worldid, refid]) elif reftype == int(ObjType.XBODY.value): refquat = xquat_in[worldid, refid] elif reftype == int(ObjType.GEOM.value): - refquat = math.mul_quat(xquat_in[worldid, geom_bodyid[refid]], geom_quat[refid]) + refquat = math.mul_quat(xquat_in[worldid, geom_bodyid[refid]], geom_quat[worldid, refid]) elif reftype == int(ObjType.SITE.value): - refquat = math.mul_quat(xquat_in[worldid, site_bodyid[refid]], site_quat[refid]) + refquat = math.mul_quat(xquat_in[worldid, site_bodyid[refid]], site_quat[worldid, refid]) elif reftype == int(ObjType.CAMERA.value): - refquat = math.mul_quat(xquat_in[worldid, cam_bodyid[refid]], cam_quat[refid]) + refquat = math.mul_quat(xquat_in[worldid, cam_bodyid[refid]], cam_quat[worldid, refid]) else: # UNKNOWN refquat = wp.quat(1.0, 0.0, 0.0, 0.0) @@ -353,14 +353,14 @@ def _clock(time_in: wp.array(dtype=float), worldid: int) -> float: @wp.kernel def _sensor_pos( # Model: - body_iquat: wp.array(dtype=wp.quat), + body_iquat: wp.array2d(dtype=wp.quat), jnt_qposadr: wp.array(dtype=int), geom_bodyid: wp.array(dtype=int), - geom_quat: wp.array(dtype=wp.quat), + geom_quat: wp.array2d(dtype=wp.quat), site_bodyid: wp.array(dtype=int), - site_quat: wp.array(dtype=wp.quat), + site_quat: wp.array2d(dtype=wp.quat), cam_bodyid: wp.array(dtype=int), - cam_quat: wp.array(dtype=wp.quat), + cam_quat: wp.array2d(dtype=wp.quat), cam_fovy: wp.array(dtype=float), cam_resolution: wp.array(dtype=wp.vec2i), cam_sensorsize: wp.array(dtype=wp.vec2), diff --git a/mujoco_warp/_src/smooth.py b/mujoco_warp/_src/smooth.py index 06f6dfa1..b7db29d9 100644 --- a/mujoco_warp/_src/smooth.py +++ b/mujoco_warp/_src/smooth.py @@ -60,18 +60,18 @@ def _kinematics_root( @wp.kernel def _kinematics_level( # Model: - qpos0: wp.array(dtype=float), + qpos0: wp.array2d(dtype=float), body_parentid: wp.array(dtype=int), body_jntnum: wp.array(dtype=int), body_jntadr: wp.array(dtype=int), - body_pos: wp.array(dtype=wp.vec3), - body_quat: wp.array(dtype=wp.quat), - body_ipos: wp.array(dtype=wp.vec3), - body_iquat: wp.array(dtype=wp.quat), + body_pos: wp.array2d(dtype=wp.vec3), + body_quat: wp.array2d(dtype=wp.quat), + body_ipos: wp.array2d(dtype=wp.vec3), + body_iquat: wp.array2d(dtype=wp.quat), jnt_type: wp.array(dtype=int), jnt_qposadr: wp.array(dtype=int), - jnt_pos: wp.array(dtype=wp.vec3), - jnt_axis: wp.array(dtype=wp.vec3), + jnt_pos: wp.array2d(dtype=wp.vec3), + jnt_axis: wp.array2d(dtype=wp.vec3), # Data in: qpos_in: wp.array2d(dtype=float), xpos_in: wp.array2d(dtype=wp.vec3), @@ -97,27 +97,27 @@ def _kinematics_level( if jntnum == 0: # no joints - apply fixed translation and rotation relative to parent pid = body_parentid[bodyid] - xpos = (xmat_in[worldid, pid] * body_pos[bodyid]) + xpos_in[worldid, pid] - xquat = math.mul_quat(xquat_in[worldid, pid], body_quat[bodyid]) + xpos = (xmat_in[worldid, pid] * body_pos[worldid, bodyid]) + xpos_in[worldid, pid] + xquat = math.mul_quat(xquat_in[worldid, pid], body_quat[worldid, bodyid]) elif jntnum == 1 and jnt_type[jntadr] == wp.static(JointType.FREE.value): # free joint qadr = jnt_qposadr[jntadr] xpos = wp.vec3(qpos[qadr], qpos[qadr + 1], qpos[qadr + 2]) xquat = wp.quat(qpos[qadr + 3], qpos[qadr + 4], qpos[qadr + 5], qpos[qadr + 6]) xanchor_out[worldid, jntadr] = xpos - xaxis_out[worldid, jntadr] = jnt_axis[jntadr] + xaxis_out[worldid, jntadr] = jnt_axis[worldid, jntadr] else: # regular or no joints # apply fixed translation and rotation relative to parent pid = body_parentid[bodyid] - xpos = (xmat_in[worldid, pid] * body_pos[bodyid]) + xpos_in[worldid, pid] - xquat = math.mul_quat(xquat_in[worldid, pid], body_quat[bodyid]) + xpos = (xmat_in[worldid, pid] * body_pos[worldid, bodyid]) + xpos_in[worldid, pid] + xquat = math.mul_quat(xquat_in[worldid, pid], body_quat[worldid, bodyid]) for _ in range(jntnum): qadr = jnt_qposadr[jntadr] jnt_type_ = jnt_type[jntadr] - jnt_axis_ = jnt_axis[jntadr] - xanchor = math.rot_vec_quat(jnt_pos[jntadr], xquat) + xpos + jnt_axis_ = jnt_axis[worldid, jntadr] + xanchor = math.rot_vec_quat(jnt_pos[worldid, jntadr], xquat) + xpos xaxis = math.rot_vec_quat(jnt_axis_, xquat) if jnt_type_ == wp.static(JointType.BALL.value): @@ -129,15 +129,15 @@ def _kinematics_level( ) xquat = math.mul_quat(xquat, qloc) # correct for off-center rotation - xpos = xanchor - math.rot_vec_quat(jnt_pos[jntadr], xquat) + xpos = xanchor - math.rot_vec_quat(jnt_pos[worldid, jntadr], xquat) elif jnt_type_ == wp.static(JointType.SLIDE.value): - xpos += xaxis * (qpos[qadr] - qpos0[qadr]) + xpos += xaxis * (qpos[qadr] - qpos0[worldid, qadr]) elif jnt_type_ == wp.static(JointType.HINGE.value): - qpos0_ = qpos0[qadr] + qpos0_ = qpos0[worldid, qadr] qloc_ = math.axis_angle_to_quat(jnt_axis_, qpos[qadr] - qpos0_) xquat = math.mul_quat(xquat, qloc_) # correct for off-center rotation - xpos = xanchor - math.rot_vec_quat(jnt_pos[jntadr], xquat) + xpos = xanchor - math.rot_vec_quat(jnt_pos[worldid, jntadr], xquat) xanchor_out[worldid, jntadr] = xanchor xaxis_out[worldid, jntadr] = xaxis @@ -146,16 +146,16 @@ def _kinematics_level( xpos_out[worldid, bodyid] = xpos xquat_out[worldid, bodyid] = wp.normalize(xquat) xmat_out[worldid, bodyid] = math.quat_to_mat(xquat) - xipos_out[worldid, bodyid] = xpos + math.rot_vec_quat(body_ipos[bodyid], xquat) - ximat_out[worldid, bodyid] = math.quat_to_mat(math.mul_quat(xquat, body_iquat[bodyid])) + xipos_out[worldid, bodyid] = xpos + math.rot_vec_quat(body_ipos[worldid, bodyid], xquat) + ximat_out[worldid, bodyid] = math.quat_to_mat(math.mul_quat(xquat, body_iquat[worldid, bodyid])) @wp.kernel def _geom_local_to_global( # Model: geom_bodyid: wp.array(dtype=int), - geom_pos: wp.array(dtype=wp.vec3), - geom_quat: wp.array(dtype=wp.quat), + geom_pos: wp.array2d(dtype=wp.vec3), + geom_quat: wp.array2d(dtype=wp.quat), # Data in: xpos_in: wp.array2d(dtype=wp.vec3), xquat_in: wp.array2d(dtype=wp.quat), @@ -167,16 +167,16 @@ def _geom_local_to_global( bodyid = geom_bodyid[geomid] xpos = xpos_in[worldid, bodyid] xquat = xquat_in[worldid, bodyid] - geom_xpos_out[worldid, geomid] = xpos + math.rot_vec_quat(geom_pos[geomid], xquat) - geom_xmat_out[worldid, geomid] = math.quat_to_mat(math.mul_quat(xquat, geom_quat[geomid])) + geom_xpos_out[worldid, geomid] = xpos + math.rot_vec_quat(geom_pos[worldid, geomid], xquat) + geom_xmat_out[worldid, geomid] = math.quat_to_mat(math.mul_quat(xquat, geom_quat[worldid, geomid])) @wp.kernel def _site_local_to_global( # Model: site_bodyid: wp.array(dtype=int), - site_pos: wp.array(dtype=wp.vec3), - site_quat: wp.array(dtype=wp.quat), + site_pos: wp.array2d(dtype=wp.vec3), + site_quat: wp.array2d(dtype=wp.quat), # Data in: xpos_in: wp.array2d(dtype=wp.vec3), xquat_in: wp.array2d(dtype=wp.quat), @@ -188,15 +188,15 @@ def _site_local_to_global( bodyid = site_bodyid[siteid] xpos = xpos_in[worldid, bodyid] xquat = xquat_in[worldid, bodyid] - site_xpos_out[worldid, siteid] = xpos + math.rot_vec_quat(site_pos[siteid], xquat) - site_xmat_out[worldid, siteid] = math.quat_to_mat(math.mul_quat(xquat, site_quat[siteid])) + site_xpos_out[worldid, siteid] = xpos + math.rot_vec_quat(site_pos[worldid, siteid], xquat) + site_xmat_out[worldid, siteid] = math.quat_to_mat(math.mul_quat(xquat, site_quat[worldid, siteid])) @wp.kernel def _mocap( # Model: - body_ipos: wp.array(dtype=wp.vec3), - body_iquat: wp.array(dtype=wp.quat), + body_ipos: wp.array2d(dtype=wp.vec3), + body_iquat: wp.array2d(dtype=wp.quat), mocap_bodyid: wp.array(dtype=int), # Data in: mocap_pos_in: wp.array2d(dtype=wp.vec3), @@ -215,8 +215,8 @@ def _mocap( xpos_out[worldid, bodyid] = xpos xquat_out[worldid, bodyid] = mocap_quat xmat_out[worldid, bodyid] = math.quat_to_mat(mocap_quat) - xipos_out[worldid, bodyid] = xpos + math.rot_vec_quat(body_ipos[bodyid], mocap_quat) - ximat_out[worldid, bodyid] = math.quat_to_mat(math.mul_quat(mocap_quat, body_iquat[bodyid])) + xipos_out[worldid, bodyid] = xpos + math.rot_vec_quat(body_ipos[worldid, bodyid], mocap_quat) + ximat_out[worldid, bodyid] = math.quat_to_mat(math.mul_quat(mocap_quat, body_iquat[worldid, bodyid])) @event_scope @@ -285,14 +285,14 @@ def kinematics(m: Model, d: Data): @wp.kernel def _subtree_com_init( # Model: - body_mass: wp.array(dtype=float), + body_mass: wp.array2d(dtype=float), # Data in: xipos_in: wp.array2d(dtype=wp.vec3), # Data out: xipos_out: wp.array2d(dtype=wp.vec3), ): worldid, bodyid = wp.tid() - xipos_out[worldid, bodyid] = xipos_in[worldid, bodyid] * body_mass[bodyid] + xipos_out[worldid, bodyid] = xipos_in[worldid, bodyid] * body_mass[worldid, bodyid] @wp.kernel @@ -315,20 +315,20 @@ def _subtree_com_acc( @wp.kernel def _subtree_div( # Model: - subtree_mass: wp.array(dtype=float), + subtree_mass: wp.array2d(dtype=float), # Data out: subtree_com_out: wp.array2d(dtype=wp.vec3), ): worldid, bodyid = wp.tid() - subtree_com_out[worldid, bodyid] /= subtree_mass[bodyid] + subtree_com_out[worldid, bodyid] /= subtree_mass[worldid, bodyid] @wp.kernel def _cinert( # Model: body_rootid: wp.array(dtype=int), - body_mass: wp.array(dtype=float), - body_inertia: wp.array(dtype=wp.vec3), + body_mass: wp.array2d(dtype=float), + body_inertia: wp.array2d(dtype=wp.vec3), # Data in: xipos_in: wp.array2d(dtype=wp.vec3), ximat_in: wp.array2d(dtype=wp.mat33), @@ -338,8 +338,8 @@ def _cinert( ): worldid, bodyid = wp.tid() mat = ximat_in[worldid, bodyid] - inert = body_inertia[bodyid] - mass = body_mass[bodyid] + inert = body_inertia[worldid, bodyid] + mass = body_mass[worldid, bodyid] dif = xipos_in[worldid, bodyid] - subtree_com_in[worldid, body_rootid[bodyid]] # express inertia in com-based frame (mju_inertCom) @@ -472,8 +472,8 @@ def com_pos(m: Model, d: Data): def _cam_local_to_global( # Model: cam_bodyid: wp.array(dtype=int), - cam_pos: wp.array(dtype=wp.vec3), - cam_quat: wp.array(dtype=wp.quat), + cam_pos: wp.array2d(dtype=wp.vec3), + cam_quat: wp.array2d(dtype=wp.quat), # Data in: xpos_in: wp.array2d(dtype=wp.vec3), xquat_in: wp.array2d(dtype=wp.quat), @@ -486,8 +486,8 @@ def _cam_local_to_global( bodyid = cam_bodyid[camid] xpos = xpos_in[worldid, bodyid] xquat = xquat_in[worldid, bodyid] - cam_xpos_out[worldid, camid] = xpos + math.rot_vec_quat(cam_pos[camid], xquat) - cam_xmat_out[worldid, camid] = math.quat_to_mat(math.mul_quat(xquat, cam_quat[camid])) + cam_xpos_out[worldid, camid] = xpos + math.rot_vec_quat(cam_pos[worldid, camid], xquat) + cam_xmat_out[worldid, camid] = math.quat_to_mat(math.mul_quat(xquat, cam_quat[worldid, camid])) @wp.kernel @@ -496,8 +496,8 @@ def _cam_fn( cam_mode: wp.array(dtype=int), cam_bodyid: wp.array(dtype=int), cam_targetbodyid: wp.array(dtype=int), - cam_poscom0: wp.array(dtype=wp.vec3), - cam_pos0: wp.array(dtype=wp.vec3), + cam_poscom0: wp.array2d(dtype=wp.vec3), + cam_pos0: wp.array2d(dtype=wp.vec3), # Data in: xpos_in: wp.array2d(dtype=wp.vec3), subtree_com_in: wp.array2d(dtype=wp.vec3), @@ -514,9 +514,9 @@ def _cam_fn( return elif cam_mode[camid] == wp.static(CamLightType.TRACK.value): body_xpos = xpos_in[worldid, cam_bodyid[camid]] - cam_xpos_out[worldid, camid] = body_xpos + cam_pos0[camid] + cam_xpos_out[worldid, camid] = body_xpos + cam_pos0[worldid, camid] elif cam_mode[camid] == wp.static(CamLightType.TRACKCOM.value): - cam_xpos_out[worldid, camid] = subtree_com_in[worldid, cam_bodyid[camid]] + cam_poscom0[camid] + cam_xpos_out[worldid, camid] = subtree_com_in[worldid, cam_bodyid[camid]] + cam_poscom0[worldid, camid] elif cam_mode[camid] == wp.static(CamLightType.TARGETBODY.value) or cam_mode[camid] == wp.static( CamLightType.TARGETBODYCOM.value ): @@ -541,8 +541,8 @@ def _cam_fn( def _light_local_to_global( # Model: light_bodyid: wp.array(dtype=int), - light_pos: wp.array(dtype=wp.vec3), - light_dir: wp.array(dtype=wp.vec3), + light_pos: wp.array2d(dtype=wp.vec3), + light_dir: wp.array2d(dtype=wp.vec3), # Data in: xpos_in: wp.array2d(dtype=wp.vec3), xquat_in: wp.array2d(dtype=wp.quat), @@ -555,8 +555,8 @@ def _light_local_to_global( bodyid = light_bodyid[lightid] xpos = xpos_in[worldid, bodyid] xquat = xquat_in[worldid, bodyid] - light_xpos_out[worldid, lightid] = xpos + math.rot_vec_quat(light_pos[lightid], xquat) - light_xdir_out[worldid, lightid] = math.rot_vec_quat(light_dir[lightid], xquat) + light_xpos_out[worldid, lightid] = xpos + math.rot_vec_quat(light_pos[worldid, lightid], xquat) + light_xdir_out[worldid, lightid] = math.rot_vec_quat(light_dir[worldid, lightid], xquat) @wp.kernel @@ -565,8 +565,8 @@ def _light_fn( light_mode: wp.array(dtype=int), light_bodyid: wp.array(dtype=int), light_targetbodyid: wp.array(dtype=int), - light_poscom0: wp.array(dtype=wp.vec3), - light_pos0: wp.array(dtype=wp.vec3), + light_poscom0: wp.array2d(dtype=wp.vec3), + light_pos0: wp.array2d(dtype=wp.vec3), # Data in: xpos_in: wp.array2d(dtype=wp.vec3), light_xpos_in: wp.array2d(dtype=wp.vec3), @@ -584,9 +584,9 @@ def _light_fn( return elif light_mode[lightid] == wp.static(CamLightType.TRACK.value): body_xpos = xpos_in[worldid, light_bodyid[lightid]] - light_xpos_out[worldid, lightid] = body_xpos + light_pos0[lightid] + light_xpos_out[worldid, lightid] = body_xpos + light_pos0[worldid, lightid] elif light_mode[lightid] == wp.static(CamLightType.TRACKCOM.value): - light_xpos_out[worldid, lightid] = subtree_com_in[worldid, light_bodyid[lightid]] + light_poscom0[lightid] + light_xpos_out[worldid, lightid] = subtree_com_in[worldid, light_bodyid[lightid]] + light_poscom0[worldid, lightid] elif light_mode[lightid] == wp.static(CamLightType.TARGETBODY.value) or light_mode[lightid] == wp.static( CamLightType.TARGETBODYCOM.value ): @@ -670,7 +670,7 @@ def _qM_sparse( dof_bodyid: wp.array(dtype=int), dof_parentid: wp.array(dtype=int), dof_Madr: wp.array(dtype=int), - dof_armature: wp.array(dtype=float), + dof_armature: wp.array2d(dtype=float), # Data in: cdof_in: wp.array2d(dtype=wp.spatial_vector), crb_in: wp.array2d(dtype=vec10), @@ -682,7 +682,7 @@ def _qM_sparse( bodyid = dof_bodyid[dofid] # init M(i,i) with armature inertia - qM_out[worldid, 0, madr_ij] = dof_armature[dofid] + qM_out[worldid, 0, madr_ij] = dof_armature[worldid, dofid] # precompute buf = crb_body_i * cdof_i buf = math.inert_vec(crb_in[worldid, bodyid], cdof_in[worldid, dofid]) @@ -699,7 +699,7 @@ def _qM_dense( # Model: dof_bodyid: wp.array(dtype=int), dof_parentid: wp.array(dtype=int), - dof_armature: wp.array(dtype=float), + dof_armature: wp.array2d(dtype=float), # Data in: cdof_in: wp.array2d(dtype=wp.spatial_vector), crb_in: wp.array2d(dtype=vec10), @@ -709,7 +709,7 @@ def _qM_dense( worldid, dofid = wp.tid() bodyid = dof_bodyid[dofid] # init M(i,i) with armature inertia - M = dof_armature[dofid] + M = dof_armature[worldid, dofid] # precompute buf = crb_body_i * cdof_i buf = math.inert_vec(crb_in[worldid, bodyid], cdof_in[worldid, dofid]) @@ -1140,11 +1140,11 @@ def _cfrc_ext_equality( # Model: body_rootid: wp.array(dtype=int), site_bodyid: wp.array(dtype=int), - site_pos: wp.array(dtype=wp.vec3), + site_pos: wp.array2d(dtype=wp.vec3), eq_obj1id: wp.array(dtype=int), eq_obj2id: wp.array(dtype=int), eq_objtype: wp.array(dtype=int), - eq_data: wp.array(dtype=vec11), + eq_data: wp.array2d(dtype=vec11), # Data in: ne_connect_in: wp.array(dtype=int), ne_weld_in: wp.array(dtype=int), @@ -1182,7 +1182,7 @@ def _cfrc_ext_equality( worldid = efc_worldid_in[efcid] id = efc_id_in[efcid] - eq_data_ = eq_data[id] + eq_data_ = eq_data[worldid, id] body_semantic = eq_objtype[id] == wp.static(ObjType.BODY.value) obj1 = eq_obj1id[id] @@ -1203,7 +1203,7 @@ def _cfrc_ext_equality( else: offset = wp.vec3(eq_data_[3], eq_data_[4], eq_data_[5]) else: - offset = site_pos[obj1] + offset = site_pos[worldid, obj1] # transform point on body1: local -> global pos = xmat_in[worldid, bodyid1] @ offset + xpos_in[worldid, bodyid1] @@ -1225,7 +1225,7 @@ def _cfrc_ext_equality( else: offset = wp.vec3(eq_data_[0], eq_data_[1], eq_data_[2]) else: - offset = site_pos[obj2] + offset = site_pos[worldid, obj2] # transform point on body2: local -> global pos = xmat_in[worldid, bodyid2] @ offset + xpos_in[worldid, bodyid2] @@ -1480,7 +1480,7 @@ def _transmission( jnt_dofadr: wp.array(dtype=int), actuator_trntype: wp.array(dtype=int), actuator_trnid: wp.array(dtype=wp.vec2i), - actuator_gear: wp.array(dtype=wp.spatial_vector), + actuator_gear: wp.array2d(dtype=wp.spatial_vector), tendon_adr: wp.array(dtype=int), tendon_num: wp.array(dtype=int), wrap_objid: wp.array(dtype=int), @@ -1495,7 +1495,7 @@ def _transmission( ): worldid, actid = wp.tid() trntype = actuator_trntype[actid] - gear = actuator_gear[actid] + gear = actuator_gear[worldid, actid] if trntype == wp.static(TrnType.JOINT.value) or trntype == wp.static(TrnType.JOINTINPARENT.value): qpos = qpos_in[worldid] jntid = actuator_trnid[actid][0] @@ -1747,8 +1747,8 @@ def factor_solve_i(m, d, M, L, D, x, y): def _subtree_vel_forward( # Model: body_rootid: wp.array(dtype=int), - body_mass: wp.array(dtype=float), - body_inertia: wp.array(dtype=wp.vec3), + body_mass: wp.array2d(dtype=float), + body_inertia: wp.array2d(dtype=wp.vec3), # Data in: xipos_in: wp.array2d(dtype=wp.vec3), ximat_in: wp.array2d(dtype=wp.mat33), @@ -1771,11 +1771,11 @@ def _subtree_vel_forward( # update linear velocity lin -= wp.cross(xipos - subtree_com_root, ang) - subtree_linvel_out[worldid, bodyid] = body_mass[bodyid] * lin + subtree_linvel_out[worldid, bodyid] = body_mass[worldid, bodyid] * lin dv = wp.transpose(ximat) @ ang - dv[0] *= body_inertia[bodyid][0] - dv[1] *= body_inertia[bodyid][1] - dv[2] *= body_inertia[bodyid][2] + dv[0] *= body_inertia[worldid, bodyid][0] + dv[1] *= body_inertia[worldid, bodyid][1] + dv[2] *= body_inertia[worldid, bodyid][2] subtree_angmom_out[worldid, bodyid] = ximat @ dv subtree_bodyvel_out[worldid, bodyid] = wp.spatial_vector(ang, lin) @@ -1784,7 +1784,7 @@ def _subtree_vel_forward( def _linear_momentum( # Model: body_parentid: wp.array(dtype=int), - body_subtreemass: wp.array(dtype=float), + body_subtreemass: wp.array2d(dtype=float), # Data in: subtree_linvel_in: wp.array2d(dtype=wp.vec3), # In: @@ -1797,15 +1797,15 @@ def _linear_momentum( if bodyid: pid = body_parentid[bodyid] wp.atomic_add(subtree_linvel_out[worldid], pid, subtree_linvel_in[worldid, bodyid]) - subtree_linvel_out[worldid, bodyid] /= wp.max(MJ_MINVAL, body_subtreemass[bodyid]) + subtree_linvel_out[worldid, bodyid] /= wp.max(MJ_MINVAL, body_subtreemass[worldid, bodyid]) @wp.kernel def _angular_momentum( # Model: body_parentid: wp.array(dtype=int), - body_mass: wp.array(dtype=float), - body_subtreemass: wp.array(dtype=float), + body_mass: wp.array2d(dtype=float), + body_subtreemass: wp.array2d(dtype=float), # Data in: xipos_in: wp.array2d(dtype=wp.vec3), subtree_com_in: wp.array2d(dtype=wp.vec3), @@ -1830,8 +1830,8 @@ def _angular_momentum( vel = subtree_bodyvel_in[worldid, bodyid] linvel = subtree_linvel_in[worldid, bodyid] linvel_parent = subtree_linvel_in[worldid, pid] - mass = body_mass[bodyid] - subtreemass = body_subtreemass[bodyid] + mass = body_mass[worldid, bodyid] + subtreemass = body_subtreemass[worldid, bodyid] # momentum wrt body i dx = xipos - com diff --git a/mujoco_warp/_src/types.py b/mujoco_warp/_src/types.py index 0293d3b7..3d3dd3fc 100644 --- a/mujoco_warp/_src/types.py +++ b/mujoco_warp/_src/types.py @@ -600,8 +600,8 @@ class Model: npair: number of predefined geom pairs () opt: physics options stat: model statistics - qpos0: qpos values at default pose (nq,) - qpos_spring: reference pose for springs (nq,) + qpos0: qpos values at default pose (nworld, nq) + qpos_spring: reference pose for springs (nworld, nq) qM_fullm_i: sparse mass matrix addressing qM_fullm_j: sparse mass matrix addressing qM_mulm_i: sparse mass matrix addressing @@ -625,32 +625,32 @@ class Model: body_dofadr: start addr of dofs; -1: no dofs (nbody,) body_geomnum: number of geoms (nbody,) body_geomadr: start addr of geoms; -1: no geoms (nbody,) - body_pos: position offset rel. to parent body (nbody, 3) - body_quat: orientation offset rel. to parent body (nbody, 4) - body_ipos: local position of center of mass (nbody, 3) - body_iquat: local orientation of inertia ellipsoid (nbody, 4) - body_mass: mass (nbody,) - body_subtreemass: mass of subtree starting at this body (nbody,) - subtree_mass: mass of subtree (nbody,) - body_inertia: diagonal inertia in ipos/iquat frame (nbody, 3) - body_invweight0: mean inv inert in qpos0 (trn, rot) (nbody, 2) + body_pos: position offset rel. to parent body (nworld, nbody, 3) + body_quat: orientation offset rel. to parent body (nworld, nbody, 4) + body_ipos: local position of center of mass (nworld, nbody, 3) + body_iquat: local orientation of inertia ellipsoid (nworld, nbody, 4) + body_mass: mass (nworld, nbody,) + body_subtreemass: mass of subtree starting at this body (nworld, nbody,) + subtree_mass: mass of subtree (nworld, nbody,) + body_inertia: diagonal inertia in ipos/iquat frame (nworld, nbody, 3) + body_invweight0: mean inv inert in qpos0 (trn, rot) (nworld, nbody, 2) body_contype: OR over all geom contypes (nbody,) body_conaffinity: OR over all geom conaffinities (nbody,) - body_gravcomp: antigravity force, units of body weight (nbody,) + body_gravcomp: antigravity force, units of body weight (nworld, nbody) jnt_type: type of joint (mjtJoint) (njnt,) jnt_qposadr: start addr in 'qpos' for joint's data (njnt,) jnt_dofadr: start addr in 'qvel' for joint's data (njnt,) jnt_bodyid: id of joint's body (njnt,) jnt_limited: does joint have limits (njnt,) jnt_actfrclimited: does joint have actuator force limits (njnt,) - jnt_solref: constraint solver reference: limit (njnt, mjNREF) - jnt_solimp: constraint solver impedance: limit (njnt, mjNIMP) - jnt_pos: local anchor position (njnt, 3) - jnt_axis: local joint axis (njnt, 3) - jnt_stiffness: stiffness coefficient (njnt,) - jnt_range: joint limits (njnt, 2) - jnt_actfrcrange: range of total actuator force (njnt, 2) - jnt_margin: min distance for limit detection (njnt,) + jnt_solref: constraint solver reference: limit (nworld, njnt, mjNREF) + jnt_solimp: constraint solver impedance: limit (nworld, njnt, mjNIMP) + jnt_pos: local anchor position (nworld, njnt, 3) + jnt_axis: local joint axis (nworld, njnt, 3) + jnt_stiffness: stiffness coefficient (nworld, njnt) + jnt_range: joint limits (nworld, njnt, 2) + jnt_actfrcrange: range of total actuator force (nworld, njnt, 2) + jnt_margin: min distance for limit detection (nworld, njnt) jnt_limited_slide_hinge_adr: limited/slide/hinge jntadr jnt_limited_ball_adr: limited/ball jntadr jnt_actgravcomp: is gravcomp force applied via actuators (njnt,) @@ -658,12 +658,12 @@ class Model: dof_jntid: id of dof's joint (nv,) dof_parentid: id of dof's parent; -1: none (nv,) dof_Madr: dof address in M-diagonal (nv,) - dof_armature: dof armature inertia/mass (nv,) - dof_damping: damping coefficient (nv,) - dof_invweight0: diag. inverse inertia in qpos0 (nv,) - dof_frictionloss: dof friction loss (nv,) - dof_solimp: constraint solver impedance: frictionloss (nv, NIMP) - dof_solref: constraint solver reference: frictionloss (nv, NREF) + dof_armature: dof armature inertia/mass (nworld, nv) + dof_damping: damping coefficient (nworld, nv) + dof_invweight0: diag. inverse inertia in qpos0 (nworld, nv) + dof_frictionloss: dof friction loss (nworld, nv) + dof_solimp: constraint solver impedance: frictionloss (nworld, nv, NIMP) + dof_solref: constraint solver reference: frictionloss (nworld, nv, NREF) dof_tri_row: np.tril_indices (mjm.nv)[0] dof_tri_col: np.tril_indices (mjm.nv)[1] geom_type: geometric type (mjtGeom) (ngeom,) @@ -673,30 +673,30 @@ class Model: geom_bodyid: id of geom's body (ngeom,) geom_dataid: id of geom's mesh/hfield; -1: none (ngeom,) geom_group: geom group inclusion/exclusion mask (ngeom,) - geom_matid: material id for rendering (ngeom,) + geom_matid: material id for rendering (nworld, ngeom,) geom_priority: geom contact priority (ngeom,) - geom_solmix: mixing coef for solref/imp in geom pair (ngeom,) - geom_solref: constraint solver reference: contact (ngeom, mjNREF) - geom_solimp: constraint solver impedance: contact (ngeom, mjNIMP) + geom_solmix: mixing coef for solref/imp in geom pair (nworld, ngeom,) + geom_solref: constraint solver reference: contact (nworld, ngeom, mjNREF) + geom_solimp: constraint solver impedance: contact (nworld, ngeom, mjNIMP) geom_size: geom-specific size parameters (ngeom, 3) geom_aabb: bounding box, (center, size) (ngeom, 6) - geom_rbound: radius of bounding sphere (ngeom,) - geom_pos: local position offset rel. to body (ngeom, 3) - geom_quat: local orientation offset rel. to body (ngeom, 4) - geom_friction: friction for (slide, spin, roll) (ngeom, 3) - geom_margin: detect contact if dist