Nested function #18271
Answered
by
anh-tong
HyunyoungJung
asked this question in
Q&A
Nested function
#18271
-
Hi, def pipeline_PD_step(
self, data: mjx.Data, action: jax.Array
) -> mjx.Data:
"""Takes a physics step using the physics pipeline."""
def PD_substep(data, _):
ctrl = self._pd_control(data, action)
data = data.replace(ctrl=ctrl)
return (
mjx.step(self.sys, data),
None,
)
data, _ = jax.lax.scan(PD_substep, data, (), self._physics_steps_per_control_step)
return data
def _pd_control(self, data: mjx.Data, action: jax.Array) -> jax.Array:
"""PD control for each joint."""
ctrl = (self._kp * (action - data.qpos[self._motor_qpos_idx])
- self._kd * data.qvel[self._motor_qvel_idx])
return ctrl Here, should I jit _pd_control as well? |
Beta Was this translation helpful? Give feedback.
Answered by
anh-tong
Oct 25, 2023
Replies: 1 comment
-
Here |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
HyunyoungJung
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Here
PD_substep
is automatically jitted. In fact, input functions forjax.lax.scan
should be jittable. So, you do not have to jit the nested function_pd_control
except you want to use_pd_control
elsewhere.