diff --git a/mujoco_warp/_src/collision_driver_test.py b/mujoco_warp/_src/collision_driver_test.py index d62853f7..2df62836 100644 --- a/mujoco_warp/_src/collision_driver_test.py +++ b/mujoco_warp/_src/collision_driver_test.py @@ -225,6 +225,50 @@ class CollisionTest(parameterized.TestCase): """, + "capsule_box_edge": """ + + + + + + + + + + """, + "capsule_box_corner": """ + + + + + + + + + + """, + "capsule_box_face_tip": """ + + + + + + + + + + """, + "capsule_box_face_flat": """ + + + + + + + + + + """, } # Temporarily disabled @@ -250,6 +294,10 @@ def test_collision(self, fixture): """Tests convex collision with different geometries.""" mjm, mjd, m, d = test_util.fixture(xml=self._FIXTURES[fixture]) + # Exempt GJK collisions from exact contact count check + # because GJK generates more contacts + allow_different_contact_count = False + mujoco.mj_collision(mjm, mjd) mjwarp.collision(m, d) @@ -257,7 +305,6 @@ def test_collision(self, fixture): actual_dist = mjd.contact.dist[i] actual_pos = mjd.contact.pos[i] actual_frame = mjd.contact.frame[i] - # This is because Gjk generates more contact result = False for j in range(d.ncon.numpy()[0]): test_dist = d.contact.dist.numpy()[j] @@ -271,6 +318,9 @@ def test_collision(self, fixture): break np.testing.assert_equal(result, True, f"Contact {i} not found in Gjk results") + if not allow_different_contact_count: + self.assertEqual(d.ncon.numpy()[0], mjd.ncon) + def test_contact_exclude(self): """Tests contact exclude.""" mjm = mujoco.MjModel.from_xml_string(""" diff --git a/mujoco_warp/_src/collision_primitive.py b/mujoco_warp/_src/collision_primitive.py index 426ec093..ffef2655 100644 --- a/mujoco_warp/_src/collision_primitive.py +++ b/mujoco_warp/_src/collision_primitive.py @@ -796,9 +796,12 @@ def contact_params(m: Model, d: Data, cid: int): @wp.func -def sphere_box( - sphere: Geom, - box: Geom, +def _sphere_box( + sphere_pos: wp.vec3, + sphere_size: float, + box_pos: wp.vec3, + box_rot: wp.mat33, + box_size: wp.vec3, worldid: int, d: Data, margin: float, @@ -810,37 +813,37 @@ def sphere_box( solimp: vec5, geoms: wp.vec2i, ): - center = wp.transpose(box.rot) @ (sphere.pos - box.pos) + center = wp.transpose(box_rot) @ (sphere_pos - box_pos) - clamped = wp.max(-box.size, wp.min(box.size, center)) + clamped = wp.max(-box_size, wp.min(box_size, center)) clamped_dir, dist = normalize_with_norm(clamped - center) - if dist - sphere.size[0] > margin: + if dist - sphere_size > margin: return # sphere center inside box if dist <= MJ_MINVAL: - closest = 2.0 * (box.size[0] + box.size[1] + box.size[2]) + closest = 2.0 * (box_size[0] + box_size[1] + box_size[2]) k = wp.int32(0) for i in range(6): - face_dist = wp.abs(wp.where(i % 2, 1.0, -1.0) * box.size[i / 2] - center[i / 2]) + face_dist = wp.abs(wp.where(i % 2, 1.0, -1.0) * box_size[i / 2] - center[i / 2]) if closest > face_dist: closest = face_dist k = i nearest = wp.vec3(0.0) nearest[k / 2] = wp.where(k % 2, -1.0, 1.0) - pos = center + nearest * (sphere.size[0] - closest) / 2.0 - contact_normal = box.rot @ nearest - contact_dist = -closest - sphere.size[0] + pos = center + nearest * (sphere_size - closest) / 2.0 + contact_normal = box_rot @ nearest + contact_dist = -closest - sphere_size else: - deepest = center + clamped_dir * sphere.size[0] + deepest = center + clamped_dir * sphere_size pos = 0.5 * (clamped + deepest) - contact_normal = box.rot @ clamped_dir - contact_dist = dist - sphere.size[0] + contact_normal = box_rot @ clamped_dir + contact_dist = dist - sphere_size - contact_pos = box.pos + box.rot @ pos + contact_pos = box_pos + box_rot @ pos write_contact( d, contact_dist, @@ -858,6 +861,393 @@ def sphere_box( ) +@wp.func +def sphere_box( + sphere: Geom, + box: Geom, + worldid: int, + d: Data, + margin: float, + gap: float, + condim: int, + friction: vec5, + solref: wp.vec2f, + solreffriction: wp.vec2f, + solimp: vec5, + geoms: wp.vec2i, +): + _sphere_box( + sphere.pos, + sphere.size[0], + box.pos, + box.rot, + box.size, + worldid, + d, + margin, + gap, + condim, + friction, + solref, + solreffriction, + solimp, + geoms, + ) + + +@wp.func +def capsule_box( + cap: Geom, + box: Geom, + worldid: int, + d: Data, + margin: float, + gap: float, + condim: int, + friction: vec5, + solref: wp.vec2f, + solreffriction: wp.vec2f, + solimp: vec5, + geoms: wp.vec2i, +): + """Calculates contacts between a capsule and a box.""" + # Based on the mjc implementation + pos = wp.transpose(box.rot) @ (cap.pos - box.pos) + axis = wp.vec3(cap.rot[0, 2], cap.rot[1, 2], cap.rot[2, 2]) + halfaxis = axis * cap.size[1] # halfaxis is the capsule direction + axisdir = ( + wp.int32(axis[0] > 0.0) + 2 * wp.int32(axis[1] > 0.0) + 4 * wp.int32(axis[2] > 0.0) + ) + + bestdistmax = margin + 2.0 * ( + cap.size[0] + cap.size[1] + box.size[0] + box.size[1] + box.size[2] + ) + + # keep track of closest point + bestdist = wp.float32(bestdistmax) + bestsegmentpos = wp.float32(-12) + + # cltype: encoded collision configuration + # cltype / 3 == 0 : lower corner is closest to the capsule + # == 2 : upper corner is closest to the capsule + # == 1 : middle of the edge is closest to the capsule + # cltype % 3 == 0 : lower corner is closest to the box + # == 2 : upper corner is closest to the box + # == 1 : middle of the capsule is closest to the box + cltype = wp.int32(-4) + + # clface: index of the closest face of the box to the capsule + # -1: no face is closest (edge or corner is closest) + # 0, 1, 2: index of the axis perpendicular to the closest face + clface = wp.int32(-12) + + # first: consider cases where a face of the box is closest + for i in range(-1, 2, 2): + axisTip = pos + wp.float32(i) * halfaxis + boxPoint = wp.vec3(axisTip) + + n_out = wp.int32(0) + ax_out = wp.int32(-1) + + for j in range(3): + if boxPoint[j] < -box.size[j]: + n_out += 1 + ax_out = j + boxPoint[j] = -box.size[j] + elif boxPoint[j] > box.size[j]: + n_out += 1 + ax_out = j + boxPoint[j] = box.size[j] + + if n_out > 1: + continue + + dist = wp.length_sq(boxPoint - axisTip) + + if dist < bestdist: + bestdist = dist + bestsegmentpos = wp.float32(i) + cltype = -2 + i + clface = ax_out + + # second: consider cases where an edge of the box is closest + clcorner = wp.int32(-123) # which corner is the closest + cledge = wp.int32(-123) # which axis + bestboxpos = wp.float32(0.0) + + for i in range(8): + for j in range(3): + if i & (1 << j) != 0: + continue + + c2 = wp.int32(-123) + + # box_pt is the starting point (corner) on the box + box_pt = wp.cw_mul( + wp.vec3( + wp.where(i & 1, 1.0, -1.0), + wp.where(i & 2, 1.0, -1.0), + wp.where(i & 4, 1.0, -1.0), + ), + box.size, + ) + box_pt[j] = 0.0 + + # find closest point between capsule and the edge + dif = box_pt - pos + + u = -box.size[j] * dif[j] + v = wp.dot(halfaxis, dif) + ma = box.size[j] * box.size[j] + mb = -box.size[j] * halfaxis[j] + mc = cap.size[1] * cap.size[1] + det = ma * mc - mb * mb + if wp.abs(det) < MJ_MINVAL: + continue + + idet = 1.0 / det + # sX : X=1 means middle of segment. X=0 or 2 one or the other end + + x1 = wp.float32((mc * u - mb * v) * idet) + x2 = wp.float32((ma * v - mb * u) * idet) + + s1 = wp.int32(1) + s2 = wp.int32(1) + + if x1 > 1: + x1 = 1.0 + s1 = 2 + x2 = (v - mb) / mc + elif x1 < -1: + x1 = -1.0 + s1 = 0 + x2 = (v + mb) / mc + + x2_over = x2 > 1.0 + if x2_over or x2 < -1.0: + if x2_over: + x2 = 1.0 + s2 = 2 + x1 = (u - mb) / ma + else: + x2 = -1.0 + s2 = 0 + x1 = (u + mb) / ma + + if x1 > 1: + x1 = 1.0 + s1 = 2 + elif x1 < -1: + x1 = -1.0 + s1 = 0 + + dif -= halfaxis * x2 + dif[j] += box.size[j] * x1 + + # encode relative positions of the closest points + ct = s1 * 3 + s2 + + dif_sq = wp.length_sq(dif) + if dif_sq < bestdist - MJ_MINVAL: + bestdist = dif_sq + bestsegmentpos = x2 + bestboxpos = x1 + # ct<6 means closest point on box is at lower end or middle of edge + c2 = ct / 6 + + clcorner = i + (1 << j) * c2 # index of closest box corner + cledge = j # axis index of closest box edge + cltype = ct # encoded collision configuration + + best = wp.float32(0.0) + l = wp.float32(0.0) + + p = wp.vec2(pos.x, pos.y) + dd = wp.vec2(halfaxis.x, halfaxis.y) + s = wp.vec2(box.size.x, box.size.y) + secondpos = wp.float32(-4.0) + + l = wp.length_sq(dd) + + uu = dd.x * s.y + vv = dd.y * s.x + w_neg = dd.x * p.y - dd.y * p.x < 0 + + best = wp.float32(-1.0) + + ee1 = uu - vv + ee2 = uu + vv + + if wp.abs(ee1) > best: + best = wp.abs(ee1) + c1 = wp.where((ee1 < 0) == w_neg, 0, 3) + + if wp.abs(ee2) > best: + best = wp.abs(ee2) + c1 = wp.where((ee2 > 0) == w_neg, 1, 2) + + if cltype == -4: # invalid type + return + + if cltype >= 0 and cltype / 3 != 1: # closest to a corner of the box + c1 = axisdir ^ clcorner + # Calculate relative orientation between capsule and corner + # There are two possible configurations: + # 1. Capsule axis points toward/away from corner + # 2. Capsule axis aligns with a face or edge + if c1 != 0 and c1 != 7: # create second contact point + if c1 == 1 or c1 == 2 or c1 == 4: + mul = 1 + else: + mul = -1 + c1 = 7 - c1 + + # "de" and "dp" distance from first closest point on the capsule to both ends of it + # mul is a direction along the capsule's axis + + if c1 == 1: + ax = 0 + ax1 = 1 + ax2 = 2 + elif c1 == 2: + ax = 1 + ax1 = 2 + ax2 = 0 + elif c1 == 4: + ax = 2 + ax1 = 0 + ax2 = 1 + + if axis[ax] * axis[ax] > 0.5: # second point along the edge of the box + m = 2.0 * box.size[ax] / wp.abs(halfaxis[ax]) + secondpos = min(1.0 - wp.float32(mul) * bestsegmentpos, m) + else: # second point along a face of the box + # check for overshoot again + m = 2.0 * min( + box.size[ax1] / wp.abs(halfaxis[ax1]), box.size[ax2] / wp.abs(halfaxis[ax2]) + ) + secondpos = -min(1.0 + wp.float32(mul) * bestsegmentpos, m) + secondpos *= wp.float32(mul) + + elif cltype >= 0 and cltype / 3 == 1: # we are on box's edge + # Calculate relative orientation between capsule and edge + # Two possible configurations: + # - T configuration: c1 = 2^n (no additional contacts) + # - X configuration: c1 != 2^n (potential additional contacts) + c1 = axisdir ^ clcorner + c1 &= 7 - (1 << cledge) # mask out edge axis to determine configuration + + if c1 == 1 or c1 == 2 or c1 == 4: # create second contact point + if cledge == 0: + ax1 = 1 + ax2 = 2 + if cledge == 1: + ax1 = 2 + ax2 = 0 + if cledge == 2: + ax1 = 0 + ax2 = 1 + ax = cledge + + # Then it finds with which face the capsule has a lower angle and switches the axis names + if wp.abs(axis[ax1]) > wp.abs(axis[ax2]): + ax1 = ax2 + ax2 = 3 - ax - ax1 + + # mul determines direction along capsule axis for second contact point + if c1 & (1 << ax2): + mul = 1 + secondpos = 1.0 - bestsegmentpos + else: + mul = -1 + secondpos = 1.0 + bestsegmentpos + + # now we have to find out whether we point towards the opposite side or towards one of the + # sides and also find the farthest point along the capsule that is above the box + + e1 = 2.0 * box.size[ax2] / wp.abs(halfaxis[ax2]) + secondpos = min(e1, secondpos) + + if ((axisdir & (1 << ax)) != 0) == ((c1 & (1 << ax2)) != 0): + e2 = 1.0 - bestboxpos + else: + e2 = 1.0 + bestboxpos + + e1 = box.size[ax] * e2 / wp.abs(halfaxis[ax]) + + secondpos = min(e1, secondpos) + secondpos *= wp.float32(mul) + + elif cltype < 0: + # similarly we handle the case when one capsule's end is closest to a face of the box + # and find where is the other end pointing to and clamping to the farthest point + # of the capsule that's above the box + # if the closest point is inside the box there's no need for a second point + + if clface != -1: # create second contact point + mul = wp.where(cltype == -3, 1, -1) + secondpos = 2.0 + + tmp1 = pos - halfaxis * wp.float32(mul) + + for i in range(3): + if i != clface: + ha_r = wp.float32(mul) / halfaxis[i] + e1 = (box.size[i] - tmp1[i]) * ha_r + if 0 < e1 and e1 < secondpos: + secondpos = e1 + + e1 = (-box.size[i] - tmp1[i]) * ha_r + if 0 < e1 and e1 < secondpos: + secondpos = e1 + + secondpos *= wp.float32(mul) + + # create sphere in original orientation at first contact point + s1_pos_l = pos + halfaxis * bestsegmentpos + s1_pos_g = box.rot @ s1_pos_l + box.pos + + # collide with sphere + _sphere_box( + s1_pos_g, + cap.size[0], + box.pos, + box.rot, + box.size, + worldid, + d, + margin, + gap, + condim, + friction, + solref, + solreffriction, + solimp, + geoms, + ) + + if secondpos > -3: # secondpos was modified + s2_pos_l = pos + halfaxis * (secondpos + bestsegmentpos) + s2_pos_g = box.rot @ s2_pos_l + box.pos + _sphere_box( + s2_pos_g, + cap.size[0], + box.pos, + box.rot, + box.size, + worldid, + d, + margin, + gap, + condim, + friction, + solref, + solreffriction, + solimp, + geoms, + ) + + @wp.kernel def _primitive_narrowphase( m: Model, @@ -1018,6 +1408,21 @@ def _primitive_narrowphase( solimp, geoms, ) + elif type1 == int(GeomType.CAPSULE.value) and type2 == int(GeomType.BOX.value): + capsule_box( + geom1, + geom2, + worldid, + d, + margin, + gap, + condim, + friction, + solref, + solreffriction, + solimp, + geoms, + ) def primitive_narrowphase(m: Model, d: Data):