140 lines
5.1 KiB
Python
140 lines
5.1 KiB
Python
import jax
|
|
from mujoco_playground import registry, wrapper_torch
|
|
import torch
|
|
|
|
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
|
|
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
|
|
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
|
|
|
|
|
|
class PlaygroundEvalEnvWrapper:
|
|
def __init__(self, eval_env, max_episode_steps, env_name, num_eval_envs, seed):
|
|
"""
|
|
Wrapper used for evaluation / rendering environments.
|
|
Note that this is different from training environments that are
|
|
wrapped with RSLRLBraxWrapper.
|
|
"""
|
|
self.env = eval_env
|
|
self.env_name = env_name
|
|
self.num_envs = num_eval_envs
|
|
self.jit_reset = jax.jit(jax.vmap(self.env.reset))
|
|
self.jit_step = jax.jit(jax.vmap(self.env.step))
|
|
|
|
if isinstance(self.env.unwrapped.observation_size, dict):
|
|
self.asymmetric_obs = True
|
|
else:
|
|
self.asymmetric_obs = False
|
|
|
|
self.key = jax.random.PRNGKey(seed)
|
|
self.key_reset = jax.random.split(self.key, num_eval_envs)
|
|
self.max_episode_steps = max_episode_steps
|
|
|
|
def reset(self):
|
|
self.state = self.jit_reset(self.key_reset)
|
|
if self.asymmetric_obs:
|
|
obs = wrapper_torch._jax_to_torch(self.state.obs["state"])
|
|
else:
|
|
obs = wrapper_torch._jax_to_torch(self.state.obs)
|
|
return obs
|
|
|
|
def step(self, actions):
|
|
self.state = self.jit_step(self.state, wrapper_torch._torch_to_jax(actions))
|
|
if self.asymmetric_obs:
|
|
next_obs = wrapper_torch._jax_to_torch(self.state.obs["state"])
|
|
else:
|
|
next_obs = wrapper_torch._jax_to_torch(self.state.obs)
|
|
rewards = wrapper_torch._jax_to_torch(self.state.reward)
|
|
dones = wrapper_torch._jax_to_torch(self.state.done)
|
|
return next_obs, rewards, dones, dones, None
|
|
|
|
|
|
class RandomizeInitialWrapper(wrapper_torch.RSLRLBraxWrapper):
|
|
"""
|
|
Wrapper to randomize the initial state of the environment.
|
|
This is useful for domain randomization experiments.
|
|
"""
|
|
|
|
def reset(self):
|
|
print("Resetting environment with randomization")
|
|
obs = super().reset()
|
|
self.env_state.info["steps"] = jax.random.randint(
|
|
self.key, self.env_state.info["steps"].shape, 0, 1000
|
|
).astype(jax.numpy.float32)
|
|
print(obs)
|
|
return obs
|
|
|
|
def reset_with_critic_obs(self):
|
|
print("Resetting environment with randomization and critic obs")
|
|
obs, critic_obs = super().reset_with_critic_obs()
|
|
self.env_state.info["steps"] = jax.random.randint(
|
|
self.key, self.env_state.info["steps"].shape, 0, 1000
|
|
).astype(jax.numpy.float32)
|
|
return obs, critic_obs
|
|
|
|
def step(self, action):
|
|
obs, reward, done, info = super().step(action)
|
|
return obs, reward, done, done, info
|
|
|
|
|
|
def make_env(
|
|
env_name,
|
|
seed,
|
|
num_envs,
|
|
num_eval_envs,
|
|
device_rank,
|
|
use_tuned_reward=False,
|
|
use_domain_randomization=False,
|
|
use_push_randomization=False,
|
|
):
|
|
# Make training environment
|
|
train_env_cfg = registry.get_default_config(env_name)
|
|
is_humanoid_task = env_name in [
|
|
"G1JoystickRoughTerrain",
|
|
"G1JoystickFlatTerrain",
|
|
"T1JoystickRoughTerrain",
|
|
"T1JoystickFlatTerrain",
|
|
]
|
|
|
|
if use_tuned_reward and is_humanoid_task:
|
|
# NOTE: Tuned reward for G1. Used for producing Figure 7 in the paper.
|
|
# Somehow it works reasonably for T1 as well.
|
|
# However, see `sim2real.md` for sim-to-real RL with Booster T1
|
|
train_env_cfg.reward_config.scales.energy = -5e-5
|
|
train_env_cfg.reward_config.scales.action_rate = -1e-1
|
|
train_env_cfg.reward_config.scales.torques = -1e-3
|
|
train_env_cfg.reward_config.scales.pose = -1.0
|
|
train_env_cfg.reward_config.scales.tracking_ang_vel = 1.25
|
|
train_env_cfg.reward_config.scales.tracking_lin_vel = 1.25
|
|
train_env_cfg.reward_config.scales.feet_phase = 1.0
|
|
train_env_cfg.reward_config.scales.ang_vel_xy = -0.3
|
|
train_env_cfg.reward_config.scales.orientation = -5.0
|
|
|
|
if is_humanoid_task and not use_push_randomization:
|
|
train_env_cfg.push_config.enable = False
|
|
train_env_cfg.push_config.magnitude_range = [0.0, 0.0]
|
|
randomizer = (
|
|
registry.get_domain_randomizer(env_name) if use_domain_randomization else None
|
|
)
|
|
raw_env = registry.load(env_name, config=train_env_cfg)
|
|
train_env = RandomizeInitialWrapper(
|
|
raw_env,
|
|
num_envs,
|
|
seed,
|
|
train_env_cfg.episode_length,
|
|
train_env_cfg.action_repeat,
|
|
randomization_fn=randomizer,
|
|
device_rank=device_rank,
|
|
)
|
|
|
|
# Make evaluation environment
|
|
eval_env_cfg = registry.get_default_config(env_name)
|
|
if is_humanoid_task and not use_push_randomization:
|
|
eval_env_cfg.push_config.enable = False
|
|
eval_env_cfg.push_config.magnitude_range = [0.0, 0.0]
|
|
eval_env = registry.load(env_name, config=eval_env_cfg)
|
|
eval_env = PlaygroundEvalEnvWrapper(
|
|
eval_env, eval_env_cfg.episode_length, env_name, num_eval_envs, seed
|
|
)
|
|
|
|
return train_env, eval_env
|