diff --git a/mujoco_warp/_src/collision_driver_test.py b/mujoco_warp/_src/collision_driver_test.py
index d62853f7..2df62836 100644
--- a/mujoco_warp/_src/collision_driver_test.py
+++ b/mujoco_warp/_src/collision_driver_test.py
@@ -225,6 +225,50 @@ class CollisionTest(parameterized.TestCase):
""",
+ "capsule_box_edge": """
+
+
+
+
+
+
+
+
+
+ """,
+ "capsule_box_corner": """
+
+
+
+
+
+
+
+
+
+ """,
+ "capsule_box_face_tip": """
+
+
+
+
+
+
+
+
+
+ """,
+ "capsule_box_face_flat": """
+
+
+
+
+
+
+
+
+
+ """,
}
# Temporarily disabled
@@ -250,6 +294,10 @@ def test_collision(self, fixture):
"""Tests convex collision with different geometries."""
mjm, mjd, m, d = test_util.fixture(xml=self._FIXTURES[fixture])
+ # Exempt GJK collisions from exact contact count check
+ # because GJK generates more contacts
+ allow_different_contact_count = False
+
mujoco.mj_collision(mjm, mjd)
mjwarp.collision(m, d)
@@ -257,7 +305,6 @@ def test_collision(self, fixture):
actual_dist = mjd.contact.dist[i]
actual_pos = mjd.contact.pos[i]
actual_frame = mjd.contact.frame[i]
- # This is because Gjk generates more contact
result = False
for j in range(d.ncon.numpy()[0]):
test_dist = d.contact.dist.numpy()[j]
@@ -271,6 +318,9 @@ def test_collision(self, fixture):
break
np.testing.assert_equal(result, True, f"Contact {i} not found in Gjk results")
+ if not allow_different_contact_count:
+ self.assertEqual(d.ncon.numpy()[0], mjd.ncon)
+
def test_contact_exclude(self):
"""Tests contact exclude."""
mjm = mujoco.MjModel.from_xml_string("""
diff --git a/mujoco_warp/_src/collision_primitive.py b/mujoco_warp/_src/collision_primitive.py
index 426ec093..ffef2655 100644
--- a/mujoco_warp/_src/collision_primitive.py
+++ b/mujoco_warp/_src/collision_primitive.py
@@ -796,9 +796,12 @@ def contact_params(m: Model, d: Data, cid: int):
@wp.func
-def sphere_box(
- sphere: Geom,
- box: Geom,
+def _sphere_box(
+ sphere_pos: wp.vec3,
+ sphere_size: float,
+ box_pos: wp.vec3,
+ box_rot: wp.mat33,
+ box_size: wp.vec3,
worldid: int,
d: Data,
margin: float,
@@ -810,37 +813,37 @@ def sphere_box(
solimp: vec5,
geoms: wp.vec2i,
):
- center = wp.transpose(box.rot) @ (sphere.pos - box.pos)
+ center = wp.transpose(box_rot) @ (sphere_pos - box_pos)
- clamped = wp.max(-box.size, wp.min(box.size, center))
+ clamped = wp.max(-box_size, wp.min(box_size, center))
clamped_dir, dist = normalize_with_norm(clamped - center)
- if dist - sphere.size[0] > margin:
+ if dist - sphere_size > margin:
return
# sphere center inside box
if dist <= MJ_MINVAL:
- closest = 2.0 * (box.size[0] + box.size[1] + box.size[2])
+ closest = 2.0 * (box_size[0] + box_size[1] + box_size[2])
k = wp.int32(0)
for i in range(6):
- face_dist = wp.abs(wp.where(i % 2, 1.0, -1.0) * box.size[i / 2] - center[i / 2])
+ face_dist = wp.abs(wp.where(i % 2, 1.0, -1.0) * box_size[i / 2] - center[i / 2])
if closest > face_dist:
closest = face_dist
k = i
nearest = wp.vec3(0.0)
nearest[k / 2] = wp.where(k % 2, -1.0, 1.0)
- pos = center + nearest * (sphere.size[0] - closest) / 2.0
- contact_normal = box.rot @ nearest
- contact_dist = -closest - sphere.size[0]
+ pos = center + nearest * (sphere_size - closest) / 2.0
+ contact_normal = box_rot @ nearest
+ contact_dist = -closest - sphere_size
else:
- deepest = center + clamped_dir * sphere.size[0]
+ deepest = center + clamped_dir * sphere_size
pos = 0.5 * (clamped + deepest)
- contact_normal = box.rot @ clamped_dir
- contact_dist = dist - sphere.size[0]
+ contact_normal = box_rot @ clamped_dir
+ contact_dist = dist - sphere_size
- contact_pos = box.pos + box.rot @ pos
+ contact_pos = box_pos + box_rot @ pos
write_contact(
d,
contact_dist,
@@ -858,6 +861,393 @@ def sphere_box(
)
+@wp.func
+def sphere_box(
+ sphere: Geom,
+ box: Geom,
+ worldid: int,
+ d: Data,
+ margin: float,
+ gap: float,
+ condim: int,
+ friction: vec5,
+ solref: wp.vec2f,
+ solreffriction: wp.vec2f,
+ solimp: vec5,
+ geoms: wp.vec2i,
+):
+ _sphere_box(
+ sphere.pos,
+ sphere.size[0],
+ box.pos,
+ box.rot,
+ box.size,
+ worldid,
+ d,
+ margin,
+ gap,
+ condim,
+ friction,
+ solref,
+ solreffriction,
+ solimp,
+ geoms,
+ )
+
+
+@wp.func
+def capsule_box(
+ cap: Geom,
+ box: Geom,
+ worldid: int,
+ d: Data,
+ margin: float,
+ gap: float,
+ condim: int,
+ friction: vec5,
+ solref: wp.vec2f,
+ solreffriction: wp.vec2f,
+ solimp: vec5,
+ geoms: wp.vec2i,
+):
+ """Calculates contacts between a capsule and a box."""
+ # Based on the mjc implementation
+ pos = wp.transpose(box.rot) @ (cap.pos - box.pos)
+ axis = wp.vec3(cap.rot[0, 2], cap.rot[1, 2], cap.rot[2, 2])
+ halfaxis = axis * cap.size[1] # halfaxis is the capsule direction
+ axisdir = (
+ wp.int32(axis[0] > 0.0) + 2 * wp.int32(axis[1] > 0.0) + 4 * wp.int32(axis[2] > 0.0)
+ )
+
+ bestdistmax = margin + 2.0 * (
+ cap.size[0] + cap.size[1] + box.size[0] + box.size[1] + box.size[2]
+ )
+
+ # keep track of closest point
+ bestdist = wp.float32(bestdistmax)
+ bestsegmentpos = wp.float32(-12)
+
+ # cltype: encoded collision configuration
+ # cltype / 3 == 0 : lower corner is closest to the capsule
+ # == 2 : upper corner is closest to the capsule
+ # == 1 : middle of the edge is closest to the capsule
+ # cltype % 3 == 0 : lower corner is closest to the box
+ # == 2 : upper corner is closest to the box
+ # == 1 : middle of the capsule is closest to the box
+ cltype = wp.int32(-4)
+
+ # clface: index of the closest face of the box to the capsule
+ # -1: no face is closest (edge or corner is closest)
+ # 0, 1, 2: index of the axis perpendicular to the closest face
+ clface = wp.int32(-12)
+
+ # first: consider cases where a face of the box is closest
+ for i in range(-1, 2, 2):
+ axisTip = pos + wp.float32(i) * halfaxis
+ boxPoint = wp.vec3(axisTip)
+
+ n_out = wp.int32(0)
+ ax_out = wp.int32(-1)
+
+ for j in range(3):
+ if boxPoint[j] < -box.size[j]:
+ n_out += 1
+ ax_out = j
+ boxPoint[j] = -box.size[j]
+ elif boxPoint[j] > box.size[j]:
+ n_out += 1
+ ax_out = j
+ boxPoint[j] = box.size[j]
+
+ if n_out > 1:
+ continue
+
+ dist = wp.length_sq(boxPoint - axisTip)
+
+ if dist < bestdist:
+ bestdist = dist
+ bestsegmentpos = wp.float32(i)
+ cltype = -2 + i
+ clface = ax_out
+
+ # second: consider cases where an edge of the box is closest
+ clcorner = wp.int32(-123) # which corner is the closest
+ cledge = wp.int32(-123) # which axis
+ bestboxpos = wp.float32(0.0)
+
+ for i in range(8):
+ for j in range(3):
+ if i & (1 << j) != 0:
+ continue
+
+ c2 = wp.int32(-123)
+
+ # box_pt is the starting point (corner) on the box
+ box_pt = wp.cw_mul(
+ wp.vec3(
+ wp.where(i & 1, 1.0, -1.0),
+ wp.where(i & 2, 1.0, -1.0),
+ wp.where(i & 4, 1.0, -1.0),
+ ),
+ box.size,
+ )
+ box_pt[j] = 0.0
+
+ # find closest point between capsule and the edge
+ dif = box_pt - pos
+
+ u = -box.size[j] * dif[j]
+ v = wp.dot(halfaxis, dif)
+ ma = box.size[j] * box.size[j]
+ mb = -box.size[j] * halfaxis[j]
+ mc = cap.size[1] * cap.size[1]
+ det = ma * mc - mb * mb
+ if wp.abs(det) < MJ_MINVAL:
+ continue
+
+ idet = 1.0 / det
+ # sX : X=1 means middle of segment. X=0 or 2 one or the other end
+
+ x1 = wp.float32((mc * u - mb * v) * idet)
+ x2 = wp.float32((ma * v - mb * u) * idet)
+
+ s1 = wp.int32(1)
+ s2 = wp.int32(1)
+
+ if x1 > 1:
+ x1 = 1.0
+ s1 = 2
+ x2 = (v - mb) / mc
+ elif x1 < -1:
+ x1 = -1.0
+ s1 = 0
+ x2 = (v + mb) / mc
+
+ x2_over = x2 > 1.0
+ if x2_over or x2 < -1.0:
+ if x2_over:
+ x2 = 1.0
+ s2 = 2
+ x1 = (u - mb) / ma
+ else:
+ x2 = -1.0
+ s2 = 0
+ x1 = (u + mb) / ma
+
+ if x1 > 1:
+ x1 = 1.0
+ s1 = 2
+ elif x1 < -1:
+ x1 = -1.0
+ s1 = 0
+
+ dif -= halfaxis * x2
+ dif[j] += box.size[j] * x1
+
+ # encode relative positions of the closest points
+ ct = s1 * 3 + s2
+
+ dif_sq = wp.length_sq(dif)
+ if dif_sq < bestdist - MJ_MINVAL:
+ bestdist = dif_sq
+ bestsegmentpos = x2
+ bestboxpos = x1
+ # ct<6 means closest point on box is at lower end or middle of edge
+ c2 = ct / 6
+
+ clcorner = i + (1 << j) * c2 # index of closest box corner
+ cledge = j # axis index of closest box edge
+ cltype = ct # encoded collision configuration
+
+ best = wp.float32(0.0)
+ l = wp.float32(0.0)
+
+ p = wp.vec2(pos.x, pos.y)
+ dd = wp.vec2(halfaxis.x, halfaxis.y)
+ s = wp.vec2(box.size.x, box.size.y)
+ secondpos = wp.float32(-4.0)
+
+ l = wp.length_sq(dd)
+
+ uu = dd.x * s.y
+ vv = dd.y * s.x
+ w_neg = dd.x * p.y - dd.y * p.x < 0
+
+ best = wp.float32(-1.0)
+
+ ee1 = uu - vv
+ ee2 = uu + vv
+
+ if wp.abs(ee1) > best:
+ best = wp.abs(ee1)
+ c1 = wp.where((ee1 < 0) == w_neg, 0, 3)
+
+ if wp.abs(ee2) > best:
+ best = wp.abs(ee2)
+ c1 = wp.where((ee2 > 0) == w_neg, 1, 2)
+
+ if cltype == -4: # invalid type
+ return
+
+ if cltype >= 0 and cltype / 3 != 1: # closest to a corner of the box
+ c1 = axisdir ^ clcorner
+ # Calculate relative orientation between capsule and corner
+ # There are two possible configurations:
+ # 1. Capsule axis points toward/away from corner
+ # 2. Capsule axis aligns with a face or edge
+ if c1 != 0 and c1 != 7: # create second contact point
+ if c1 == 1 or c1 == 2 or c1 == 4:
+ mul = 1
+ else:
+ mul = -1
+ c1 = 7 - c1
+
+ # "de" and "dp" distance from first closest point on the capsule to both ends of it
+ # mul is a direction along the capsule's axis
+
+ if c1 == 1:
+ ax = 0
+ ax1 = 1
+ ax2 = 2
+ elif c1 == 2:
+ ax = 1
+ ax1 = 2
+ ax2 = 0
+ elif c1 == 4:
+ ax = 2
+ ax1 = 0
+ ax2 = 1
+
+ if axis[ax] * axis[ax] > 0.5: # second point along the edge of the box
+ m = 2.0 * box.size[ax] / wp.abs(halfaxis[ax])
+ secondpos = min(1.0 - wp.float32(mul) * bestsegmentpos, m)
+ else: # second point along a face of the box
+ # check for overshoot again
+ m = 2.0 * min(
+ box.size[ax1] / wp.abs(halfaxis[ax1]), box.size[ax2] / wp.abs(halfaxis[ax2])
+ )
+ secondpos = -min(1.0 + wp.float32(mul) * bestsegmentpos, m)
+ secondpos *= wp.float32(mul)
+
+ elif cltype >= 0 and cltype / 3 == 1: # we are on box's edge
+ # Calculate relative orientation between capsule and edge
+ # Two possible configurations:
+ # - T configuration: c1 = 2^n (no additional contacts)
+ # - X configuration: c1 != 2^n (potential additional contacts)
+ c1 = axisdir ^ clcorner
+ c1 &= 7 - (1 << cledge) # mask out edge axis to determine configuration
+
+ if c1 == 1 or c1 == 2 or c1 == 4: # create second contact point
+ if cledge == 0:
+ ax1 = 1
+ ax2 = 2
+ if cledge == 1:
+ ax1 = 2
+ ax2 = 0
+ if cledge == 2:
+ ax1 = 0
+ ax2 = 1
+ ax = cledge
+
+ # Then it finds with which face the capsule has a lower angle and switches the axis names
+ if wp.abs(axis[ax1]) > wp.abs(axis[ax2]):
+ ax1 = ax2
+ ax2 = 3 - ax - ax1
+
+ # mul determines direction along capsule axis for second contact point
+ if c1 & (1 << ax2):
+ mul = 1
+ secondpos = 1.0 - bestsegmentpos
+ else:
+ mul = -1
+ secondpos = 1.0 + bestsegmentpos
+
+ # now we have to find out whether we point towards the opposite side or towards one of the
+ # sides and also find the farthest point along the capsule that is above the box
+
+ e1 = 2.0 * box.size[ax2] / wp.abs(halfaxis[ax2])
+ secondpos = min(e1, secondpos)
+
+ if ((axisdir & (1 << ax)) != 0) == ((c1 & (1 << ax2)) != 0):
+ e2 = 1.0 - bestboxpos
+ else:
+ e2 = 1.0 + bestboxpos
+
+ e1 = box.size[ax] * e2 / wp.abs(halfaxis[ax])
+
+ secondpos = min(e1, secondpos)
+ secondpos *= wp.float32(mul)
+
+ elif cltype < 0:
+ # similarly we handle the case when one capsule's end is closest to a face of the box
+ # and find where is the other end pointing to and clamping to the farthest point
+ # of the capsule that's above the box
+ # if the closest point is inside the box there's no need for a second point
+
+ if clface != -1: # create second contact point
+ mul = wp.where(cltype == -3, 1, -1)
+ secondpos = 2.0
+
+ tmp1 = pos - halfaxis * wp.float32(mul)
+
+ for i in range(3):
+ if i != clface:
+ ha_r = wp.float32(mul) / halfaxis[i]
+ e1 = (box.size[i] - tmp1[i]) * ha_r
+ if 0 < e1 and e1 < secondpos:
+ secondpos = e1
+
+ e1 = (-box.size[i] - tmp1[i]) * ha_r
+ if 0 < e1 and e1 < secondpos:
+ secondpos = e1
+
+ secondpos *= wp.float32(mul)
+
+ # create sphere in original orientation at first contact point
+ s1_pos_l = pos + halfaxis * bestsegmentpos
+ s1_pos_g = box.rot @ s1_pos_l + box.pos
+
+ # collide with sphere
+ _sphere_box(
+ s1_pos_g,
+ cap.size[0],
+ box.pos,
+ box.rot,
+ box.size,
+ worldid,
+ d,
+ margin,
+ gap,
+ condim,
+ friction,
+ solref,
+ solreffriction,
+ solimp,
+ geoms,
+ )
+
+ if secondpos > -3: # secondpos was modified
+ s2_pos_l = pos + halfaxis * (secondpos + bestsegmentpos)
+ s2_pos_g = box.rot @ s2_pos_l + box.pos
+ _sphere_box(
+ s2_pos_g,
+ cap.size[0],
+ box.pos,
+ box.rot,
+ box.size,
+ worldid,
+ d,
+ margin,
+ gap,
+ condim,
+ friction,
+ solref,
+ solreffriction,
+ solimp,
+ geoms,
+ )
+
+
@wp.kernel
def _primitive_narrowphase(
m: Model,
@@ -1018,6 +1408,21 @@ def _primitive_narrowphase(
solimp,
geoms,
)
+ elif type1 == int(GeomType.CAPSULE.value) and type2 == int(GeomType.BOX.value):
+ capsule_box(
+ geom1,
+ geom2,
+ worldid,
+ d,
+ margin,
+ gap,
+ condim,
+ friction,
+ solref,
+ solreffriction,
+ solimp,
+ geoms,
+ )
def primitive_narrowphase(m: Model, d: Data):