I am using a GPU python3.11, but JAX 0.7 seems to be too new for my project. My solution is as follows:
pip install jax[cuda12]==0.6.2
pip uninstall mujoco mujoco-mjx
pip install mujoco==3.3.2 mujoco-mjx
I started modifying the code. I changed the return values of the following methods from lists to tuples: get_physics_randomizers, get_events, get_resets, get_observations, get_commands, get_rewards, and get_terminations.
I am using a GPU python3.11, but JAX 0.7 seems to be too new for my project. My solution is as follows:
I started modifying the code. I changed the return values of the following methods from lists to tuples:
get_physics_randomizers,get_events,get_resets,get_observations,get_commands,get_rewards, andget_terminations.