FastTD3/fast_td3/environments/mujoco_playground_env.py
Younggyo Seo 51c55d4a8a
Support Multi-GPU Training (#22)
- Change in isaaclab_env wrapper to explicitly state GPU for each simulation
- Removing jax cache to support multi-gpu environment launch in MuJoCo Playground
- Removing .train() and .eval() in evaluation and rendering to avoid deadlock in multi-gpu training
- Supporting synchronous normalization for multi-gpu training
2025-07-07 10:24:42 -07:00

157 lines
5.3 KiB
Python

from mujoco_playground import registry
from mujoco_playground import wrapper_torch
import jax
import mujoco
class PlaygroundEvalEnvWrapper:
def __init__(
self,
eval_env,
max_episode_steps,
env_name,
num_eval_envs,
seed,
device_rank=None,
):
"""
Wrapper used for evaluation / rendering environments.
Note that this is different from training environments that are
wrapped with RSLRLBraxWrapper.
"""
self.env = eval_env
self.env_name = env_name
self.num_envs = num_eval_envs
self.jit_reset = jax.jit(jax.vmap(self.env.reset))
self.jit_step = jax.jit(jax.vmap(self.env.step))
if isinstance(self.env.unwrapped.observation_size, dict):
self.asymmetric_obs = True
else:
self.asymmetric_obs = False
self.key = jax.random.PRNGKey(seed)
if device_rank is not None:
gpu_devices = jax.devices("gpu")
self.key = jax.device_put(self.key, gpu_devices[device_rank])
self.key_reset = jax.random.split(self.key, num_eval_envs)
self.max_episode_steps = max_episode_steps
def reset(self):
self.state = self.jit_reset(self.key_reset)
if self.asymmetric_obs:
obs = wrapper_torch._jax_to_torch(self.state.obs["state"])
else:
obs = wrapper_torch._jax_to_torch(self.state.obs)
return obs
def step(self, actions):
self.state = self.jit_step(self.state, wrapper_torch._torch_to_jax(actions))
if self.asymmetric_obs:
next_obs = wrapper_torch._jax_to_torch(self.state.obs["state"])
else:
next_obs = wrapper_torch._jax_to_torch(self.state.obs)
rewards = wrapper_torch._jax_to_torch(self.state.reward)
dones = wrapper_torch._jax_to_torch(self.state.done)
return next_obs, rewards, dones, None
def render_trajectory(self, trajectory):
scene_option = mujoco.MjvOption()
scene_option.flags[mujoco.mjtVisFlag.mjVIS_TRANSPARENT] = False
scene_option.flags[mujoco.mjtVisFlag.mjVIS_PERTFORCE] = False
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTFORCE] = False
frames = self.env.render(
trajectory,
camera="track" if "Joystick" in self.env_name else None,
height=480,
width=640,
scene_option=scene_option,
)
return frames
def make_env(
env_name,
seed,
num_envs,
num_eval_envs,
device_rank,
use_tuned_reward=False,
use_domain_randomization=False,
use_push_randomization=False,
):
# Make training environment
train_env_cfg = registry.get_default_config(env_name)
is_humanoid_task = env_name in [
"G1JoystickRoughTerrain",
"G1JoystickFlatTerrain",
"T1JoystickRoughTerrain",
"T1JoystickFlatTerrain",
]
if use_tuned_reward and is_humanoid_task:
# NOTE: Tuned reward for G1. Used for producing Figure 7 in the paper.
# Somehow it works reasonably for T1 as well.
# However, see `sim2real.md` for sim-to-real RL with Booster T1
train_env_cfg.reward_config.scales.energy = -5e-5
train_env_cfg.reward_config.scales.action_rate = -1e-1
train_env_cfg.reward_config.scales.torques = -1e-3
train_env_cfg.reward_config.scales.pose = -1.0
train_env_cfg.reward_config.scales.tracking_ang_vel = 1.25
train_env_cfg.reward_config.scales.tracking_lin_vel = 1.25
train_env_cfg.reward_config.scales.feet_phase = 1.0
train_env_cfg.reward_config.scales.ang_vel_xy = -0.3
train_env_cfg.reward_config.scales.orientation = -5.0
if is_humanoid_task and not use_push_randomization:
train_env_cfg.push_config.enable = False
train_env_cfg.push_config.magnitude_range = [0.0, 0.0]
randomizer = (
registry.get_domain_randomizer(env_name) if use_domain_randomization else None
)
raw_env = registry.load(env_name, config=train_env_cfg)
train_env = wrapper_torch.RSLRLBraxWrapper(
raw_env,
num_envs,
seed,
train_env_cfg.episode_length,
train_env_cfg.action_repeat,
randomization_fn=randomizer,
device_rank=device_rank,
)
# Make evaluation environment
eval_env_cfg = registry.get_default_config(env_name)
if is_humanoid_task and not use_push_randomization:
eval_env_cfg.push_config.enable = False
eval_env_cfg.push_config.magnitude_range = [0.0, 0.0]
eval_env = registry.load(env_name, config=eval_env_cfg)
eval_env = PlaygroundEvalEnvWrapper(
eval_env,
eval_env_cfg.episode_length,
env_name,
num_eval_envs,
seed,
device_rank=device_rank,
)
render_env_cfg = registry.get_default_config(env_name)
if is_humanoid_task and not use_push_randomization:
render_env_cfg.push_config.enable = False
render_env_cfg.push_config.magnitude_range = [0.0, 0.0]
render_env = registry.load(env_name, config=render_env_cfg)
render_env = PlaygroundEvalEnvWrapper(
render_env,
render_env_cfg.episode_length,
env_name,
1,
seed,
device_rank=device_rank,
)
return train_env, eval_env, render_env