Skip to content

Named element accessors for mjx.Model #2806

@hartikainen

Description

@hartikainen

The feature, motivation and pitch

Both mujoco.MjModel and mujoco.MjSpec have a nice API where their elements can be accessed by name with methods like mujoco.MjModel.{body,geom,joint,actuator}. I notice myself constantly missing similar API for mjx.Model. For example, most recently, I tried using the playground environment interface together with domain randomization. Unfortunately, BraxDomainRandomizationVmapWrapper only passes self.mjx_model into the randomization_fn, which makes it hard to access its elements with names.

Alternatives

I can access the elements with mjx.Model.bind(mujoco.MjSpec), but as the previous example shows, it's quite clumsy to always have to carry both the mjx.Model and mujoco.MjSpec objects around together. Edit: I can also do mjx.name2id(name), which solves my most immediate issue. It would be much cleaner and consistent to be able to access mjx.Model with similar pattern as mujoco.MjModel.

Additional context

Here's a simple test case I wish would pass:

Details
"""Test for `mjx.Model` named element access."""

import textwrap

from absl.testing import absltest
from absl.testing import parameterized
import mujoco
from mujoco import mjx


class MjxModelTests(parameterized.TestCase):
    def test_named_element_access(self) -> None:
        model = mujoco.MjModel.from_xml_string(textwrap.dedent("""
            <mujoco model="test">
              <worldbody>
                <geom name="plane" type="plane" size="1 1 1"/>
                <body name="body" pos="0 0 0">
                  <joint name="joint" type="slide" axis="1 0 0" range="-5 5"/>
                  <geom name="box" type="box" size=".2 .1 .1" rgba=".9 .3 .3 1"/>
                </body>
              </worldbody>
              <actuator>
                <motor joint="joint" name="motor"/>
              </actuator>
            </mujoco>
        """))
        modelx = mjx.put_model(model)

        self.assertEqual(model.body("body").id, modelx.body("body").id)
        self.assertEqual(model.joint("joint").id, modelx.joint("joint").id)
        self.assertEqual(model.actuator("motor").id, modelx.actuator("motor").id)
        self.assertEqual(model.geom("plane").id, modelx.geom("plane").id)
        self.assertEqual(model.geom("box").id, modelx.geom("box").id)


if __name__ == "__main__":
    absltest.main()
$ python ./mjx_model_test.py
Running tests under Python 3.12.11: /home/user/.venv/bin/python
[ RUN      ] UnitTests.test_add
[  FAILED  ] UnitTests.test_add
======================================================================
ERROR: test_add (__main__.UnitTests)
UnitTests.test_add
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/user/.venv/lib/python3.12/site-packages/mujoco/mjx/_src/types.py", line 935, in __getattr__
    val = getattr(impl_instsance, name)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'ModelJAX' object has no attribute 'body'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/user/mjx_model_test.py", line 30, in test_add
    self.assertEqual(model.body("body").id, modelx.body("body").id)
                                            ^^^^^^^^^^^
  File "/home/user/.venv/lib/python3.12/site-packages/mujoco/mjx/_src/types.py", line 944, in __getattr__
    raise AttributeError(  # pylint: disable=raise-missing-from
AttributeError: 'Model' object has no attribute 'body'. Did you mean: 'nbody'?

----------------------------------------------------------------------
Ran 1 test in 0.329s

FAILED (errors=1)

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions