- Change in isaaclab_env wrapper to explicitly state GPU for each simulation - Removing jax cache to support multi-gpu environment launch in MuJoCo Playground - Removing .train() and .eval() in evaluation and rendering to avoid deadlock in multi-gpu training - Supporting synchronous normalization for multi-gpu training
157 lines
5.3 KiB
Python
157 lines
5.3 KiB
Python
from mujoco_playground import registry
|
|
from mujoco_playground import wrapper_torch
|
|
|
|
import jax
|
|
import mujoco
|
|
|
|
|
|
class PlaygroundEvalEnvWrapper:
|
|
def __init__(
|
|
self,
|
|
eval_env,
|
|
max_episode_steps,
|
|
env_name,
|
|
num_eval_envs,
|
|
seed,
|
|
device_rank=None,
|
|
):
|
|
"""
|
|
Wrapper used for evaluation / rendering environments.
|
|
Note that this is different from training environments that are
|
|
wrapped with RSLRLBraxWrapper.
|
|
"""
|
|
self.env = eval_env
|
|
self.env_name = env_name
|
|
self.num_envs = num_eval_envs
|
|
self.jit_reset = jax.jit(jax.vmap(self.env.reset))
|
|
self.jit_step = jax.jit(jax.vmap(self.env.step))
|
|
|
|
if isinstance(self.env.unwrapped.observation_size, dict):
|
|
self.asymmetric_obs = True
|
|
else:
|
|
self.asymmetric_obs = False
|
|
|
|
self.key = jax.random.PRNGKey(seed)
|
|
|
|
if device_rank is not None:
|
|
gpu_devices = jax.devices("gpu")
|
|
self.key = jax.device_put(self.key, gpu_devices[device_rank])
|
|
|
|
self.key_reset = jax.random.split(self.key, num_eval_envs)
|
|
self.max_episode_steps = max_episode_steps
|
|
|
|
def reset(self):
|
|
self.state = self.jit_reset(self.key_reset)
|
|
if self.asymmetric_obs:
|
|
obs = wrapper_torch._jax_to_torch(self.state.obs["state"])
|
|
else:
|
|
obs = wrapper_torch._jax_to_torch(self.state.obs)
|
|
return obs
|
|
|
|
def step(self, actions):
|
|
self.state = self.jit_step(self.state, wrapper_torch._torch_to_jax(actions))
|
|
if self.asymmetric_obs:
|
|
next_obs = wrapper_torch._jax_to_torch(self.state.obs["state"])
|
|
else:
|
|
next_obs = wrapper_torch._jax_to_torch(self.state.obs)
|
|
rewards = wrapper_torch._jax_to_torch(self.state.reward)
|
|
dones = wrapper_torch._jax_to_torch(self.state.done)
|
|
return next_obs, rewards, dones, None
|
|
|
|
def render_trajectory(self, trajectory):
|
|
scene_option = mujoco.MjvOption()
|
|
scene_option.flags[mujoco.mjtVisFlag.mjVIS_TRANSPARENT] = False
|
|
scene_option.flags[mujoco.mjtVisFlag.mjVIS_PERTFORCE] = False
|
|
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTFORCE] = False
|
|
|
|
frames = self.env.render(
|
|
trajectory,
|
|
camera="track" if "Joystick" in self.env_name else None,
|
|
height=480,
|
|
width=640,
|
|
scene_option=scene_option,
|
|
)
|
|
return frames
|
|
|
|
|
|
def make_env(
|
|
env_name,
|
|
seed,
|
|
num_envs,
|
|
num_eval_envs,
|
|
device_rank,
|
|
use_tuned_reward=False,
|
|
use_domain_randomization=False,
|
|
use_push_randomization=False,
|
|
):
|
|
# Make training environment
|
|
train_env_cfg = registry.get_default_config(env_name)
|
|
is_humanoid_task = env_name in [
|
|
"G1JoystickRoughTerrain",
|
|
"G1JoystickFlatTerrain",
|
|
"T1JoystickRoughTerrain",
|
|
"T1JoystickFlatTerrain",
|
|
]
|
|
|
|
if use_tuned_reward and is_humanoid_task:
|
|
# NOTE: Tuned reward for G1. Used for producing Figure 7 in the paper.
|
|
# Somehow it works reasonably for T1 as well.
|
|
# However, see `sim2real.md` for sim-to-real RL with Booster T1
|
|
train_env_cfg.reward_config.scales.energy = -5e-5
|
|
train_env_cfg.reward_config.scales.action_rate = -1e-1
|
|
train_env_cfg.reward_config.scales.torques = -1e-3
|
|
train_env_cfg.reward_config.scales.pose = -1.0
|
|
train_env_cfg.reward_config.scales.tracking_ang_vel = 1.25
|
|
train_env_cfg.reward_config.scales.tracking_lin_vel = 1.25
|
|
train_env_cfg.reward_config.scales.feet_phase = 1.0
|
|
train_env_cfg.reward_config.scales.ang_vel_xy = -0.3
|
|
train_env_cfg.reward_config.scales.orientation = -5.0
|
|
|
|
if is_humanoid_task and not use_push_randomization:
|
|
train_env_cfg.push_config.enable = False
|
|
train_env_cfg.push_config.magnitude_range = [0.0, 0.0]
|
|
randomizer = (
|
|
registry.get_domain_randomizer(env_name) if use_domain_randomization else None
|
|
)
|
|
raw_env = registry.load(env_name, config=train_env_cfg)
|
|
train_env = wrapper_torch.RSLRLBraxWrapper(
|
|
raw_env,
|
|
num_envs,
|
|
seed,
|
|
train_env_cfg.episode_length,
|
|
train_env_cfg.action_repeat,
|
|
randomization_fn=randomizer,
|
|
device_rank=device_rank,
|
|
)
|
|
|
|
# Make evaluation environment
|
|
eval_env_cfg = registry.get_default_config(env_name)
|
|
if is_humanoid_task and not use_push_randomization:
|
|
eval_env_cfg.push_config.enable = False
|
|
eval_env_cfg.push_config.magnitude_range = [0.0, 0.0]
|
|
eval_env = registry.load(env_name, config=eval_env_cfg)
|
|
eval_env = PlaygroundEvalEnvWrapper(
|
|
eval_env,
|
|
eval_env_cfg.episode_length,
|
|
env_name,
|
|
num_eval_envs,
|
|
seed,
|
|
device_rank=device_rank,
|
|
)
|
|
|
|
render_env_cfg = registry.get_default_config(env_name)
|
|
if is_humanoid_task and not use_push_randomization:
|
|
render_env_cfg.push_config.enable = False
|
|
render_env_cfg.push_config.magnitude_range = [0.0, 0.0]
|
|
render_env = registry.load(env_name, config=render_env_cfg)
|
|
render_env = PlaygroundEvalEnvWrapper(
|
|
render_env,
|
|
render_env_cfg.episode_length,
|
|
env_name,
|
|
1,
|
|
seed,
|
|
device_rank=device_rank,
|
|
)
|
|
|
|
return train_env, eval_env, render_env
|