Skip to content

Model batching v2 #195

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
2e09003
add get_modelid helper
adenzler-nvidia Apr 25, 2025
d7d6a9e
start implementing things with helper
adenzler-nvidia Apr 25, 2025
6ba92b0
update ast_analyzer_test
adenzler-nvidia Apr 25, 2025
b21f290
add test + dof_damping
adenzler-nvidia Apr 25, 2025
819d9c1
fixes
adenzler-nvidia Apr 25, 2025
db068dc
small fix
adenzler-nvidia Apr 25, 2025
173679a
qpos_spring
adenzler-nvidia Apr 25, 2025
bf6f809
body_pos
adenzler-nvidia Apr 25, 2025
3af4cf9
ipos/iquat/mass etc
adenzler-nvidia Apr 25, 2025
171d8d5
more conversions
adenzler-nvidia Apr 25, 2025
3270655
committing WIP to get things working
adenzler-nvidia Apr 25, 2025
ca4a632
fixes
adenzler-nvidia Apr 28, 2025
555c815
fix geom_size
adenzler-nvidia Apr 28, 2025
d7facd1
formatting and linting
adenzler-nvidia Apr 28, 2025
7a8bb01
remove get_batched_arrays
adenzler-nvidia Apr 30, 2025
ea5c629
first parts of get_batched_value
adenzler-nvidia Apr 30, 2025
c7cb2e7
collision driver
adenzler-nvidia Apr 30, 2025
a4ee24e
primitive collisions
adenzler-nvidia Apr 30, 2025
a0ae5c1
constraint
adenzler-nvidia Apr 30, 2025
6ed86cb
forward
adenzler-nvidia Apr 30, 2025
72fd650
passive
adenzler-nvidia Apr 30, 2025
991d6dd
sensor
adenzler-nvidia Apr 30, 2025
40fac97
smooth
adenzler-nvidia Apr 30, 2025
b8bf457
smooth part 2
adenzler-nvidia Apr 30, 2025
fdbc94b
remove infra
adenzler-nvidia Apr 30, 2025
760dfe5
Merge branch 'main' into dev/adenzler/model-batching-v3
adenzler-nvidia Apr 30, 2025
55389b5
fix an issue in gjk
adenzler-nvidia Apr 30, 2025
badb919
subtree mass
adenzler-nvidia Apr 30, 2025
25369ea
geom for body
adenzler-nvidia Apr 30, 2025
00ea73b
geomadr
adenzler-nvidia Apr 30, 2025
d5c7d81
rbound
adenzler-nvidia Apr 30, 2025
16caa15
formatting
adenzler-nvidia Apr 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions contrib/kernel_analyzer/kernel_analyzer/ast_analyzer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,8 @@ def test_all_issues(
@kernel
def test_no_issues(
# Model:
qpos0: wp.array(dtype=wp.float32, ndim=1),
geom_pos: wp.array(dtype=wp.vec3, ndim=1),
qpos0: wp.array(dtype=wp.float32, ndim=2),
geom_pos: wp.array(dtype=wp.vec3, ndim=2),
# Data in:
qpos_in: wp.array(dtype=wp.float32, ndim=2),
qvel_in: wp.array(dtype=wp.float32, ndim=2),
Expand Down
11 changes: 7 additions & 4 deletions mujoco_warp/_src/collision_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def box_box_kernel(
worldid = d.collision_worldid[bp_idx]

geoms, margin, gap, condim, friction, solref, solreffriction, solimp = (
contact_params(m, d, tid)
contact_params(m, d, tid, worldid)
)

# transformations
Expand All @@ -220,8 +220,8 @@ def box_box_kernel(
trans_atob = b_mat_inv @ (a_pos - b_pos)
rot_atob = b_mat_inv @ a_mat

a_size = m.geom_size[ga]
b_size = m.geom_size[gb]
a_size = m.geom_size[worldid, ga]
b_size = m.geom_size[worldid, gb]
a = box(rot_atob, trans_atob, a_size)
b = box(wp.identity(3, wp.float32), wp.vec3(0.0), b_size)

Expand Down Expand Up @@ -312,7 +312,10 @@ def box_box_kernel(
for i in range(4):
pos[i] = pos[idx]

margin = wp.max(m.geom_margin[ga], m.geom_margin[gb])
margin = wp.max(
m.geom_margin[worldid, ga],
m.geom_margin[worldid, gb],
)
for i in range(4):
pos_glob = b_mat @ pos[i] + b_pos
n_glob = b_mat @ sep_axis
Expand Down
11 changes: 7 additions & 4 deletions mujoco_warp/_src/collision_convex.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,7 +754,7 @@ def _gjk_epa_sparse(m: Model, d: Data):

worldid = d.collision_worldid[tid]
geoms, margin, gap, condim, friction, solref, solreffriction, solimp = (
contact_params(m, d, tid)
contact_params(m, d, tid, worldid)
)

g1 = geoms[0]
Expand All @@ -763,10 +763,13 @@ def _gjk_epa_sparse(m: Model, d: Data):
if m.geom_type[g1] != geomtype1 or m.geom_type[g2] != geomtype2:
return

geom1 = _geom(g1, m, d.geom_xpos[worldid], d.geom_xmat[worldid])
geom2 = _geom(g2, m, d.geom_xpos[worldid], d.geom_xmat[worldid])
geom1 = _geom(g1, m, d.geom_xpos[worldid], d.geom_xmat[worldid], worldid)
geom2 = _geom(g2, m, d.geom_xpos[worldid], d.geom_xmat[worldid], worldid)

margin = wp.max(m.geom_margin[g1], m.geom_margin[g2])
margin = wp.max(
m.geom_margin[worldid, g1],
m.geom_margin[worldid, g2],
)

simplex, normal = _gjk(m, geom1, geom2)

Expand Down
12 changes: 6 additions & 6 deletions mujoco_warp/_src/collision_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@

@wp.func
def _sphere_filter(m: Model, d: Data, geom1: int, geom2: int, worldid: int) -> bool:
margin1 = m.geom_margin[geom1]
margin2 = m.geom_margin[geom2]
margin1 = m.geom_margin[worldid, geom1]
margin2 = m.geom_margin[worldid, geom2]
pos1 = d.geom_xpos[worldid, geom1]
pos2 = d.geom_xpos[worldid, geom2]
size1 = m.geom_rbound[geom1]
size2 = m.geom_rbound[geom2]
size1 = m.geom_rbound[worldid, geom1]
size2 = m.geom_rbound[worldid, geom2]

bound = size1 + size2 + wp.max(margin1, margin2)
dif = pos2 - pos1
Expand Down Expand Up @@ -105,13 +105,13 @@ def _sap_project(m: Model, d: Data, direction: wp.vec3):
worldid, geomid = wp.tid()

xpos = d.geom_xpos[worldid, geomid]
rbound = m.geom_rbound[geomid]
rbound = m.geom_rbound[worldid, geomid]

if rbound == 0.0:
# geom is a plane
rbound = MJ_MAXVAL

radius = rbound + m.geom_margin[geomid]
radius = rbound + m.geom_margin[worldid, geomid]
center = wp.dot(direction, xpos)

d.sap_projection_lower[worldid, geomid] = center - radius
Expand Down
65 changes: 40 additions & 25 deletions mujoco_warp/_src/collision_primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,15 @@ def _geom(
m: Model,
geom_xpos: wp.array(dtype=wp.vec3),
geom_xmat: wp.array(dtype=wp.mat33),
worldid: int,
) -> Geom:
geom = Geom()
geom.pos = geom_xpos[gid]
rot = geom_xmat[gid]
geom.rot = rot
geom.size = m.geom_size[gid]
geom.size = m.geom_size[worldid, gid]
geom.normal = wp.vec3(rot[0, 2], rot[1, 2], rot[2, 2]) # plane
dataid = m.geom_dataid[gid]
dataid = m.geom_dataid[worldid, gid]
if dataid >= 0:
geom.vertadr = m.mesh_vertadr[dataid]
geom.vertnum = m.mesh_vertnum[dataid]
Expand Down Expand Up @@ -737,44 +738,53 @@ def plane_cylinder(


@wp.func
def contact_params(m: Model, d: Data, cid: int):
def contact_params(m: Model, d: Data, cid: int, worldid: int):
geoms = d.collision_pair[cid]
pairid = d.collision_pairid[cid]

if pairid > -1:
margin = m.pair_margin[pairid]
gap = m.pair_gap[pairid]
margin = m.pair_margin[worldid, pairid]
gap = m.pair_gap[worldid, pairid]
condim = m.pair_dim[pairid]
friction = m.pair_friction[pairid]
solref = m.pair_solref[pairid]
solreffriction = m.pair_solreffriction[pairid]
solimp = m.pair_solimp[pairid]
friction = m.pair_friction[worldid, pairid]
solref = m.pair_solref[worldid, pairid]
solreffriction = m.pair_solreffriction[worldid, pairid]
solimp = m.pair_solimp[worldid, pairid]
else:
g1 = geoms[0]
g2 = geoms[1]

p1 = m.geom_priority[g1]
p2 = m.geom_priority[g2]
p1 = m.geom_priority[worldid, g1]
p2 = m.geom_priority[worldid, g2]

solmix1 = m.geom_solmix[g1]
solmix2 = m.geom_solmix[g2]
solmix1 = m.geom_solmix[worldid, g1]
solmix2 = m.geom_solmix[worldid, g2]

mix = solmix1 / (solmix1 + solmix2)
mix = wp.where((solmix1 < MJ_MINVAL) and (solmix2 < MJ_MINVAL), 0.5, mix)
mix = wp.where((solmix1 < MJ_MINVAL) and (solmix2 >= MJ_MINVAL), 0.0, mix)
mix = wp.where((solmix1 >= MJ_MINVAL) and (solmix2 < MJ_MINVAL), 1.0, mix)
mix = wp.where(p1 == p2, mix, wp.where(p1 > p2, 1.0, 0.0))

margin = wp.max(m.geom_margin[g1], m.geom_margin[g2])
gap = wp.max(m.geom_gap[g1], m.geom_gap[g2])
margin = wp.max(
m.geom_margin[worldid, g1],
m.geom_margin[worldid, g2],
)
gap = wp.max(
m.geom_gap[worldid, g1],
m.geom_gap[worldid, g2],
)

condim1 = m.geom_condim[g1]
condim2 = m.geom_condim[g2]
condim = wp.where(
p1 == p2, wp.max(condim1, condim2), wp.where(p1 > p2, condim1, condim2)
)

geom_friction = wp.max(m.geom_friction[g1], m.geom_friction[g2])
geom_friction = wp.max(
m.geom_friction[worldid, g1],
m.geom_friction[worldid, g2],
)
friction = vec5(
geom_friction[0],
geom_friction[0],
Expand All @@ -783,14 +793,19 @@ def contact_params(m: Model, d: Data, cid: int):
geom_friction[2],
)

if m.geom_solref[g1].x > 0.0 and m.geom_solref[g2].x > 0.0:
solref = mix * m.geom_solref[g1] + (1.0 - mix) * m.geom_solref[g2]
if m.geom_solref[worldid, g1].x > 0.0 and m.geom_solref[worldid, g2].x > 0.0:
solref = (
mix * m.geom_solref[worldid, g1] + (1.0 - mix) * m.geom_solref[worldid, g2]
)
else:
solref = wp.min(m.geom_solref[g1], m.geom_solref[g2])
solref = wp.min(
m.geom_solref[worldid, g1],
m.geom_solref[worldid, g2],
)

solreffriction = wp.vec2(0.0, 0.0)

solimp = mix * m.geom_solimp[g1] + (1.0 - mix) * m.geom_solimp[g2]
solimp = mix * m.geom_solimp[worldid, g1] + (1.0 - mix) * m.geom_solimp[worldid, g2]

return geoms, margin, gap, condim, friction, solref, solreffriction, solimp

Expand Down Expand Up @@ -1258,16 +1273,16 @@ def _primitive_narrowphase(
if tid >= d.ncollision[0]:
return

worldid = d.collision_worldid[tid]

geoms, margin, gap, condim, friction, solref, solreffriction, solimp = contact_params(
m, d, tid
m, d, tid, worldid
)
g1 = geoms[0]
g2 = geoms[1]

worldid = d.collision_worldid[tid]

geom1 = _geom(g1, m, d.geom_xpos[worldid], d.geom_xmat[worldid])
geom2 = _geom(g2, m, d.geom_xpos[worldid], d.geom_xmat[worldid])
geom1 = _geom(g1, m, d.geom_xpos[worldid], d.geom_xmat[worldid], worldid)
geom2 = _geom(g2, m, d.geom_xpos[worldid], d.geom_xmat[worldid], worldid)

type1 = m.geom_type[g1]
type2 = m.geom_type[g2]
Expand Down
Loading