import logging import math import time import typing from typing import Callable, Optional import distrax import hydra import jax import optax import plotly.graph_objs as go from flax import nnx, struct from flax.struct import PyTreeNode from gymnax.environments.environment import Environment, EnvParams, EnvState from jax import numpy as jnp from jax.experimental import checkify from jax.random import PRNGKey from omegaconf import DictConfig, OmegaConf import wandb from reppo_alg.env_utils.jax_wrappers import ( BraxGymnaxWrapper, ClipAction, LogWrapper, MjxGymnaxWrapper, ) from reppo_alg.jaxrl import utils from reppo_alg.jaxrl.normalization import NormalizationState, Normalizer logging.basicConfig(level=logging.INFO) ## INITIALIZE CLASS STRUCTURES (NETWORKS, STATES, ...) class Policy(typing.Protocol): def __call__( self, key: jax.random.PRNGKey, obs: PyTreeNode, state: Optional[PyTreeNode] = None, ) -> tuple[PyTreeNode, PyTreeNode]: pass class PPOConfig(struct.PyTreeNode): lr: float gamma: float lmbda: float clip_ratio: float value_coef: float entropy_coef: float total_time_steps: int num_steps: int num_mini_batches: int num_envs: int num_epochs: int max_grad_norm: float | None normalize_advantages: bool normalize_env: bool anneal_lr: bool num_eval: int = 25 max_episode_steps: int = 1000 class Transition(struct.PyTreeNode): obs: jax.Array critic_obs: jax.Array action: jax.Array reward: jax.Array log_prob: jax.Array value: jax.Array done: jax.Array truncated: jax.Array info: dict[str, jax.Array] class PPOTrainState(nnx.TrainState): iteration: int time_steps: int last_env_state: EnvState last_obs: jax.Array last_critic_obs: jax.Array normalization_state: NormalizationState | None = None critic_normalization_state: NormalizationState | None = None class PPONetworks(nnx.Module): def __init__( self, obs_dim: int, critic_obs_dim: int, action_dim: int, hidden_dim: int = 64, *, rngs: nnx.Rngs, ): def linear_layer(in_features, out_features, scale=jnp.sqrt(2)): return nnx.Linear( in_features=in_features, out_features=out_features, kernel_init=nnx.initializers.orthogonal(scale=scale), bias_init=nnx.initializers.zeros_init(), rngs=rngs, ) self.actor_module = nnx.Sequential( linear_layer(obs_dim, hidden_dim), nnx.tanh, linear_layer(hidden_dim, hidden_dim), nnx.tanh, linear_layer(hidden_dim, action_dim, scale=0.01), ) self.log_std = nnx.Param(jnp.zeros(action_dim)) self.critic_module = nnx.Sequential( linear_layer(critic_obs_dim, hidden_dim), nnx.tanh, linear_layer(hidden_dim, hidden_dim), nnx.tanh, linear_layer(hidden_dim, 1, scale=1.0), ) def critic(self, obs: jax.Array) -> jax.Array: return self.critic_module(obs).squeeze() def actor(self, obs: jax.Array) -> distrax.Distribution: loc = self.actor_module(obs) pi = distrax.MultivariateNormalDiag( loc=loc, scale_diag=jnp.exp(self.log_std.value) ) return pi def make_policy(train_state: PPOTrainState) -> Policy: normalizer = Normalizer() def policy( key: PRNGKey, obs: jax.Array, state: struct.PyTreeNode = None ) -> tuple[jax.Array, jax.Array]: if train_state.normalization_state is not None: obs = normalizer.normalize(train_state.normalization_state, obs) model = nnx.merge(train_state.graphdef, train_state.params) pi = model.actor(obs) value = model.critic(obs) action = pi.sample(seed=key) log_prob = pi.log_prob(action) return action, dict(log_prob=log_prob, value=value) return policy def make_eval_fn( env: Environment, max_episode_steps: int ) -> Callable[[jax.random.PRNGKey, Policy], dict[str, float]]: def evaluation_fn(key: jax.random.PRNGKey, policy: Policy): def step_env(carry, _): key, env_state, obs = carry key, act_key, env_key = jax.random.split(key, 3) action, _ = policy(act_key, obs) env_key = jax.random.split(env_key, env.num_envs) obs, _, env_state, reward, done, info = env.step( env_key, env_state, action.clip(-1.0 + 1e-4, 1.0 - 1e-4) ) return (key, env_state, obs), info key, init_key = jax.random.split(key) init_key = jax.random.split(init_key, env.num_envs) obs, _, env_state = env.reset(init_key) _, infos = jax.lax.scan( f=step_env, init=(key, env_state, obs), xs=None, length=max_episode_steps, ) return { "episode_return": infos["returned_episode_returns"].mean( where=infos["returned_episode"] ), "episode_return_std": infos["returned_episode_returns"].std( where=infos["returned_episode"] ), "episode_length": infos["returned_episode_lengths"].mean( where=infos["returned_episode"] ), "episode_length_std": infos["returned_episode_lengths"].std( where=infos["returned_episode"] ), "num_episodes": infos["returned_episode"].sum(), } return evaluation_fn def make_init( cfg: PPOConfig, env: Environment, env_params: EnvParams = None, ) -> PPOTrainState: def init(key: jax.random.PRNGKey) -> PPOTrainState: # Number of calls to train_step num_train_steps = cfg.total_time_steps // (cfg.num_steps * cfg.num_envs) # Number of calls to train_iter, add 1 if not divisible by eval_interval eval_interval = int( (cfg.total_time_steps / (cfg.num_steps * cfg.num_envs)) // cfg.num_eval ) num_iterations = num_train_steps // eval_interval + int( num_train_steps % eval_interval != 0 ) key, model_key = jax.random.split(key) # Intialize the model networks = PPONetworks( obs_dim=env.observation_space(env_params)[0].shape[0], critic_obs_dim=env.observation_space(env_params)[1].shape[0], action_dim=env.action_space(env_params).shape[0], rngs=nnx.Rngs(model_key), ) # Set initial learning rate if not cfg.anneal_lr: lr = cfg.lr else: num_iterations = cfg.total_time_steps // cfg.num_steps // cfg.num_envs num_updates = num_iterations * cfg.num_epochs * cfg.num_mini_batches lr = optax.linear_schedule(cfg.lr, 1e-6, num_updates) # Initialize the optimizer if cfg.max_grad_norm is not None: optimizer = optax.chain( optax.clip_by_global_norm(cfg.max_grad_norm), optax.adam(lr), ) else: optimizer = optax.adam(lr) # Reset and fully initialize the environment key, env_key = jax.random.split(key) env_key = jax.random.split(env_key, cfg.num_envs) obs, critic_obs, env_state = env.reset(env_key) # randomize initial time step to prevent all envs stepping in tandem _env_state = env_state.unwrapped() key, randomize_steps_key = jax.random.split(key) _env_state.info["steps"] = jax.random.randint( randomize_steps_key, _env_state.info["steps"].shape, 0, cfg.max_episode_steps, ).astype(jnp.float32) env_state.set_env_state(_env_state) if cfg.normalize_env: normalizer = Normalizer() norm_state = normalizer.init(obs) critic_normalizer = Normalizer() critic_norm_state = critic_normalizer.init(critic_obs) obs = normalizer.normalize(norm_state, obs) critic_obs = critic_normalizer.normalize(critic_norm_state, critic_obs) else: norm_state = None critic_norm_state = None # Initialize the state observations of the environment return PPOTrainState.create( iteration=0, time_steps=0, graphdef=nnx.graphdef(networks), params=nnx.state(networks), tx=optimizer, last_env_state=env_state, last_obs=obs, last_critic_obs=critic_obs, normalization_state=norm_state, critic_normalization_state=critic_norm_state, ) return init def make_train_fn( cfg: PPOConfig, env: Environment, env_params: EnvParams = None, log_callback: Callable[[PPOTrainState, dict[str, jax.Array]], None] = None, num_seeds: int = 1, ): # Initialize the environment and wrap it to admit vectorized behavior. env_params = env_params or env.default_params env = ClipAction(env) env = LogWrapper(env, cfg.num_envs) eval_fn = make_eval_fn(env, cfg.max_episode_steps) normalizer = Normalizer() eval_interval = int( (cfg.total_time_steps / (cfg.num_steps * cfg.num_envs)) // cfg.num_eval ) def collect_rollout( key: PRNGKey, train_state: PPOTrainState ) -> tuple[Transition, PPOTrainState]: model = nnx.merge(train_state.graphdef, train_state.params) # Take a step in the environment def step_env(carry, _) -> tuple[tuple, Transition]: key, env_state, train_state, obs, critic_obs = carry if cfg.normalize_env: norm_state = normalizer.update(train_state.normalization_state, obs) obs = normalizer.normalize(norm_state, obs) train_state = train_state.replace(normalization_state=norm_state) critic_obs = normalizer.normalize( train_state.critic_normalization_state, critic_obs ) # Select action key, act_key, step_key = jax.random.split(key, 3) pi = model.actor(obs) action = pi.sample(seed=act_key) # Take a step in the environment step_key = jax.random.split(step_key, cfg.num_envs) next_obs, next_critic_obs, next_env_state, reward, done, info = env.step( step_key, env_state, action.clip(-1.0 + 1e-4, 1.0 - 1e-4) ) # Record the transition transition = Transition( obs=obs, critic_obs=critic_obs, action=action, reward=reward, log_prob=pi.log_prob(action), value=model.critic(critic_obs), done=done, truncated=next_env_state.truncated, info=info, ) return ( key, next_env_state, train_state, next_obs, next_critic_obs, ), transition # Collect rollout via lax.scan taking steps in the environment rollout_state, transitions = jax.lax.scan( f=step_env, init=( key, train_state.last_env_state, train_state, train_state.last_obs, train_state.last_critic_obs, ), length=cfg.num_steps, ) # Aggregate the transitions across all the environments to reset for the next iteration _, last_env_state, train_state, last_obs, last_critic_obs = rollout_state train_state = train_state.replace( last_env_state=last_env_state, last_obs=last_obs, last_critic_obs=last_critic_obs, time_steps=train_state.time_steps + cfg.num_steps * cfg.num_envs, ) return transitions, train_state def learn_step( key: PRNGKey, train_state: PPOTrainState, batch: Transition ) -> tuple[PPOTrainState, dict[str, jax.Array]]: # Compute advantages and target values model = nnx.merge(train_state.graphdef, train_state.params) if cfg.normalize_env: last_critic_obs = normalizer.normalize( train_state.critic_normalization_state, train_state.last_critic_obs ) else: last_critic_obs = train_state.last_critic_obs last_value = model.critic(last_critic_obs) def compute_advantage(carry, transition): gae, next_value = carry done = transition.done truncated = transition.truncated reward = transition.reward value = transition.value delta = reward + cfg.gamma * next_value * (1 - done) - value gae = delta + cfg.gamma * cfg.lmbda * (1 - done) * gae truncated_gae = reward + cfg.gamma * next_value - value gae = jnp.where(truncated, truncated_gae, gae) return (gae, value), gae # Compute the advantage using GAE _, advantages = jax.lax.scan( compute_advantage, (jnp.zeros_like(last_value), last_value), batch, reverse=True, ) target_values = advantages + batch.value data = (batch, advantages, target_values) # Reshape data to (num_steps * num_envs, ...) data = jax.tree.map( lambda x: x.reshape( (math.floor(cfg.num_steps * cfg.num_envs), *x.shape[2:]) ), data, ) def update(train_state, key) -> tuple[PPOTrainState, dict[str, jax.Array]]: def minibatch_update(carry, indices): idx, train_state = carry # Sample data at indices from the batch minibatch, advantages, target_values = jax.tree.map( lambda x: jnp.take(x, indices, axis=0), data ) if cfg.normalize_advantages: advantages = (advantages - jnp.mean(advantages)) / ( jnp.std(advantages) + 1e-8 ) # Define the loss function def loss_fn(params): model = nnx.merge(train_state.graphdef, params) pi = model.actor(minibatch.obs) value = model.critic(minibatch.critic_obs) log_prob = pi.log_prob(minibatch.action) value_pred_clipped = minibatch.value + ( value - minibatch.value ).clip(-cfg.clip_ratio, cfg.clip_ratio) value_error = jnp.square(value - target_values) value_error_clipped = jnp.square(value_pred_clipped - target_values) value_loss = 0.5 * jnp.mean( (1.0 - minibatch.truncated) * jnp.maximum(value_error, value_error_clipped) ) ratio = jnp.exp(log_prob - minibatch.log_prob) checkify.check( jnp.allclose(ratio, 1.0) | (idx != 1), debug=True, msg="Ratio not equal to 1 on first iteration: {r}", r=ratio, ) actor_loss1 = ratio * advantages actor_loss2 = ( jnp.clip(ratio, 1 - cfg.clip_ratio, 1 + cfg.clip_ratio) * advantages ) actor_loss = -jnp.mean( (1.0 - minibatch.truncated) * jnp.minimum(actor_loss1, actor_loss2) ) entropy_loss = jnp.mean(pi.entropy()) loss = ( actor_loss + cfg.value_coef * value_loss - cfg.entropy_coef * entropy_loss ) return loss, dict( actor_loss=actor_loss, value_loss=value_loss, entropy_loss=entropy_loss, loss=loss, mean_value=value.mean(), mean_log_prob=log_prob.mean(), mean_advantages=advantages.mean(), mean_action=minibatch.action.mean(), mean_reward=minibatch.reward.mean(), ) grad_fn = jax.value_and_grad(loss_fn, has_aux=True) output, grads = grad_fn(train_state.params) # Global gradient norm (all parameters combined) flat_grads, _ = jax.flatten_util.ravel_pytree(grads) global_grad_norm = jnp.linalg.norm(flat_grads) metrics = output[1] metrics["advantages"] = advantages metrics["global_grad_norm"] = global_grad_norm train_state = train_state.apply_gradients(grads) return (idx + 1, train_state), metrics # Shuffle data and split into mini-batches key, shuffle_key = jax.random.split(key) mini_batch_size = ( math.floor(cfg.num_steps * cfg.num_envs) // cfg.num_mini_batches ) indices = jax.random.permutation(shuffle_key, cfg.num_steps * cfg.num_envs) minibatch_idxs = jax.tree.map( lambda x: x.reshape( (cfg.num_mini_batches, mini_batch_size, *x.shape[1:]) ), indices, ) # Run model update for each mini-batch train_state, metrics = jax.lax.scan( minibatch_update, train_state, minibatch_idxs ) # Compute mean metrics across mini-batches metrics = jax.tree.map(lambda x: x.mean(0), metrics) return train_state, metrics # Update the model for a number of epochs key, train_key = jax.random.split(key) (_, train_state), update_metrics = jax.lax.scan( f=update, init=(1, train_state), xs=jax.random.split(train_key, cfg.num_epochs), ) # Get metrics from the last epoch update_metrics = jax.tree.map(lambda x: x[-1], update_metrics) return train_state, update_metrics # Define the training loop def train_fn(key: PRNGKey) -> tuple[PPOTrainState, dict]: def train_eval_step(key, train_state): def train_step( state: PPOTrainState, key: PRNGKey ) -> tuple[PPOTrainState, dict[str, jax.Array]]: key, rollout_key, learn_key = jax.random.split(key, 3) # Collect trajectories from `state` transitions, state = collect_rollout(key=rollout_key, train_state=state) # Execute an update to the policy with `transitions` state, update_metrics = learn_step( key=learn_key, train_state=state, batch=transitions ) metrics = {**update_metrics, **update_metrics} state = state.replace(iteration=state.iteration + 1) return state, metrics train_key, eval_key = jax.random.split(key) train_state, train_metrics = jax.lax.scan( f=train_step, init=train_state, xs=jax.random.split(train_key, eval_interval), ) train_metrics = jax.tree.map(lambda x: x[-1], train_metrics) policy = make_policy(train_state) eval_metrics = eval_fn(eval_key, policy) metrics = { "time_step": train_state.time_steps, **utils.prefix_dict("train", train_metrics), **utils.prefix_dict("eval", eval_metrics), } return train_state, metrics def loop_body( train_state: PPOTrainState, key: PRNGKey ) -> tuple[PPOTrainState, dict]: # Map execution of the train+eval step across num_seeds (will be looped using jax.lax.scan) key, subkey = jax.random.split(key) train_state, metrics = jax.vmap(train_eval_step)( jax.random.split(subkey, num_seeds), train_state ) jax.debug.callback(log_callback, train_state, metrics) return train_state, metrics # Initialize the policy, environment and map that across the number of random seeds num_train_steps = cfg.total_time_steps // (cfg.num_steps * cfg.num_envs) num_iterations = num_train_steps // eval_interval + int( num_train_steps % eval_interval != 0 ) key, init_key = jax.random.split(key) train_state = jax.vmap(make_init(cfg, env, env_params))( jax.random.split(init_key, num_seeds) ) keys = jax.random.split(key, num_iterations) # Run the training and evaluation loop from the initialized training state state, metrics = jax.lax.scan(f=loop_body, init=train_state, xs=keys) return state, metrics return train_fn def plot_history(history: list[dict[str, jax.Array]]): steps = jnp.array([m["time_step"][0] for m in history]) eval_return = jnp.array([m["eval/episode_return"].mean() for m in history]) eval_return_std = jnp.array([m["eval/episode_return"].std() for m in history]) fig = go.Figure( [ go.Scatter( x=steps, y=eval_return, name="Mean Episode Return", mode="lines", line=dict(color="blue"), showlegend=False, ), go.Scatter( x=steps, y=eval_return + eval_return_std, name="Upper Bound", mode="lines", line=dict(width=0), showlegend=False, ), go.Scatter( x=steps, y=eval_return - eval_return_std, name="Lower Bound", mode="lines", line=dict(width=0), fill="tonexty", fillcolor="rgba(50, 127, 168, 0.3)", showlegend=False, ), ] ) fig.update_layout( xaxis=dict(title=dict(text="Environment Steps")), ) return fig def run(cfg: DictConfig): metric_history = [] # Define callback to log metrics during training def log_callback(state, metrics): metrics["sys_time"] = time.perf_counter() if len(metric_history) > 0: num_env_steps = state.time_steps[0] - metric_history[-1]["time_step"][0] seconds = metrics["sys_time"] - metric_history[-1]["sys_time"] sps = num_env_steps / seconds else: sps = 0 metric_history.append(metrics) episode_return = metrics["eval/episode_return"].mean() # Use pop() with a default value of None in case 'advantages' key doesn't exist advantages = metrics.pop("train/advantages", None) logging.info( f"step={state.time_steps[0]} episode_return={episode_return:.3f}, sps={sps:.2f}" ) log_data = { "eval/episode_return": episode_return, "train/advantages": wandb.Histogram(advantages), **jax.tree.map(jnp.mean, utils.filter_prefix("train", metrics)), } # Push log data to WandB wandb.log(log_data, step=state.time_steps[0]) logging.info(OmegaConf.to_yaml(cfg)) # Set up the experimental environment if cfg.env.type == "brax": env = BraxGymnaxWrapper( cfg.env.name ) # , episode_length=cfg.env.max_episode_steps elif cfg.env.type == "mjx": env = MjxGymnaxWrapper(cfg.env.name, episode_length=cfg.env.max_episode_steps) else: raise ValueError(f"Unknown environment type: {cfg.env.type}") key = jax.random.PRNGKey(cfg.seed) train_fn = make_train_fn( cfg=PPOConfig(**cfg.hyperparameters), env=env, log_callback=log_callback, num_seeds=cfg.num_seeds, ) for i in range(cfg.trials): # Initialize WandB reporting key, train_key = jax.random.split(key) wandb.init( mode=cfg.wandb.mode, project=cfg.wandb.project, entity=cfg.wandb.entity, tags=[cfg.name, cfg.env.name, cfg.env.type, *cfg.tags], config=OmegaConf.to_container(cfg), name=f"ppo-{cfg.name}-{cfg.env.name.lower()}", save_code=True, ) start = time.perf_counter() train_state, metrics = jax.jit(train_fn)(train_key) jax.block_until_ready(metrics) duration = time.perf_counter() - start # Save metrics and finish the run logging.info(f"Training took {duration:.2f} seconds.") # jnp.savez("metrics.npz", **metrics) # TODO: fix the directory here to save to a unique output directory wandb.finish() def tune(cfg: DictConfig): def log_callback(state, metrics): episode_return = metrics["eval/episode_return"].mean() t = state.time_steps[0] wandb.log( { "episode_return": episode_return, }, step=t, ) env = MjxGymnaxWrapper(cfg.env.name, episode_length=cfg.env.max_episode_steps) def train_agent(): wandb.init(project=cfg.wandb.project) run_cfg = OmegaConf.to_container(cfg) for k, v in dict(wandb.config).items(): run_cfg["experiment"]["hyperparameters"][k] = v ppo_cfg = PPOConfig(**run_cfg["experiment"]["hyperparameters"]) train_fn = make_train_fn( cfg=ppo_cfg, env=env, log_callback=log_callback, num_seeds=cfg.num_seeds, ) train_fn = jax.jit(train_fn) logging.info(f"Running experiment with params: \n {run_cfg}") key = jax.random.PRNGKey(cfg.seed) train_state, metrics = train_fn(key) jax.block_until_ready(metrics) sweep_id = wandb.sweep( sweep={ "name": f"{cfg.name}-{cfg.env.name}", "method": "bayes", "metric": {"name": "episode_return", "goal": "maximize"}, "parameters": { "lr": { "values": [1e-4, 3e-4, 1e-3], }, "normalize_env": { "values": [True, False], }, }, }, project=cfg.wandb.project, entity=cfg.wandb.entity, ) wandb.agent(sweep_id, function=train_agent, count=cfg.tune.num_runs) @hydra.main(version_base=None, config_path="../../config", config_name="ppo") def main(cfg: DictConfig): if cfg.tune: tune(cfg) else: run(cfg) if __name__ == "__main__": main()