From 6953516e91ce4e686fcccb7d1baa4a56b66184e1 Mon Sep 17 00:00:00 2001 From: camevor Date: Wed, 30 Apr 2025 19:12:13 +0200 Subject: [PATCH 01/12] Sequential box-box --- mujoco_warp/_src/collision_primitive.py | 588 ++++++++++++++++++++++++ 1 file changed, 588 insertions(+) diff --git a/mujoco_warp/_src/collision_primitive.py b/mujoco_warp/_src/collision_primitive.py index ffef2655..96113848 100644 --- a/mujoco_warp/_src/collision_primitive.py +++ b/mujoco_warp/_src/collision_primitive.py @@ -25,9 +25,22 @@ from .types import Model from .types import vec5 + wp.set_module_options({"enable_backward": False}) +class vec8f(wp.types.vector(length=8, dtype=wp.float32)): + pass + + +class mat43f(wp.types.matrix(shape=(4, 3), dtype=wp.float32)): + pass + + +class mat83f(wp.types.matrix(shape=(8, 3), dtype=wp.float32)): + pass + + @wp.struct class Geom: pos: wp.vec3 @@ -1248,6 +1261,581 @@ def capsule_box( ) +@wp.func +def compute_rotmore(face_idx: wp.int32) -> wp.mat33: + rotmore = wp.mat33(0.0) + + if face_idx == 0: + rotmore[0, 2] = -1.0 # rotmore[2] + rotmore[1, 1] = +1.0 # rotmore[4] + rotmore[2, 0] = +1.0 # rotmore[6] + elif face_idx == 1: + rotmore[0, 0] = +1.0 # rotmore[0] + rotmore[1, 2] = -1.0 # rotmore[5] + rotmore[2, 1] = +1.0 # rotmore[7] + elif face_idx == 2: + rotmore[0, 0] = +1.0 # rotmore[0] + rotmore[1, 1] = +1.0 # rotmore[4] + rotmore[2, 2] = +1.0 # rotmore[8] + elif face_idx == 3: + rotmore[0, 2] = +1.0 # rotmore[2] + rotmore[1, 1] = +1.0 # rotmore[4] + rotmore[2, 0] = -1.0 # rotmore[6] + elif face_idx == 4: + rotmore[0, 0] = +1.0 # rotmore[0] + rotmore[1, 2] = +1.0 # rotmore[5] + rotmore[2, 1] = -1.0 # rotmore[7] + elif face_idx == 5: + rotmore[0, 0] = -1.0 # rotmore[0] + rotmore[1, 1] = +1.0 # rotmore[4] + rotmore[2, 2] = -1.0 # rotmore[8] + + return rotmore + + +@wp.func +def box_box( + box1: Geom, + box2: Geom, + worldid: int, + d: Data, + margin: float, + gap: float, + condim: int, + friction: vec5, + solref: wp.vec2f, + solreffriction: wp.vec2f, + solimp: vec5, + geoms: wp.vec2i, +): + # Compute transforms between box's frames + + pos21 = wp.transpose(box1.rot) @ (box2.pos - box1.pos) + pos12 = wp.transpose(box2.rot) @ (box1.pos - box2.pos) + + rot21 = wp.transpose(box1.rot) @ box2.rot + rot12 = wp.transpose(rot21) + + rot21abs = wp.transpose( + wp.mat33(wp.abs(rot21[0]), wp.abs(rot21[1]), wp.abs(rot21[2])) + ) + + rot12abs = wp.transpose( + wp.mat33(wp.abs(rot12[0]), wp.abs(rot12[1]), wp.abs(rot12[2])) + ) + + plen2 = rot21abs @ box2.size + plen1 = rot12abs @ box1.size + + # Compute axis of maximum separation + + s_sum_3 = 3.0 * (box1.size + box2.size) + separation = wp.float32(margin + s_sum_3[0] + s_sum_3[1] + s_sum_3[2]) + axis_code = wp.int32(-1) + + # First test: consider boxes' face normals + for i in range(3): + c1 = -wp.abs(pos21[i]) + box1.size[i] + plen2[i] + + c2 = -wp.abs(pos12[i]) + box2.size[i] + plen1[i] + + if c1 < -margin or c2 < -margin: + return + + if c1 < separation: + separation = c1 + axis_code = i + 3 * wp.int32(pos21[i] < 0) + 0 # Face of box1 + if c2 < separation: + separation = c2 + axis_code = i + 3 * wp.int32(pos12[i] < 0) + 6 # Face of box2 + + clnorm = wp.vec3(0.0) + inv = wp.bool(False) + cle1 = wp.int32(0) + cle2 = wp.int32(0) + + # Second test: consider cross products of boxes' edges + for i in range(3): + for j in range(3): + # Compute cross product of box edges (potential separating axis) + + if i == 0: + cross_axis = wp.vec3(0.0, -rot12[j, 2], rot12[j, 1]) + elif i == 1: + cross_axis = wp.vec3(rot12[j, 2], 0.0, -rot12[j, 0]) + else: + cross_axis = wp.vec3(-rot12[j, 1], rot12[j, 0], 0.0) + + cross_length = wp.length(cross_axis) + + if cross_length < MJ_MINVAL: + continue + + cross_axis /= cross_length + + box_dist = wp.dot(pos21, cross_axis) + c3 = wp.float32(0.0) + + # Project box half-sizes onto the potential separating axis + for k in range(3): + if k != i: + c3 += box1.size[k] * wp.abs(cross_axis[k]) + if k != j: + c3 += box2.size[k] * rot21abs[i, 3 - k - j] / cross_length + + c3 -= wp.abs(box_dist) + + # Early exit: no collision if separated along this axis + if c3 < -margin: + return + + # Track minimum separation and which edge-edge pair it occurs on + + if c3 < separation * (1.0 - 1e-12): + separation = c3 + # Determine which corners/edges are closest + cle1 = 0 + cle2 = 0 + + for k in range(3): + if k != i and (int(cross_axis[k] > 0) ^ int(box_dist < 0)): + cle1 += 1 << k + if k != j and ( + int(rot21[i, 3 - k - j] > 0) ^ int(box_dist < 0) ^ int((k - j + 3) % 3 == 1) + ): + cle2 += 1 << k + + axis_code = 12 + i * 3 + j + clnorm = cross_axis + inv = box_dist < 0 + + # No axis with separation < margin found + if axis_code == -1: + return + + if axis_code < 12: + # Handle face-vertex collision + face_idx = axis_code % 6 + box_idx = axis_code / 6 + # TODO(camevor): consider single rotmore calc, or using i0,i1,i2 and f0,f1,f2 as rotmore substitute + rotmore = compute_rotmore(face_idx) + + r = rotmore @ wp.where(box_idx, rot12, rot21) + p = rotmore @ wp.where(box_idx, pos12, pos21) + ss = wp.abs(rotmore @ wp.where(box_idx, box2.size, box1.size)) + s = wp.where(box_idx, box1.size, box2.size) + rt = wp.transpose(r) + + lx, ly, hz = ss[0], ss[1], ss[2] + p[2] -= hz + + clcorner = wp.int32(0) # corner of non-face box with least axis separation + + for i in range(3): + if r[2, i] < 0: + clcorner += 1 << i + + lp = p + for i in range(wp.static(3)): + lp += rt[i] * s[i] * wp.where(clcorner & 1 << i, 1.0, -1.0) + + m = wp.int32(1) + dirs = wp.int32(0) + + cn1 = wp.vec3(0.0) + cn2 = wp.vec3(0.0) + + for i in range(3): + if wp.abs(r[2, i]) < 0.5: + if not dirs: + cn1 = rt[i] * s[i] * wp.where(clcorner & (1 << i), -2.0, 2.0) + else: + cn2 = rt[i] * s[i] * wp.where(clcorner & (1 << i), -2.0, 2.0) + + dirs += 1 + + # Compute additional contact points + # compute the other box's corner points + # TODO(camevor) potentially unnecessary use of registers + # pts are now the corners of the other closest face + + k = dirs * dirs + + # Find potential contact points + lines_a = mat43f() + lines_b = mat43f() + + if dirs: + lines_a[0] = lp + lines_b[0] = cn1 + + if dirs == 2: + lines_a[1] = lp + lines_b[1] = cn2 + + lines_a[2] = lp + cn1 + lines_b[2] = cn2 + + lines_a[3] = lp + cn2 + lines_b[3] = cn1 + + points = mat83f() + + n = wp.int32(0) + + for i in range(k): + for q in range(2): + la = lines_a[i, q] + lb = lines_b[i, q] + lc = lines_a[i, 1 - q] + ld = lines_b[i, 1 - q] + + if wp.abs(lb) > MJ_MINVAL: + br = 1.0 / lb + for j in range(-1, 2, 2): + l = ss[q] * wp.float32(j) + c1 = (l - la) * br + if c1 < 0 or c1 > 1: + continue + c2 = lc + ld * c1 + if wp.abs(c2) > ss[1 - q]: + continue + + points[n] = lines_a[i] + c1 * lines_b[i] + n += 1 + + if dirs == 2: + ax = cn1[0] + bx = cn2[0] + ay = cn1[1] + by = cn2[1] + C = 1.0 / (ax * by - bx * ay) + + for i in range(4): + llx = wp.where(i / 2, lx, -lx) + lly = wp.where(i % 2, ly, -ly) + + x = llx - lp[0] + y = lly - lp[1] + + u = (x * by - y * bx) * C + v = (y * ax - x * ay) * C + + if u <= 0 or v <= 0 or u >= 1 or v >= 1: + continue + + points[n] = wp.vec3(llx, lly, lp[2] + u * cn1[2] + v * cn2[2]) + + n += 1 + + points[n] = lp # TODO(camevor) consider doing this sooner or later + n += 1 + + for i in range(1, k): + tmpv = lp + wp.float32(i & 1) * cn1 + wp.float32(i & 2) * cn2 + if tmpv[0] <= -lx or tmpv[0] >= lx: + continue + if tmpv[1] <= -ly or tmpv[1] >= ly: + continue + points[n] = tmpv + n += 1 + + m = n + n = wp.int32(0) + + for i in range(m): + if points[i][2] > margin: + continue + if i != n: + points[n] = points[i] + + points[n, 2] *= 0.5 + n += 1 + + # Set up contact frame + rw = wp.where(box_idx, box2.rot, box1.rot) @ wp.transpose(rotmore) + pw = wp.where(box_idx, box2.pos, box1.pos) + + normal = wp.where(box_idx, -1.0, 1.0) * wp.transpose(rw)[2] + + # TODO(camevor): explicit frame + frame = make_frame(normal) + + coff = wp.atomic_add(d.ncon, 0, n) + + for i in range(min(d.nconmax - coff, n)): + dist = points[i, 2] + + # Transform contact point to world frame + points[i, 2] += hz + pos = rw @ points[i] + pw + + cid = coff + i + + d.contact.dist[cid] = dist + d.contact.pos[cid] = pos + d.contact.frame[cid] = frame + d.contact.geom[cid] = geoms + d.contact.worldid[cid] = worldid + d.contact.includemargin[cid] = margin - gap + d.contact.dim[cid] = condim + d.contact.friction[cid] = friction + d.contact.solref[cid] = solref + d.contact.solreffriction[cid] = solreffriction + d.contact.solimp[cid] = solimp + + # return + + else: + # Handle edge-edge collision + edge1 = (axis_code - 12) / 3 + edge2 = (axis_code - 12) % 3 + + # Set up non-contacting edges ax1, ax2 for box2 and pax1, pax2 for box 1 + ax1 = wp.int(1 - (edge2 & 1)) + ax2 = wp.int(2 - (edge2 & 2)) + + pax1 = wp.int(1 - (edge1 & 1)) + pax2 = wp.int(2 - (edge1 & 2)) + + if rot21abs[edge1, ax1] < rot21abs[edge1, ax2]: + ax_tmp = ax1 + ax1 = ax2 + ax2 = ax_tmp + + if rot12abs[edge2, pax1] < rot12abs[edge2, pax2]: + pax_tmp = pax1 + pax1 = pax2 + pax2 = pax_tmp + + rotmore = compute_rotmore(wp.where(cle1 & (1 << pax2), pax2, pax2 + 3)) + + # Transform coordinates for edge-edge contact calculation + p = rotmore @ pos21 + rnorm = rotmore @ clnorm + r = rotmore @ rot21 + rt = wp.transpose(r) + s = wp.abs(wp.transpose(rotmore) @ box1.size) + + lx, ly, hz = s[0], s[1], s[2] + p[2] -= hz + + # Calculate closest box2 face + points = mat83f() + + points[0] = ( + p + + rt[ax1] * box2.size[ax1] * wp.where(cle2 & (1 << ax1), 1.0, -1.0) + + rt[ax2] * box2.size[ax2] * wp.where(cle2 & (1 << ax2), 1.0, -1.0) + ) + points[1] = points[0] - rt[edge2] * box2.size[edge2] + points[0] += rt[edge2] * box2.size[edge2] + + points[2] = ( + p + + rt[ax1] * box2.size[ax1] * wp.where(cle2 & (1 << ax1), -1.0, 1.0) + + rt[ax2] * box2.size[ax2] * wp.where(cle2 & (1 << ax2), 1.0, -1.0) + ) + + points[3] = points[2] - rt[edge2] * box2.size[edge2] + points[2] += rt[edge2] * box2.size[edge2] + + n = 4 + + # Set up coordinate axes for contact face of box2 + axi_lp = points[0] + axi_cn1 = points[1] - points[0] + axi_cn2 = points[2] - points[0] + + # Check if contact normal is valid + if wp.abs(rnorm[2]) < MJ_MINVAL: + return # Shouldn't happen + + # Calculate inverse normal for projection + innorm = wp.where(inv, -1.0, 1.0) / rnorm[2] + + pu = mat43f() + + # Project points onto contact plane + for i in range(4): + pu[i] = points[i] + c_scl = points[i, 2] * wp.where(inv, -1.0, 1.0) * innorm + points[i] -= rnorm * c_scl + + pts_lp = points[0] + pts_cn1 = points[1] - points[0] + pts_cn2 = points[2] - points[0] + + lines_a = mat43f() + lines_b = mat43f() + linesu_a = mat43f() + linesu_b = mat43f() + + lines_a[0] = pts_lp + lines_b[0] = pts_cn1 + linesu_a[0] = axi_lp + linesu_b[0] = axi_cn1 + + lines_a[1] = pts_lp + lines_b[1] = pts_cn2 + linesu_a[1] = axi_lp + linesu_b[1] = axi_cn2 + + lines_a[2] = pts_lp + pts_cn1 + lines_b[2] = pts_cn2 + linesu_a[2] = axi_lp + axi_cn1 + linesu_b[2] = axi_cn2 + + lines_a[3] = pts_lp + pts_cn2 + lines_b[3] = pts_cn1 + linesu_a[3] = axi_lp + axi_cn2 + linesu_b[3] = axi_cn1 + + k = 4 + n = wp.int32(0) + + max_con_pair = 8 + depth = vec8f() + + for i in range(k): + for q in range(2): + la = lines_a[i, q] + lb = lines_b[i, q] + lc = lines_a[i, 1 - q] + ld = lines_b[i, 1 - q] + + if wp.abs(lb) > MJ_MINVAL: + br = 1.0 / lb + for j in range(-1, 2, 2): + if n == max_con_pair: + break + l = s[q] * wp.float32(j) + c1 = (l - la) * br + if c1 < 0 or c1 > 1: + continue + c2 = lc + ld * c1 + if wp.abs(c2) > s[1 - q]: + continue + if (linesu_a[i, 2] + linesu_b[i][2] * c1) * innorm > margin: + continue + + points[n] = linesu_a[i] * 0.5 + c1 * linesu_b[i] + points[n, q] += 0.5 * l + points[n, 1 - q] += 0.5 * c2 + depth[n] = points[n, 2] * innorm * 2.0 + n += 1 + + nl = n + + ax = pts_cn1[0] + bx = pts_cn2[0] + ay = pts_cn1[1] + by = pts_cn2[1] + C = 1.0 / (ax * by - bx * ay) + + for i in range(4): + if n == max_con_pair: + break + llx = wp.where(i / 2, lx, -lx) + lly = wp.where(i % 2, ly, -ly) + + x = llx - pts_lp[0] + y = lly - pts_lp[1] + + u = (x * by - y * bx) * C + v = (y * ax - x * ay) * C + + if nl == 0: + if (u < 0 or u > 0) and (v < 0 or v > 1): + continue + elif u < 0 or v < 0 or u > 1 or v > 1: + continue + + u = wp.clamp(u, 0.0, 1.0) + v = wp.clamp(v, 0.0, 1.0) + w = 1.0 - u - v + vtmp = pu[0] * w + pu[1] * u + pu[2] * v + + points[n] = wp.vec3(llx, lly, 0.0) + + vtmp2 = points[n] - vtmp + tc1 = wp.length_sq(vtmp2) + if vtmp[2] > 0 and tc1 > margin * margin: + continue + + points[n] = 0.5 * (points[n] + vtmp) + + depth[n] = wp.sqrt(tc1) * wp.where(vtmp[2] < 0, -1.0, 1.0) + n += 1 + + nf = n + + for i in range(4): + if n >= max_con_pair: + break + x = pu[i, 0] + y = pu[i, 1] + if nl == 0 and nf != 0: + if (x < -lx or x > lx) and (y < -ly or y > ly): + continue + elif x < -lx or x > lx or y < -ly or y > ly: + continue + + c1 = wp.float32(0) + + for j in range(2): + if pu[i, j] < -s[j]: + c1 += (pu[i, j] + s[j]) * (pu[i, j] + s[j]) + elif pu[i, j] > s[j]: + c1 += (pu[i, j] - s[j]) * (pu[i, j] - s[j]) + + c1 += pu[i, 2] * innorm * pu[i, 2] * innorm + + if pu[i, 2] > 0 and c1 > margin * margin: + continue + + tmp_p = wp.vec3(pu[i, 0], pu[i, 1], 0.0) + + for j in range(2): + if pu[i, j] < -s[j]: + tmp_p[j] = -s[j] * 0.5 + elif pu[i, j] > s[j]: + tmp_p[j] = +s[j] * 0.5 + + tmp_p += pu[i] + points[n] = tmp_p * 0.5 + + depth[n] = wp.sqrt(c1) * wp.where(pu[i, 2] < 0, -1.0, 1.0) + n += 1 + + # Set up contact data for all points + rw = box1.rot @ wp.transpose(rotmore) + normal = r @ rnorm + frame = make_frame(wp.where(inv, -1.0, 1.0) * normal) + + coff = wp.atomic_add(d.ncon, 0, n) + + for i in range(min(d.nconmax - coff, n)): + dist = depth[i] + + points[i, 2] += hz + pos = rw @ points[i] + box1.pos + + cid = coff + i + + d.contact.dist[cid] = dist + d.contact.pos[cid] = pos + d.contact.frame[cid] = frame + d.contact.geom[cid] = geoms + d.contact.worldid[cid] = worldid + d.contact.includemargin[cid] = margin - gap + d.contact.dim[cid] = condim + d.contact.friction[cid] = friction + d.contact.solref[cid] = solref + d.contact.solreffriction[cid] = solreffriction + d.contact.solimp[cid] = solimp + + @wp.kernel def _primitive_narrowphase( m: Model, From bd061c2219b4f22410317abecf077daaeae61d8c Mon Sep 17 00:00:00 2001 From: camevor Date: Wed, 30 Apr 2025 18:35:52 +0200 Subject: [PATCH 02/12] Switch to sequential box-box --- mujoco_warp/_src/collision_box.py | 586 ------------------------ mujoco_warp/_src/collision_driver.py | 2 - mujoco_warp/_src/collision_primitive.py | 15 + 3 files changed, 15 insertions(+), 588 deletions(-) delete mode 100644 mujoco_warp/_src/collision_box.py diff --git a/mujoco_warp/_src/collision_box.py b/mujoco_warp/_src/collision_box.py deleted file mode 100644 index ae3f2849..00000000 --- a/mujoco_warp/_src/collision_box.py +++ /dev/null @@ -1,586 +0,0 @@ -# Copyright 2025 The Newton Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - - -import math -from typing import Any - -import warp as wp - -from .collision_primitive import contact_params -from .collision_primitive import write_contact -from .math import make_frame -from .types import Data -from .types import GeomType -from .types import Model - -BOX_BOX_BLOCK_DIM = 32 - - -_HUGE_VAL = 1e6 -_TINY_VAL = 1e-6 - - -class vec16b(wp.types.vector(length=16, dtype=wp.int8)): - pass - - -class mat43f(wp.types.matrix(shape=(4, 3), dtype=wp.float32)): - pass - - -class mat83f(wp.types.matrix(shape=(8, 3), dtype=wp.float32)): - pass - - -class mat16_3f(wp.types.matrix(shape=(16, 3), dtype=wp.float32)): - pass - - -Box = mat83f - - -@wp.func -def _argmin(a: Any) -> wp.int32: - amin = wp.int32(0) - vmin = wp.float32(a[0]) - for i in range(1, len(a)): - if a[i] < vmin: - amin = i - vmin = a[i] - return amin - - -@wp.func -def box_normals(i: int) -> wp.vec3: - direction = wp.where(i < 3, -1.0, 1.0) - mod = i % 3 - if mod == 0: - return wp.vec3(0.0, direction, 0.0) - if mod == 1: - return wp.vec3(0.0, 0.0, direction) - return wp.vec3(-direction, 0.0, 0.0) - - -@wp.func -def box(R: wp.mat33, t: wp.vec3, geom_size: wp.vec3) -> Box: - """Get a transformed box""" - x = geom_size[0] - y = geom_size[1] - z = geom_size[2] - m = Box() - for i in range(8): - ix = wp.where(i & 4, x, -x) - iy = wp.where(i & 2, y, -y) - iz = wp.where(i & 1, z, -z) - m[i] = R @ wp.vec3(ix, iy, iz) + t - return m - - -@wp.func -def box_face_verts(box: Box, idx: wp.int32) -> mat43f: - """Get the quad corresponding to a box face""" - if idx == 0: - verts = wp.vec4i(0, 4, 5, 1) - if idx == 1: - verts = wp.vec4i(0, 2, 6, 4) - if idx == 2: - verts = wp.vec4i(6, 7, 5, 4) - if idx == 3: - verts = wp.vec4i(2, 3, 7, 6) - if idx == 4: - verts = wp.vec4i(1, 5, 7, 3) - if idx == 5: - verts = wp.vec4i(0, 1, 3, 2) - - m = mat43f() - for i in range(4): - m[i] = box[verts[i]] - return m - - -@wp.func -def get_box_axis( - axis_idx: int, - R: wp.mat33, -): - """Get the axis at index axis_idx. - R: rotation matrix from a to b - Axes 0-12 are face normals of boxes a & b - Axes 12-21 are edge cross products.""" - if axis_idx < 6: # a faces - axis = R @ wp.vec3(box_normals(axis_idx)) - is_degenerate = False - elif axis_idx < 12: # b faces - axis = wp.vec3(box_normals(axis_idx - 6)) - is_degenerate = False - else: # edges cross products - assert axis_idx < 21 - edges = axis_idx - 12 - axis_a, axis_b = edges / 3, edges % 3 - edge_a = wp.transpose(R)[axis_a] - if axis_b == 0: - axis = wp.vec3(0.0, -edge_a[2], edge_a[1]) - elif axis_b == 1: - axis = wp.vec3(edge_a[2], 0.0, -edge_a[0]) - else: - axis = wp.vec3(-edge_a[1], edge_a[0], 0.0) - is_degenerate = wp.length_sq(axis) < _TINY_VAL - return wp.normalize(axis), is_degenerate - - -@wp.func -def get_box_axis_support(axis: wp.vec3, degenerate_axis: bool, a: Box, b: Box): - """Get the overlap (or separating distance if negative) along `axis`, and the sign.""" - axis_d = wp.vec3d(axis) - support_a_max, support_b_max = wp.float32(-_HUGE_VAL), wp.float32(-_HUGE_VAL) - support_a_min, support_b_min = wp.float32(_HUGE_VAL), wp.float32(_HUGE_VAL) - for i in range(8): - vert_a = wp.vec3d(a[i]) - vert_b = wp.vec3d(b[i]) - proj_a = wp.float32(wp.dot(vert_a, axis_d)) - proj_b = wp.float32(wp.dot(vert_b, axis_d)) - support_a_max = wp.max(support_a_max, proj_a) - support_b_max = wp.max(support_b_max, proj_b) - support_a_min = wp.min(support_a_min, proj_a) - support_b_min = wp.min(support_b_min, proj_b) - dist1 = support_a_max - support_b_min - dist2 = support_b_max - support_a_min - dist = wp.where(degenerate_axis, _HUGE_VAL, wp.min(dist1, dist2)) - sign = wp.where(dist1 > dist2, -1, 1) - return dist, sign - - -@wp.struct -class AxisSupport: - best_dist: wp.float32 - best_sign: wp.int8 - best_idx: wp.int8 - - -@wp.func -def reduce_axis_support(a: AxisSupport, b: AxisSupport): - return wp.where(a.best_dist > b.best_dist, b, a) - - -@wp.func -def face_axis_alignment(a: wp.vec3, R: wp.mat33) -> wp.int32: - """Find the box faces most aligned with the axis `a`""" - max_dot = wp.float32(0.0) - max_idx = wp.int32(0) - for i in range(6): - d = wp.dot(R @ box_normals(i), a) - if d > max_dot: - max_dot = d - max_idx = i - return max_idx - - -@wp.kernel(enable_backward=False) -def box_box_kernel( - m: Model, - d: Data, - num_kernels: int, -): - """Calculates contacts between pairs of boxes.""" - tid, axis_idx = wp.tid() - - for bp_idx in range(tid, min(d.ncollision[0], d.nconmax), num_kernels): - geoms = d.collision_pair[bp_idx] - - ga, gb = geoms[0], geoms[1] - - if m.geom_type[ga] != int(GeomType.BOX.value) or m.geom_type[gb] != int( - GeomType.BOX.value - ): - continue - - worldid = d.collision_worldid[bp_idx] - - geoms, margin, gap, condim, friction, solref, solreffriction, solimp = ( - contact_params(m, d, tid) - ) - - # transformations - a_pos, b_pos = d.geom_xpos[worldid, ga], d.geom_xpos[worldid, gb] - a_mat, b_mat = d.geom_xmat[worldid, ga], d.geom_xmat[worldid, gb] - b_mat_inv = wp.transpose(b_mat) - trans_atob = b_mat_inv @ (a_pos - b_pos) - rot_atob = b_mat_inv @ a_mat - - a_size = m.geom_size[ga] - b_size = m.geom_size[gb] - a = box(rot_atob, trans_atob, a_size) - b = box(wp.identity(3, wp.float32), wp.vec3(0.0), b_size) - - # box-box implementation - - # Inlined def collision_axis_tiled( a: Box, b: Box, R: wp.mat33, axis_idx: wp.int32,): - # Finds the axis of minimum separation. - # a: Box a vertices, in frame b - # b: Box b vertices, in frame b - # R: rotation matrix from a to b - # Returns: - # best_axis: vec3 - # best_sign: int32 - # best_idx: int32 - R = rot_atob - - # launch tiled with block_dim=21 - if axis_idx > 20: - continue - - axis, degenerate_axis = get_box_axis(axis_idx, R) - axis_dist, axis_sign = get_box_axis_support(axis, degenerate_axis, a, b) - - supports = wp.tile(AxisSupport(axis_dist, wp.int8(axis_sign), wp.int8(axis_idx))) - - face_supports = wp.tile_view(supports, offset=(0,), shape=(12,)) - edge_supports = wp.tile_view(supports, offset=(12,), shape=(9,)) - - face_supports_red = wp.tile_reduce(reduce_axis_support, face_supports) - edge_supports_red = wp.tile_reduce(reduce_axis_support, edge_supports) - - face = face_supports_red[0] - edge = edge_supports_red[0] - - if axis_idx > 0: # single thread - continue - - # choose the best separating axis - face_axis, _ = get_box_axis(wp.int32(face.best_idx), R) - best_axis = wp.vec3(face_axis) - best_sign = wp.int32(face.best_sign) - best_idx = wp.int32(face.best_idx) - best_dist = wp.float32(face.best_dist) - - if edge.best_dist < face.best_dist: - edge_axis, _ = get_box_axis(wp.int32(edge.best_idx), R) - if wp.abs(wp.dot(face_axis, edge_axis)) < 0.99: - best_axis = edge_axis - best_sign = wp.int32(edge.best_sign) - best_idx = wp.int32(edge.best_idx) - best_dist = wp.float32(edge.best_dist) - # end inlined collision_axis_tiled - - # if axis_idx != 0: - # continue - if best_dist < 0: - continue - - # get the (reference) face most aligned with the separating axis - a_max = face_axis_alignment(best_axis, rot_atob) - b_max = face_axis_alignment(best_axis, wp.identity(3, wp.float32)) - - sep_axis = wp.float32(best_sign) * best_axis - - if best_sign > 0: - b_min = (b_max + 3) % 6 - dist, pos = _create_contact_manifold( - box_face_verts(a, a_max), - rot_atob @ box_normals(a_max), - box_face_verts(b, b_min), - box_normals(b_min), - ) - else: - a_min = (a_max + 3) % 6 - dist, pos = _create_contact_manifold( - box_face_verts(b, b_max), - box_normals(b_max), - box_face_verts(a, a_min), - rot_atob @ box_normals(a_min), - ) - - # For edge contacts, we use the clipped face point, mainly for performance - # reasons. For small penetration, the clipped face point is roughly the edge - # contact point. - if best_idx > 11: # is_edge_contact - idx = _argmin(dist) - dist = wp.vec4f(dist[idx], 1.0, 1.0, 1.0) - for i in range(4): - pos[i] = pos[idx] - - margin = wp.max(m.geom_margin[ga], m.geom_margin[gb]) - for i in range(4): - pos_glob = b_mat @ pos[i] + b_pos - n_glob = b_mat @ sep_axis - write_contact( - d, - dist[i], - pos_glob, - make_frame(n_glob), - margin, - gap, - condim, - friction, - solref, - solreffriction, - solimp, - geoms, - worldid, - ) - - -@wp.func -def _closest_segment_point_plane( - a: wp.vec3, b: wp.vec3, p0: wp.vec3, plane_normal: wp.vec3 -) -> wp.vec3: - """Gets the closest point between a line segment and a plane. - - Args: - a: first line segment point - b: second line segment point - p0: point on plane - plane_normal: plane normal - - Returns: - closest point between the line segment and the plane - """ - # Parametrize a line segment as S(t) = a + t * (b - a), plug it into the plane - # equation dot(n, S(t)) - d = 0, then solve for t to get the line-plane - # intersection. We then clip t to be in [0, 1] to be on the line segment. - n = plane_normal - d = wp.dot(p0, n) # shortest distance from origin to plane - denom = wp.dot(n, (b - a)) - t = (d - wp.dot(n, a)) / (denom + wp.where(denom == 0.0, _TINY_VAL, 0.0)) - t = wp.clamp(t, 0.0, 1.0) - segment_point = a + t * (b - a) - - return segment_point - - -@wp.func -def _project_poly_onto_plane( - poly: Any, - poly_n: wp.vec3, - plane_n: wp.vec3, - plane_pt: wp.vec3, -): - """Projects poly1 onto the poly2 plane along poly2's normal.""" - d = wp.dot(plane_pt, plane_n) - denom = wp.dot(poly_n, plane_n) - qn_scaled = poly_n / (denom + wp.where(denom == 0.0, _TINY_VAL, 0.0)) - - for i in range(len(poly)): - poly[i] = poly[i] + (d - wp.dot(poly[i], plane_n)) * qn_scaled - return poly - - -@wp.func -def _clip_edge_to_quad( - subject_poly: mat43f, - clipping_poly: mat43f, - clipping_normal: wp.vec3, -): - p0 = mat43f() - p1 = mat43f() - mask = wp.vec4b() - for edge_idx in range(4): - subject_p0 = subject_poly[(edge_idx + 3) % 4] - subject_p1 = subject_poly[edge_idx] - - any_both_in_front = wp.int32(0) - clipped0_dist_max = wp.float32(-_HUGE_VAL) - clipped1_dist_max = wp.float32(-_HUGE_VAL) - clipped_p0_distmax = wp.vec3(0.0) - clipped_p1_distmax = wp.vec3(0.0) - - for clipping_edge_idx in range(4): - clipping_p0 = clipping_poly[(clipping_edge_idx + 3) % 4] - clipping_p1 = clipping_poly[clipping_edge_idx] - edge_normal = wp.cross(clipping_p1 - clipping_p0, clipping_normal) - - p0_in_front = wp.dot(subject_p0 - clipping_p0, edge_normal) > _TINY_VAL - p1_in_front = wp.dot(subject_p1 - clipping_p0, edge_normal) > _TINY_VAL - candidate_clipped_p = _closest_segment_point_plane( - subject_p0, subject_p1, clipping_p1, edge_normal - ) - clipped_p0 = wp.where(p0_in_front, candidate_clipped_p, subject_p0) - clipped_p1 = wp.where(p1_in_front, candidate_clipped_p, subject_p1) - clipped_dist_p0 = wp.dot(clipped_p0 - subject_p0, subject_p1 - subject_p0) - clipped_dist_p1 = wp.dot(clipped_p1 - subject_p1, subject_p0 - subject_p1) - any_both_in_front |= wp.int32(p0_in_front and p1_in_front) - - if clipped_dist_p0 > clipped0_dist_max: - clipped0_dist_max = clipped_dist_p0 - clipped_p0_distmax = clipped_p0 - - if clipped_dist_p1 > clipped1_dist_max: - clipped1_dist_max = clipped_dist_p1 - clipped_p1_distmax = clipped_p1 - new_p0 = wp.where(any_both_in_front, subject_p0, clipped_p0_distmax) - new_p1 = wp.where(any_both_in_front, subject_p1, clipped_p1_distmax) - - mask_val = wp.int8( - wp.where( - wp.dot(subject_p0 - subject_p1, new_p0 - new_p1) < 0, - 0, - wp.int32(not any_both_in_front), - ) - ) - - p0[edge_idx] = new_p0 - p1[edge_idx] = new_p1 - mask[edge_idx] = mask_val - return p0, p1, mask - - -@wp.func -def _clip_quad( - subject_quad: mat43f, - subject_normal: wp.vec3, - clipping_quad: mat43f, - clipping_normal: wp.vec3, -): - """Clips a subject quad against a clipping quad. - Serial implementation. - """ - - subject_clipped_p0, subject_clipped_p1, subject_mask = _clip_edge_to_quad( - subject_quad, clipping_quad, clipping_normal - ) - clipping_proj = _project_poly_onto_plane( - clipping_quad, clipping_normal, subject_normal, subject_quad[0] - ) - clipping_clipped_p0, clipping_clipped_p1, clipping_mask = _clip_edge_to_quad( - clipping_proj, subject_quad, subject_normal - ) - - clipped = mat16_3f() - mask = vec16b() - for i in range(4): - clipped[i] = subject_clipped_p0[i] - clipped[i + 4] = clipping_clipped_p0[i] - clipped[i + 8] = subject_clipped_p1[i] - clipped[i + 12] = clipping_clipped_p1[i] - mask[i] = subject_mask[i] - mask[i + 4] = clipping_mask[i] - mask[i + 8] = subject_mask[i] - mask[i + 8 + 4] = clipping_mask[i] - - return clipped, mask - - -# TODO(ca): tiling variant -@wp.func -def _manifold_points( - poly: Any, - mask: Any, - clipping_norm: wp.vec3, -) -> wp.vec4b: - """Chooses four points on the polygon with approximately maximal area. Return the indices""" - n = len(poly) - - a_idx = wp.int32(0) - a_mask = wp.int8(mask[0]) - for i in range(n): - if mask[i] >= a_mask: - a_idx = i - a_mask = mask[i] - a = poly[a_idx] - - b_idx = wp.int32(0) - b_dist = wp.float32(-_HUGE_VAL) - for i in range(n): - dist = wp.length_sq(poly[i] - a) + wp.where(mask[i], 0.0, -_HUGE_VAL) - if dist >= b_dist: - b_idx = i - b_dist = dist - b = poly[b_idx] - - ab = wp.cross(clipping_norm, a - b) - - c_idx = wp.int32(0) - c_dist = wp.float32(-_HUGE_VAL) - for i in range(n): - ap = a - poly[i] - dist = wp.abs(wp.dot(ap, ab)) + wp.where(mask[i], 0.0, -_HUGE_VAL) - if dist >= c_dist: - c_idx = i - c_dist = dist - c = poly[c_idx] - - ac = wp.cross(clipping_norm, a - c) - bc = wp.cross(clipping_norm, b - c) - - d_idx = wp.int32(0) - d_dist = wp.float32(-2.0 * _HUGE_VAL) - for i in range(n): - ap = a - poly[i] - dist_ap = wp.abs(wp.dot(ap, ac)) + wp.where(mask[i], 0.0, -_HUGE_VAL) - bp = b - poly[i] - dist_bp = wp.abs(wp.dot(bp, bc)) + wp.where(mask[i], 0.0, -_HUGE_VAL) - if dist_ap + dist_bp >= d_dist: - d_idx = i - d_dist = dist_ap + dist_bp - d = poly[d_idx] - return wp.vec4b(wp.int8(a_idx), wp.int8(b_idx), wp.int8(c_idx), wp.int8(d_idx)) - - -@wp.func -def _create_contact_manifold( - clipping_quad: mat43f, - clipping_normal: wp.vec3, - subject_quad: mat43f, - subject_normal: wp.vec3, -): - # Clip the subject (incident) face onto the clipping (reference) face. - # The incident points are clipped points on the subject polygon. - incident, mask = _clip_quad( - subject_quad, subject_normal, clipping_quad, clipping_normal - ) - - clipping_normal_neg = -clipping_normal - d = wp.dot(clipping_quad[0], clipping_normal_neg) + _TINY_VAL - - for i in range(16): - if wp.dot(incident[i], clipping_normal_neg) < d: - mask[i] = wp.int8(0) - - ref = _project_poly_onto_plane( - incident, clipping_normal, clipping_normal, clipping_quad[0] - ) - - # Choose four contact points. - best = _manifold_points(ref, mask, clipping_normal) - contact_pts = mat43f() - dist = wp.vec4f() - - for i in range(4): - idx = wp.int32(best[i]) - contact_pt = ref[idx] - contact_pts[i] = contact_pt - penetration_dir = incident[idx] - contact_pt - penetration = wp.dot(penetration_dir, clipping_normal) - dist[i] = wp.where(mask[idx], penetration, 1.0) - - return dist, contact_pts - - -def box_box_narrowphase( - m: Model, - d: Data, -): - """Calculates contacts between pairs of boxes.""" - kernel_ratio = 16 - num_threads = math.ceil( - d.nconmax / kernel_ratio - ) # parallel threads excluding tile dim - wp.launch_tiled( - kernel=box_box_kernel, - dim=num_threads, - inputs=[m, d, num_threads], - block_dim=BOX_BOX_BLOCK_DIM, - ) diff --git a/mujoco_warp/_src/collision_driver.py b/mujoco_warp/_src/collision_driver.py index 00fdc5b1..82f640ec 100644 --- a/mujoco_warp/_src/collision_driver.py +++ b/mujoco_warp/_src/collision_driver.py @@ -17,7 +17,6 @@ import warp as wp -from .collision_box import box_box_narrowphase from .collision_convex import gjk_narrowphase from .collision_primitive import primitive_narrowphase from .types import MJ_MAXVAL @@ -271,4 +270,3 @@ def collision(m: Model, d: Data): # TODO(team) switch between collision functions and GJK/EPA here gjk_narrowphase(m, d) primitive_narrowphase(m, d) - box_box_narrowphase(m, d) diff --git a/mujoco_warp/_src/collision_primitive.py b/mujoco_warp/_src/collision_primitive.py index 96113848..2fd9c3d2 100644 --- a/mujoco_warp/_src/collision_primitive.py +++ b/mujoco_warp/_src/collision_primitive.py @@ -1996,6 +1996,21 @@ def _primitive_narrowphase( solimp, geoms, ) + elif type1 == int(GeomType.BOX.value) and type2 == int(GeomType.BOX.value): + box_box( + geom1, + geom2, + worldid, + d, + margin, + gap, + condim, + friction, + solref, + solreffriction, + solimp, + geoms, + ) elif type1 == int(GeomType.CAPSULE.value) and type2 == int(GeomType.BOX.value): capsule_box( geom1, From 248964c727016a12110314d5c9923f5ae5d04f01 Mon Sep 17 00:00:00 2001 From: camevor Date: Wed, 30 Apr 2025 19:24:38 +0200 Subject: [PATCH 03/12] Add boxbox tests --- mujoco_warp/_src/collision_driver_test.py | 39 +++++++++++++++++++++++ mujoco_warp/_src/collision_primitive.py | 1 - 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/mujoco_warp/_src/collision_driver_test.py b/mujoco_warp/_src/collision_driver_test.py index 2df62836..ed33d72b 100644 --- a/mujoco_warp/_src/collision_driver_test.py +++ b/mujoco_warp/_src/collision_driver_test.py @@ -40,6 +40,45 @@ class CollisionTest(parameterized.TestCase): """, + "box_box_vf": """ + + + + + + + + + + + + """, + "box_box_vf_flat": """ + + + + + + + + + + + + """, + "box_box_ee": """ + + + + + + + + + + + + """, "plane_sphere": """ diff --git a/mujoco_warp/_src/collision_primitive.py b/mujoco_warp/_src/collision_primitive.py index 2fd9c3d2..2e8d5f3b 100644 --- a/mujoco_warp/_src/collision_primitive.py +++ b/mujoco_warp/_src/collision_primitive.py @@ -25,7 +25,6 @@ from .types import Model from .types import vec5 - wp.set_module_options({"enable_backward": False}) From 0dfec6dbf743705b2bca1bc90f36f5d6a9b62d66 Mon Sep 17 00:00:00 2001 From: camevor Date: Mon, 5 May 2025 16:33:12 +0200 Subject: [PATCH 04/12] Fix edge-edge errors --- mujoco_warp/_src/collision_primitive.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mujoco_warp/_src/collision_primitive.py b/mujoco_warp/_src/collision_primitive.py index 2e8d5f3b..3f0d4ff1 100644 --- a/mujoco_warp/_src/collision_primitive.py +++ b/mujoco_warp/_src/collision_primitive.py @@ -1718,7 +1718,7 @@ def box_box( if (linesu_a[i, 2] + linesu_b[i][2] * c1) * innorm > margin: continue - points[n] = linesu_a[i] * 0.5 + c1 * linesu_b[i] + points[n] = linesu_a[i] * 0.5 + c1 * linesu_b[i] * 0.5 points[n, q] += 0.5 * l points[n, 1 - q] += 0.5 * c2 depth[n] = points[n, 2] * innorm * 2.0 @@ -1809,7 +1809,7 @@ def box_box( # Set up contact data for all points rw = box1.rot @ wp.transpose(rotmore) - normal = r @ rnorm + normal = rw @ rnorm frame = make_frame(wp.where(inv, -1.0, 1.0) * normal) coff = wp.atomic_add(d.ncon, 0, n) From c4b27e42542fbba20029b7d725fa705411a5c227 Mon Sep 17 00:00:00 2001 From: camevor Date: Wed, 7 May 2025 10:26:00 +0200 Subject: [PATCH 05/12] Bugfixes, simplifications, remove confusing comments --- mujoco_warp/_src/collision_primitive.py | 101 ++++++++---------------- 1 file changed, 35 insertions(+), 66 deletions(-) diff --git a/mujoco_warp/_src/collision_primitive.py b/mujoco_warp/_src/collision_primitive.py index 3f0d4ff1..c37b5cef 100644 --- a/mujoco_warp/_src/collision_primitive.py +++ b/mujoco_warp/_src/collision_primitive.py @@ -1265,29 +1265,29 @@ def compute_rotmore(face_idx: wp.int32) -> wp.mat33: rotmore = wp.mat33(0.0) if face_idx == 0: - rotmore[0, 2] = -1.0 # rotmore[2] - rotmore[1, 1] = +1.0 # rotmore[4] - rotmore[2, 0] = +1.0 # rotmore[6] + rotmore[0, 2] = -1.0 + rotmore[1, 1] = +1.0 + rotmore[2, 0] = +1.0 elif face_idx == 1: - rotmore[0, 0] = +1.0 # rotmore[0] - rotmore[1, 2] = -1.0 # rotmore[5] - rotmore[2, 1] = +1.0 # rotmore[7] + rotmore[0, 0] = +1.0 + rotmore[1, 2] = -1.0 + rotmore[2, 1] = +1.0 elif face_idx == 2: - rotmore[0, 0] = +1.0 # rotmore[0] - rotmore[1, 1] = +1.0 # rotmore[4] - rotmore[2, 2] = +1.0 # rotmore[8] + rotmore[0, 0] = +1.0 + rotmore[1, 1] = +1.0 + rotmore[2, 2] = +1.0 elif face_idx == 3: - rotmore[0, 2] = +1.0 # rotmore[2] - rotmore[1, 1] = +1.0 # rotmore[4] - rotmore[2, 0] = -1.0 # rotmore[6] + rotmore[0, 2] = +1.0 + rotmore[1, 1] = +1.0 + rotmore[2, 0] = -1.0 elif face_idx == 4: - rotmore[0, 0] = +1.0 # rotmore[0] - rotmore[1, 2] = +1.0 # rotmore[5] - rotmore[2, 1] = -1.0 # rotmore[7] + rotmore[0, 0] = +1.0 + rotmore[1, 2] = +1.0 + rotmore[2, 1] = -1.0 elif face_idx == 5: - rotmore[0, 0] = -1.0 # rotmore[0] - rotmore[1, 1] = +1.0 # rotmore[4] - rotmore[2, 2] = -1.0 # rotmore[8] + rotmore[0, 0] = -1.0 + rotmore[1, 1] = +1.0 + rotmore[2, 2] = -1.0 return rotmore @@ -1315,13 +1315,8 @@ def box_box( rot21 = wp.transpose(box1.rot) @ box2.rot rot12 = wp.transpose(rot21) - rot21abs = wp.transpose( - wp.mat33(wp.abs(rot21[0]), wp.abs(rot21[1]), wp.abs(rot21[2])) - ) - - rot12abs = wp.transpose( - wp.mat33(wp.abs(rot12[0]), wp.abs(rot12[1]), wp.abs(rot12[2])) - ) + rot21abs = wp.matrix_from_rows(wp.abs(rot21[0]), wp.abs(rot21[1]), wp.abs(rot21[2])) + rot12abs = wp.transpose(rot21abs) plen2 = rot21abs @ box2.size plen1 = rot12abs @ box1.size @@ -1412,11 +1407,13 @@ def box_box( if axis_code == -1: return + points = mat83f() + max_con_pair = 8 + if axis_code < 12: # Handle face-vertex collision face_idx = axis_code % 6 box_idx = axis_code / 6 - # TODO(camevor): consider single rotmore calc, or using i0,i1,i2 and f0,f1,f2 as rotmore substitute rotmore = compute_rotmore(face_idx) r = rotmore @ wp.where(box_idx, rot12, rot21) @@ -1453,11 +1450,6 @@ def box_box( dirs += 1 - # Compute additional contact points - # compute the other box's corner points - # TODO(camevor) potentially unnecessary use of registers - # pts are now the corners of the other closest face - k = dirs * dirs # Find potential contact points @@ -1478,8 +1470,6 @@ def box_box( lines_a[3] = lp + cn2 lines_b[3] = cn1 - points = mat83f() - n = wp.int32(0) for i in range(k): @@ -1520,24 +1510,15 @@ def box_box( u = (x * by - y * bx) * C v = (y * ax - x * ay) * C - if u <= 0 or v <= 0 or u >= 1 or v >= 1: - continue - - points[n] = wp.vec3(llx, lly, lp[2] + u * cn1[2] + v * cn2[2]) - - n += 1 - - points[n] = lp # TODO(camevor) consider doing this sooner or later - n += 1 + if u > 0 and v > 0 and u < 1 and v < 1: + points[n] = wp.vec3(llx, lly, lp[2] + u * cn1[2] + v * cn2[2]) + n += 1 - for i in range(1, k): + for i in range(k): tmpv = lp + wp.float32(i & 1) * cn1 + wp.float32(i & 2) * cn2 - if tmpv[0] <= -lx or tmpv[0] >= lx: - continue - if tmpv[1] <= -ly or tmpv[1] >= ly: - continue - points[n] = tmpv - n += 1 + if tmpv[0] > -lx and tmpv[0] < lx and tmpv[1] > -ly and tmpv[1] < ly: + points[n] = tmpv + n += 1 m = n n = wp.int32(0) @@ -1556,8 +1537,6 @@ def box_box( pw = wp.where(box_idx, box2.pos, box1.pos) normal = wp.where(box_idx, -1.0, 1.0) * wp.transpose(rw)[2] - - # TODO(camevor): explicit frame frame = make_frame(normal) coff = wp.atomic_add(d.ncon, 0, n) @@ -1583,8 +1562,6 @@ def box_box( d.contact.solreffriction[cid] = solreffriction d.contact.solimp[cid] = solimp - # return - else: # Handle edge-edge collision edge1 = (axis_code - 12) / 3 @@ -1598,14 +1575,10 @@ def box_box( pax2 = wp.int(2 - (edge1 & 2)) if rot21abs[edge1, ax1] < rot21abs[edge1, ax2]: - ax_tmp = ax1 - ax1 = ax2 - ax2 = ax_tmp + ax1, ax2 = ax2, ax1 if rot12abs[edge2, pax1] < rot12abs[edge2, pax2]: - pax_tmp = pax1 - pax1 = pax2 - pax2 = pax_tmp + pax1, pax2 = pax2, pax1 rotmore = compute_rotmore(wp.where(cle1 & (1 << pax2), pax2, pax2 + 3)) @@ -1620,7 +1593,6 @@ def box_box( p[2] -= hz # Calculate closest box2 face - points = mat83f() points[0] = ( p @@ -1690,13 +1662,10 @@ def box_box( linesu_a[3] = axi_lp + axi_cn2 linesu_b[3] = axi_cn1 - k = 4 n = wp.int32(0) - max_con_pair = 8 depth = vec8f() - - for i in range(k): + for i in range(4): for q in range(2): la = lines_a[i, q] lb = lines_b[i, q] @@ -1809,8 +1778,8 @@ def box_box( # Set up contact data for all points rw = box1.rot @ wp.transpose(rotmore) - normal = rw @ rnorm - frame = make_frame(wp.where(inv, -1.0, 1.0) * normal) + normal = wp.where(inv, -1.0, 1.0) * rw @ rnorm + frame = make_frame(normal) coff = wp.atomic_add(d.ncon, 0, n) From 7b0a494b6742d06edee3e09c2264353dfec37162 Mon Sep 17 00:00:00 2001 From: camevor Date: Wed, 7 May 2025 11:31:30 +0200 Subject: [PATCH 06/12] Merge after fv and ee branches --- mujoco_warp/_src/collision_driver_test.py | 19 +++++- mujoco_warp/_src/collision_primitive.py | 76 +++++++---------------- 2 files changed, 40 insertions(+), 55 deletions(-) diff --git a/mujoco_warp/_src/collision_driver_test.py b/mujoco_warp/_src/collision_driver_test.py index ed33d72b..3cd13a5c 100644 --- a/mujoco_warp/_src/collision_driver_test.py +++ b/mujoco_warp/_src/collision_driver_test.py @@ -56,10 +56,10 @@ class CollisionTest(parameterized.TestCase): "box_box_vf_flat": """ - - + + - + @@ -79,6 +79,19 @@ class CollisionTest(parameterized.TestCase): """, + "box_box_ee_deep": """ + + + + + + + + + + + + """, "plane_sphere": """ diff --git a/mujoco_warp/_src/collision_primitive.py b/mujoco_warp/_src/collision_primitive.py index c37b5cef..2f38a822 100644 --- a/mujoco_warp/_src/collision_primitive.py +++ b/mujoco_warp/_src/collision_primitive.py @@ -1322,7 +1322,6 @@ def box_box( plen1 = rot12abs @ box1.size # Compute axis of maximum separation - s_sum_3 = 3.0 * (box1.size + box2.size) separation = wp.float32(margin + s_sum_3[0] + s_sum_3[1] + s_sum_3[2]) axis_code = wp.int32(-1) @@ -1352,7 +1351,6 @@ def box_box( for i in range(3): for j in range(3): # Compute cross product of box edges (potential separating axis) - if i == 0: cross_axis = wp.vec3(0.0, -rot12[j, 2], rot12[j, 1]) elif i == 1: @@ -1361,7 +1359,6 @@ def box_box( cross_axis = wp.vec3(-rot12[j, 1], rot12[j, 0], 0.0) cross_length = wp.length(cross_axis) - if cross_length < MJ_MINVAL: continue @@ -1384,7 +1381,6 @@ def box_box( return # Track minimum separation and which edge-edge pair it occurs on - if c3 < separation * (1.0 - 1e-12): separation = c3 # Determine which corners/edges are closest @@ -1408,7 +1404,9 @@ def box_box( return points = mat83f() + depth = vec8f() max_con_pair = 8 + # 8 contacts should suffice for most configurations if axis_code < 12: # Handle face-vertex collision @@ -1530,37 +1528,13 @@ def box_box( points[n] = points[i] points[n, 2] *= 0.5 + depth[n] = points[n, 2] n += 1 # Set up contact frame rw = wp.where(box_idx, box2.rot, box1.rot) @ wp.transpose(rotmore) pw = wp.where(box_idx, box2.pos, box1.pos) - normal = wp.where(box_idx, -1.0, 1.0) * wp.transpose(rw)[2] - frame = make_frame(normal) - - coff = wp.atomic_add(d.ncon, 0, n) - - for i in range(min(d.nconmax - coff, n)): - dist = points[i, 2] - - # Transform contact point to world frame - points[i, 2] += hz - pos = rw @ points[i] + pw - - cid = coff + i - - d.contact.dist[cid] = dist - d.contact.pos[cid] = pos - d.contact.frame[cid] = frame - d.contact.geom[cid] = geoms - d.contact.worldid[cid] = worldid - d.contact.includemargin[cid] = margin - gap - d.contact.dim[cid] = condim - d.contact.friction[cid] = friction - d.contact.solref[cid] = solref - d.contact.solreffriction[cid] = solreffriction - d.contact.solimp[cid] = solimp else: # Handle edge-edge collision @@ -1664,7 +1638,6 @@ def box_box( n = wp.int32(0) - depth = vec8f() for i in range(4): for q in range(2): la = lines_a[i, q] @@ -1778,30 +1751,29 @@ def box_box( # Set up contact data for all points rw = box1.rot @ wp.transpose(rotmore) + pw = box1.pos normal = wp.where(inv, -1.0, 1.0) * rw @ rnorm - frame = make_frame(normal) - - coff = wp.atomic_add(d.ncon, 0, n) - - for i in range(min(d.nconmax - coff, n)): - dist = depth[i] - points[i, 2] += hz - pos = rw @ points[i] + box1.pos - - cid = coff + i - - d.contact.dist[cid] = dist - d.contact.pos[cid] = pos - d.contact.frame[cid] = frame - d.contact.geom[cid] = geoms - d.contact.worldid[cid] = worldid - d.contact.includemargin[cid] = margin - gap - d.contact.dim[cid] = condim - d.contact.friction[cid] = friction - d.contact.solref[cid] = solref - d.contact.solreffriction[cid] = solreffriction - d.contact.solimp[cid] = solimp + frame = make_frame(normal) + coff = wp.atomic_add(d.ncon, 0, n) + + for i in range(min(d.nconmax - coff, n)): + points[i, 2] += hz + pos = rw @ points[i] + pw + + cid = coff + i + + d.contact.dist[cid] = depth[i] + d.contact.pos[cid] = pos + d.contact.frame[cid] = frame + d.contact.geom[cid] = geoms + d.contact.worldid[cid] = worldid + d.contact.includemargin[cid] = margin - gap + d.contact.dim[cid] = condim + d.contact.friction[cid] = friction + d.contact.solref[cid] = solref + d.contact.solreffriction[cid] = solreffriction + d.contact.solimp[cid] = solimp @wp.kernel From a575a7ae198243fffbcb92541e2398834a2671ca Mon Sep 17 00:00:00 2001 From: camevor Date: Wed, 7 May 2025 12:53:05 +0200 Subject: [PATCH 07/12] Remove unnecessary temporary matrices --- mujoco_warp/_src/collision_primitive.py | 74 ++++++------------------- 1 file changed, 18 insertions(+), 56 deletions(-) diff --git a/mujoco_warp/_src/collision_primitive.py b/mujoco_warp/_src/collision_primitive.py index 2f38a822..d919270a 100644 --- a/mujoco_warp/_src/collision_primitive.py +++ b/mujoco_warp/_src/collision_primitive.py @@ -1451,44 +1451,27 @@ def box_box( k = dirs * dirs # Find potential contact points - lines_a = mat43f() - lines_b = mat43f() - - if dirs: - lines_a[0] = lp - lines_b[0] = cn1 - - if dirs == 2: - lines_a[1] = lp - lines_b[1] = cn2 - - lines_a[2] = lp + cn1 - lines_b[2] = cn2 - - lines_a[3] = lp + cn2 - lines_b[3] = cn1 n = wp.int32(0) for i in range(k): for q in range(2): - la = lines_a[i, q] - lb = lines_b[i, q] - lc = lines_a[i, 1 - q] - ld = lines_b[i, 1 - q] + # lines_a and lines_b (lines between corners) computed on the fly + lav = lp + wp.where(i < 2, wp.vec3(0.0), wp.where(i == 2, cn1, cn2)) + lbv = wp.where(i == 0 or i == 3, cn1, cn2) - if wp.abs(lb) > MJ_MINVAL: - br = 1.0 / lb + if wp.abs(lbv[q]) > MJ_MINVAL: + br = 1.0 / lbv[q] for j in range(-1, 2, 2): l = ss[q] * wp.float32(j) - c1 = (l - la) * br + c1 = (l - lav[q]) * br if c1 < 0 or c1 > 1: continue - c2 = lc + ld * c1 + c2 = lav[1 - q] + lbv[1 - q] * c1 if wp.abs(c2) > ss[1 - q]: continue - points[n] = lines_a[i] + c1 * lines_b[i] + points[n] = lav + c1 * lbv n += 1 if dirs == 2: @@ -1611,39 +1594,18 @@ def box_box( pts_cn1 = points[1] - points[0] pts_cn2 = points[2] - points[0] - lines_a = mat43f() - lines_b = mat43f() - linesu_a = mat43f() - linesu_b = mat43f() - - lines_a[0] = pts_lp - lines_b[0] = pts_cn1 - linesu_a[0] = axi_lp - linesu_b[0] = axi_cn1 - - lines_a[1] = pts_lp - lines_b[1] = pts_cn2 - linesu_a[1] = axi_lp - linesu_b[1] = axi_cn2 - - lines_a[2] = pts_lp + pts_cn1 - lines_b[2] = pts_cn2 - linesu_a[2] = axi_lp + axi_cn1 - linesu_b[2] = axi_cn2 - - lines_a[3] = pts_lp + pts_cn2 - lines_b[3] = pts_cn1 - linesu_a[3] = axi_lp + axi_cn2 - linesu_b[3] = axi_cn1 - n = wp.int32(0) for i in range(4): for q in range(2): - la = lines_a[i, q] - lb = lines_b[i, q] - lc = lines_a[i, 1 - q] - ld = lines_b[i, 1 - q] + la = pts_lp[q] + wp.where(i < 2, 0.0, wp.where(i == 2, pts_cn1[q], pts_cn2[q])) + lb = wp.where(i == 0 or i == 3, pts_cn1[q], pts_cn2[q]) + lc = pts_lp[1-q] + wp.where(i < 2, 0.0, wp.where(i == 2, pts_cn1[1-q], pts_cn2[1-q])) + ld = wp.where(i == 0 or i == 3, pts_cn1[1-q], pts_cn2[1-q]) + + # linesu_a and linesu_b (lines between corners) computed on the fly + lua = axi_lp + wp.where(i < 2, wp.vec3(0.0), wp.where(i == 2, axi_cn1, axi_cn2)) + lub = wp.where(i == 0 or i == 3, axi_cn1, axi_cn2) if wp.abs(lb) > MJ_MINVAL: br = 1.0 / lb @@ -1657,10 +1619,10 @@ def box_box( c2 = lc + ld * c1 if wp.abs(c2) > s[1 - q]: continue - if (linesu_a[i, 2] + linesu_b[i][2] * c1) * innorm > margin: + if (lua[2]+ lub[2] * c1) * innorm > margin: continue - points[n] = linesu_a[i] * 0.5 + c1 * linesu_b[i] * 0.5 + points[n] = lua * 0.5 + c1 * lub * 0.5 points[n, q] += 0.5 * l points[n, 1 - q] += 0.5 * c2 depth[n] = points[n, 2] * innorm * 2.0 From badff64e734792440dd9214ac592c882262c4344 Mon Sep 17 00:00:00 2001 From: camevor Date: Wed, 7 May 2025 13:28:54 +0200 Subject: [PATCH 08/12] Autoformat --- mujoco_warp/_src/collision_primitive.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/mujoco_warp/_src/collision_primitive.py b/mujoco_warp/_src/collision_primitive.py index d919270a..8e8ba713 100644 --- a/mujoco_warp/_src/collision_primitive.py +++ b/mujoco_warp/_src/collision_primitive.py @@ -1600,8 +1600,10 @@ def box_box( for q in range(2): la = pts_lp[q] + wp.where(i < 2, 0.0, wp.where(i == 2, pts_cn1[q], pts_cn2[q])) lb = wp.where(i == 0 or i == 3, pts_cn1[q], pts_cn2[q]) - lc = pts_lp[1-q] + wp.where(i < 2, 0.0, wp.where(i == 2, pts_cn1[1-q], pts_cn2[1-q])) - ld = wp.where(i == 0 or i == 3, pts_cn1[1-q], pts_cn2[1-q]) + lc = pts_lp[1 - q] + wp.where( + i < 2, 0.0, wp.where(i == 2, pts_cn1[1 - q], pts_cn2[1 - q]) + ) + ld = wp.where(i == 0 or i == 3, pts_cn1[1 - q], pts_cn2[1 - q]) # linesu_a and linesu_b (lines between corners) computed on the fly lua = axi_lp + wp.where(i < 2, wp.vec3(0.0), wp.where(i == 2, axi_cn1, axi_cn2)) @@ -1619,7 +1621,7 @@ def box_box( c2 = lc + ld * c1 if wp.abs(c2) > s[1 - q]: continue - if (lua[2]+ lub[2] * c1) * innorm > margin: + if (lua[2] + lub[2] * c1) * innorm > margin: continue points[n] = lua * 0.5 + c1 * lub * 0.5 From c96c3296e2ffdaf4bf9e9d2d7d6f464360defd7b Mon Sep 17 00:00:00 2001 From: camevor Date: Fri, 9 May 2025 16:45:44 +0200 Subject: [PATCH 09/12] Fix vf errors --- mujoco_warp/_src/collision_primitive.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mujoco_warp/_src/collision_primitive.py b/mujoco_warp/_src/collision_primitive.py index 8e8ba713..7050afd7 100644 --- a/mujoco_warp/_src/collision_primitive.py +++ b/mujoco_warp/_src/collision_primitive.py @@ -1495,8 +1495,8 @@ def box_box( points[n] = wp.vec3(llx, lly, lp[2] + u * cn1[2] + v * cn2[2]) n += 1 - for i in range(k): - tmpv = lp + wp.float32(i & 1) * cn1 + wp.float32(i & 2) * cn2 + for i in range(1< -lx and tmpv[0] < lx and tmpv[1] > -ly and tmpv[1] < ly: points[n] = tmpv n += 1 From ea165a7028fe206a09112c9623a30591843c7472 Mon Sep 17 00:00:00 2001 From: camevor Date: Fri, 9 May 2025 18:39:37 +0200 Subject: [PATCH 10/12] Implement new func signature format --- mujoco_warp/_src/collision_primitive.py | 59 ++++++++++++++++++------- 1 file changed, 43 insertions(+), 16 deletions(-) diff --git a/mujoco_warp/_src/collision_primitive.py b/mujoco_warp/_src/collision_primitive.py index cff26583..b1f8b12d 100644 --- a/mujoco_warp/_src/collision_primitive.py +++ b/mujoco_warp/_src/collision_primitive.py @@ -1731,7 +1731,7 @@ def capsule_box( @wp.func -def compute_rotmore(face_idx: wp.int32) -> wp.mat33: +def compute_rotmore(face_idx: int) -> wp.mat33: rotmore = wp.mat33(0.0) if face_idx == 0: @@ -1764,10 +1764,12 @@ def compute_rotmore(face_idx: wp.int32) -> wp.mat33: @wp.func def box_box( + # Data in: + nconmax_in: int, + # In: box1: Geom, box2: Geom, worldid: int, - d: Data, margin: float, gap: float, condim: int, @@ -1776,6 +1778,19 @@ def box_box( solreffriction: wp.vec2f, solimp: vec5, geoms: wp.vec2i, + # Data out: + ncon_out: wp.array(dtype=int), + contact_dist_out: wp.array(dtype=float), + contact_pos_out: wp.array(dtype=wp.vec3), + contact_frame_out: wp.array(dtype=wp.mat33), + contact_includemargin_out: wp.array(dtype=float), + contact_friction_out: wp.array(dtype=vec5), + contact_solref_out: wp.array(dtype=wp.vec2), + contact_solreffriction_out: wp.array(dtype=wp.vec2), + contact_solimp_out: wp.array(dtype=vec5), + contact_dim_out: wp.array(dtype=int), + contact_geom_out: wp.array(dtype=wp.vec2i), + contact_worldid_out: wp.array(dtype=int), ): # Compute transforms between box's frames @@ -2189,25 +2204,25 @@ def box_box( normal = wp.where(inv, -1.0, 1.0) * rw @ rnorm frame = make_frame(normal) - coff = wp.atomic_add(d.ncon, 0, n) + coff = wp.atomic_add(ncon_out, 0, n) - for i in range(min(d.nconmax - coff, n)): + for i in range(min(nconmax_in - coff, n)): points[i, 2] += hz pos = rw @ points[i] + pw cid = coff + i - d.contact.dist[cid] = depth[i] - d.contact.pos[cid] = pos - d.contact.frame[cid] = frame - d.contact.geom[cid] = geoms - d.contact.worldid[cid] = worldid - d.contact.includemargin[cid] = margin - gap - d.contact.dim[cid] = condim - d.contact.friction[cid] = friction - d.contact.solref[cid] = solref - d.contact.solreffriction[cid] = solreffriction - d.contact.solimp[cid] = solimp + contact_dist_out[cid] = depth[i] + contact_pos_out[cid] = pos + contact_frame_out[cid] = frame + contact_geom_out[cid] = geoms + contact_worldid_out[cid] = worldid + contact_includemargin_out[cid] = margin - gap + contact_dim_out[cid] = condim + contact_friction_out[cid] = friction + contact_solref_out[cid] = solref + contact_solreffriction_out[cid] = solreffriction + contact_solimp_out[cid] = solimp @wp.kernel @@ -2555,10 +2570,10 @@ def _primitive_narrowphase( ) elif type1 == int(GeomType.BOX.value) and type2 == int(GeomType.BOX.value): box_box( + nconmax_in, geom1, geom2, worldid, - d, margin, gap, condim, @@ -2567,6 +2582,18 @@ def _primitive_narrowphase( solreffriction, solimp, geoms, + ncon_out, + contact_dist_out, + contact_pos_out, + contact_frame_out, + contact_includemargin_out, + contact_friction_out, + contact_solref_out, + contact_solreffriction_out, + contact_solimp_out, + contact_dim_out, + contact_geom_out, + contact_worldid_out, ) elif type1 == int(GeomType.CAPSULE.value) and type2 == int(GeomType.BOX.value): capsule_box( From 0abed55b4595aeb635bc8fe6aa320b4288e1b319 Mon Sep 17 00:00:00 2001 From: camevor Date: Fri, 9 May 2025 18:46:13 +0200 Subject: [PATCH 11/12] Ruff format --- mujoco_warp/_src/collision_primitive.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/mujoco_warp/_src/collision_primitive.py b/mujoco_warp/_src/collision_primitive.py index b1f8b12d..9f293e1e 100644 --- a/mujoco_warp/_src/collision_primitive.py +++ b/mujoco_warp/_src/collision_primitive.py @@ -1875,9 +1875,7 @@ def box_box( for k in range(3): if k != i and (int(cross_axis[k] > 0) ^ int(box_dist < 0)): cle1 += 1 << k - if k != j and ( - int(rot21[i, 3 - k - j] > 0) ^ int(box_dist < 0) ^ int((k - j + 3) % 3 == 1) - ): + if k != j and (int(rot21[i, 3 - k - j] > 0) ^ int(box_dist < 0) ^ int((k - j + 3) % 3 == 1)): cle2 += 1 << k axis_code = 12 + i * 3 + j @@ -1980,8 +1978,8 @@ def box_box( points[n] = wp.vec3(llx, lly, lp[2] + u * cn1[2] + v * cn2[2]) n += 1 - for i in range(1< -lx and tmpv[0] < lx and tmpv[1] > -ly and tmpv[1] < ly: points[n] = tmpv n += 1 @@ -2085,9 +2083,7 @@ def box_box( for q in range(2): la = pts_lp[q] + wp.where(i < 2, 0.0, wp.where(i == 2, pts_cn1[q], pts_cn2[q])) lb = wp.where(i == 0 or i == 3, pts_cn1[q], pts_cn2[q]) - lc = pts_lp[1 - q] + wp.where( - i < 2, 0.0, wp.where(i == 2, pts_cn1[1 - q], pts_cn2[1 - q]) - ) + lc = pts_lp[1 - q] + wp.where(i < 2, 0.0, wp.where(i == 2, pts_cn1[1 - q], pts_cn2[1 - q])) ld = wp.where(i == 0 or i == 3, pts_cn1[1 - q], pts_cn2[1 - q]) # linesu_a and linesu_b (lines between corners) computed on the fly From 915a924b61e0aecec6db654431dae841ed7f3c49 Mon Sep 17 00:00:00 2001 From: camevor Date: Mon, 12 May 2025 10:00:42 +0200 Subject: [PATCH 12/12] Make `compute_rotmore` module private --- mujoco_warp/_src/collision_primitive.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mujoco_warp/_src/collision_primitive.py b/mujoco_warp/_src/collision_primitive.py index 9f293e1e..692feed7 100644 --- a/mujoco_warp/_src/collision_primitive.py +++ b/mujoco_warp/_src/collision_primitive.py @@ -1731,7 +1731,7 @@ def capsule_box( @wp.func -def compute_rotmore(face_idx: int) -> wp.mat33: +def _compute_rotmore(face_idx: int) -> wp.mat33: rotmore = wp.mat33(0.0) if face_idx == 0: @@ -1895,7 +1895,7 @@ def box_box( # Handle face-vertex collision face_idx = axis_code % 6 box_idx = axis_code / 6 - rotmore = compute_rotmore(face_idx) + rotmore = _compute_rotmore(face_idx) r = rotmore @ wp.where(box_idx, rot12, rot21) p = rotmore @ wp.where(box_idx, pos12, pos21) @@ -2020,7 +2020,7 @@ def box_box( if rot12abs[edge2, pax1] < rot12abs[edge2, pax2]: pax1, pax2 = pax2, pax1 - rotmore = compute_rotmore(wp.where(cle1 & (1 << pax2), pax2, pax2 + 3)) + rotmore = _compute_rotmore(wp.where(cle1 & (1 << pax2), pax2, pax2 + 3)) # Transform coordinates for edge-edge contact calculation p = rotmore @ pos21