reppo/reppo_alg/env_utils/jax_wrappers.py
2025-07-21 18:31:20 -04:00

376 lines
12 KiB
Python

from functools import partial
from typing import Any, Tuple, Union
import chex
import gymnax
import jax
import jax.numpy as jnp
from brax import envs
from brax.envs.wrappers.training import AutoResetWrapper, EpisodeWrapper
from flax import struct
from gymnax.environments import environment, spaces
from gymnax.environments.environment import Environment
from gymnax.environments.spaces import Box
from ml_collections import ConfigDict
from mujoco_playground import MjxEnv, registry
from mujoco_playground._src.wrapper import wrap_for_brax_training, Wrapper
class MjxGymnaxWrapper(Environment):
def __init__(
self,
env_or_name: str | MjxEnv,
episode_length: int = 1000,
action_repeat: int = 1,
reward_scale: float = 1.0,
push_distractions: bool = False,
config: dict = None,
asymmetric_observation: bool = False,
):
if isinstance(env_or_name, str):
if config is None:
config = registry.get_default_config(env_or_name)
is_humanoid_task = env_or_name in [
"G1JoystickRoughTerrain",
"G1JoystickFlatTerrain",
"T1JoystickRoughTerrain",
"T1JoystickFlatTerrain",
]
if is_humanoid_task:
config.push_config.enable = push_distractions
else:
config = ConfigDict(config)
env = registry.load(env_or_name, config=config)
if episode_length is not None:
env = wrap_for_brax_training(
env, episode_length=episode_length, action_repeat=action_repeat
)
self.env = env
else:
self.env = env_or_name
self.reward_scale = reward_scale
if isinstance(self.env.observation_size, int):
self.dict_obs = False
else:
self.dict_obs = True
if asymmetric_observation:
self.dict_obs_key = "privileged_state"
else:
self.dict_obs_key = "state"
print(self.dict_obs_key)
super().__init__()
def action_space(self, params):
return gymnax.environments.spaces.Box(
low=-1.0,
high=1.0,
shape=(self.env.action_size,),
)
def observation_space(self, params):
if self.dict_obs:
return Box(
low=-float("inf"),
high=float("inf"),
shape=self.env.observation_size["state"],
), Box(
low=-float("inf"),
high=float("inf"),
shape=self.env.observation_size[self.dict_obs_key],
)
else:
return Box(
low=-float("inf"),
high=float("inf"),
shape=(self.env.observation_size,),
), Box(
low=-float("inf"),
high=float("inf"),
shape=(self.env.observation_size,),
)
@property
def default_params(self) -> gymnax.EnvParams:
return gymnax.EnvParams()
def reset(self, key):
state = self.env.reset(key)
# state.info["truncation"] = 0.0
obs = state.obs if not self.dict_obs else state.obs["state"]
critic_obs = state.obs if not self.dict_obs else state.obs[self.dict_obs_key]
return obs, critic_obs, state
def step(self, key, state, action):
# action = jnp.nan_to_num(action, 0.0)
state = self.env.step(state, action)
obs = state.obs if not self.dict_obs else state.obs["state"]
critic_obs = state.obs if not self.dict_obs else state.obs[self.dict_obs_key]
return (
obs,
critic_obs,
state,
state.reward * self.reward_scale,
state.done > 0.5,
{},
)
@struct.dataclass
class LogEnvState:
env_state: environment.EnvState
episode_returns: jnp.ndarray
episode_lengths: jnp.ndarray
returned_episode_returns: jnp.ndarray
returned_episode_lengths: jnp.ndarray
timestep: jnp.ndarray
truncated: jnp.ndarray
info: Any = None
def unwrapped(self):
return self.env_state
def set_env_state(self, env_state):
return self.replace(env_state=env_state)
class LogWrapper(Wrapper):
"""Log the episode returns and lengths."""
def __init__(self, env: environment.Environment, num_envs: int):
super().__init__(env)
self.num_envs = num_envs
@partial(jax.jit, static_argnums=(0,))
def reset(self, key) -> Tuple[chex.Array, environment.EnvState]:
obs, critic_obs, env_state = self.env.reset(key)
state = LogEnvState(
env_state=env_state,
episode_returns=jnp.zeros((self.num_envs,)),
episode_lengths=jnp.zeros((self.num_envs,), dtype=jnp.int32),
returned_episode_returns=jnp.zeros((self.num_envs,)),
returned_episode_lengths=jnp.zeros((self.num_envs,), dtype=jnp.int32),
timestep=jnp.zeros((self.num_envs,), dtype=jnp.int32),
truncated=jnp.ones((self.num_envs,), dtype=jnp.float32),
info={
"returned_episode": jnp.zeros((self.num_envs,), dtype=jnp.bool_),
"returned_episode_returns": jnp.zeros((self.num_envs,)),
"timestep": jnp.zeros((self.num_envs,), dtype=jnp.int32),
"returned_episode_lengths": jnp.zeros(
(self.num_envs,), dtype=jnp.int32
),
},
)
return obs, critic_obs, state
@partial(jax.jit, static_argnums=(0,))
def step(
self,
key: chex.PRNGKey,
state: environment.EnvState,
action: Union[int, float],
) -> Tuple[chex.Array, environment.EnvState, float, bool, dict]:
obs, critic_obs, env_state, reward, done, info = self.env.step(
key, state.env_state, action
)
new_episode_return = state.episode_returns + reward
new_episode_length = state.episode_lengths + 1
info["returned_episode_returns"] = (
state.returned_episode_returns * (1 - done) + new_episode_return * done
)
info["returned_episode_lengths"] = (
state.returned_episode_lengths * (1 - done) + new_episode_length * done
)
info["timestep"] = state.timestep
info["returned_episode"] = done
state = LogEnvState(
env_state=env_state,
episode_returns=new_episode_return * (1 - done),
episode_lengths=new_episode_length * (1 - done),
returned_episode_returns=state.returned_episode_returns * (1 - done)
+ new_episode_return * done,
returned_episode_lengths=state.returned_episode_lengths * (1 - done)
+ new_episode_length * done,
timestep=state.timestep + 1,
truncated=env_state.info["truncation"],
info=info,
)
return obs, critic_obs, state, reward, done, info
class BraxGymnaxWrapper:
def __init__(
self,
env_name,
backend="generalized",
episode_length=1000,
reward_scaling=1.0,
terminate=True,
):
env = envs.get_environment(
env_name=env_name, backend=backend, terminate_when_unhealthy=terminate
)
env = EpisodeWrapper(env, episode_length=episode_length, action_repeat=1)
env = AutoResetWrapper(env)
self.env = env
self.action_size = self.env.action_size
self.observation_size = (self.env.observation_size,)
self.default_params = ()
self.reward_scaling = reward_scaling
def reset(self, key):
state = self.env.reset(key)
return state.obs, state
def step(self, key, state, action):
next_state = self.env.step(state, action)
return (
next_state.obs,
next_state.obs,
next_state,
next_state.reward * self.reward_scaling,
next_state.done > 0.5,
{},
)
def observation_space(self):
return spaces.Box(
low=-jnp.inf,
high=jnp.inf,
shape=(self.env.observation_size,),
), spaces.Box(
low=-jnp.inf,
high=jnp.inf,
shape=(self.env.observation_size,),
)
def action_space(self):
return spaces.Box(
low=-1.0,
high=1.0,
shape=(self.env.action_size,),
)
class ClipAction(Wrapper):
def __init__(self, env, low=-0.999, high=0.999):
super().__init__(env)
self.low = low
self.high = high
def step(self, key, state, action):
"""TODO: In theory the below line should be the way to do this."""
# action = jnp.clip(action, self.env.action_space.low, self.env.action_space.high)
action = jnp.clip(action, self.low, self.high)
return self.env.step(key, state, action)
@struct.dataclass
class NormalizeVecObsEnvState:
mean: jnp.ndarray
var: jnp.ndarray
critic_mean: jnp.ndarray
critic_var: jnp.ndarray
count: float
env_state: environment.EnvState
truncated: float
info: Any = None
def unwrapped(self):
return self.env_state.unwrapped()
def set_env_state(self, env_state):
return self.replace(env_state=self.env_state.set_env_state(env_state))
class NormalizeVec(Wrapper):
def __init__(self, env):
super().__init__(env)
def _init_state(self, key):
obs, critic_obs, env_state = self.env.reset(key)
return NormalizeVecObsEnvState(
mean=jnp.mean(obs, axis=0),
var=jnp.var(obs, axis=0),
critic_mean=jnp.mean(critic_obs, axis=0),
critic_var=jnp.var(critic_obs, axis=0),
count=obs.shape[0],
env_state=env_state,
)
def _compute_stats(self, mean, var, count, obs):
batch_mean = jnp.mean(obs, axis=0)
batch_var = jnp.var(obs, axis=0)
batch_count = obs.shape[0]
delta = batch_mean - mean
tot_count = count + batch_count
new_mean = mean + delta * batch_count / tot_count
m_a = var * count
m_b = batch_var * batch_count
M2 = m_a + m_b + jnp.square(delta) * count * batch_count / tot_count
new_var = M2 / tot_count
return new_mean, new_var
def reset(self, key, params=None):
obs, critic_obs, env_state = self.env.reset(key)
if params is not None:
mean = params.mean
var = params.var
critic_mean = params.critic_mean
critic_var = params.critic_var
count = params.count
else:
mean = jnp.mean(obs, axis=0)
var = jnp.var(obs, axis=0)
critic_mean = jnp.mean(critic_obs, axis=0)
critic_var = jnp.var(critic_obs, axis=0)
count = obs.shape[0]
state = NormalizeVecObsEnvState(
mean=mean,
var=var,
critic_mean=critic_mean,
critic_var=critic_var,
count=count,
env_state=env_state,
truncated=env_state.truncated,
info=env_state.info,
)
return (
(obs - state.mean) / jnp.sqrt(state.var + 1e-2),
(critic_obs - state.critic_mean) / jnp.sqrt(state.critic_var + 1e-2),
state,
)
def step(self, key, state, action):
obs, critic_obs, env_state, reward, done, info = self.env.step(
key, state.env_state, action
)
new_mean, new_var = self._compute_stats(state.mean, state.var, state.count, obs)
new_critic_mean, new_critic_var = self._compute_stats(
state.critic_mean, state.critic_var, state.count, critic_obs
)
new_count = state.count + obs.shape[0]
state = NormalizeVecObsEnvState(
mean=new_mean,
var=new_var,
critic_mean=new_critic_mean,
critic_var=new_critic_var,
count=new_count,
env_state=env_state,
truncated=env_state.truncated,
info=env_state.info,
)
return (
(obs - state.mean) / jnp.sqrt(state.var + 1e-2),
(critic_obs - state.critic_mean) / jnp.sqrt(state.critic_var + 1e-2),
state,
reward,
done,
info,
)