-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Description
The feature, motivation and pitch
Both mujoco.MjModel and mujoco.MjSpec have a nice API where their elements can be accessed by name with methods like mujoco.MjModel.{body,geom,joint,actuator}. I notice myself constantly missing similar API for mjx.Model. For example, most recently, I tried using the playground environment interface together with domain randomization. Unfortunately, BraxDomainRandomizationVmapWrapper only passes self.mjx_model into the randomization_fn, which makes it hard to access its elements with names.
Alternatives
I can access the elements with mjx.Model.bind(mujoco.MjSpec), but as the previous example shows, it's quite clumsy to always have to carry both the mjx.Model and mujoco.MjSpec objects around together. Edit: I can also do mjx.name2id(name), which solves my most immediate issue. It would be much cleaner and consistent to be able to access mjx.Model with similar pattern as mujoco.MjModel.
Additional context
Here's a simple test case I wish would pass:
Details
"""Test for `mjx.Model` named element access."""
import textwrap
from absl.testing import absltest
from absl.testing import parameterized
import mujoco
from mujoco import mjx
class MjxModelTests(parameterized.TestCase):
def test_named_element_access(self) -> None:
model = mujoco.MjModel.from_xml_string(textwrap.dedent("""
<mujoco model="test">
<worldbody>
<geom name="plane" type="plane" size="1 1 1"/>
<body name="body" pos="0 0 0">
<joint name="joint" type="slide" axis="1 0 0" range="-5 5"/>
<geom name="box" type="box" size=".2 .1 .1" rgba=".9 .3 .3 1"/>
</body>
</worldbody>
<actuator>
<motor joint="joint" name="motor"/>
</actuator>
</mujoco>
"""))
modelx = mjx.put_model(model)
self.assertEqual(model.body("body").id, modelx.body("body").id)
self.assertEqual(model.joint("joint").id, modelx.joint("joint").id)
self.assertEqual(model.actuator("motor").id, modelx.actuator("motor").id)
self.assertEqual(model.geom("plane").id, modelx.geom("plane").id)
self.assertEqual(model.geom("box").id, modelx.geom("box").id)
if __name__ == "__main__":
absltest.main()$ python ./mjx_model_test.py
Running tests under Python 3.12.11: /home/user/.venv/bin/python
[ RUN ] UnitTests.test_add
[ FAILED ] UnitTests.test_add
======================================================================
ERROR: test_add (__main__.UnitTests)
UnitTests.test_add
----------------------------------------------------------------------
Traceback (most recent call last):
File "/home/user/.venv/lib/python3.12/site-packages/mujoco/mjx/_src/types.py", line 935, in __getattr__
val = getattr(impl_instsance, name)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'ModelJAX' object has no attribute 'body'
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/user/mjx_model_test.py", line 30, in test_add
self.assertEqual(model.body("body").id, modelx.body("body").id)
^^^^^^^^^^^
File "/home/user/.venv/lib/python3.12/site-packages/mujoco/mjx/_src/types.py", line 944, in __getattr__
raise AttributeError( # pylint: disable=raise-missing-from
AttributeError: 'Model' object has no attribute 'body'. Did you mean: 'nbody'?
----------------------------------------------------------------------
Ran 1 test in 0.329s
FAILED (errors=1)