reppo/reppo_alg/jaxrl/ppo_mjx.py
2025-07-21 18:31:20 -04:00

751 lines
27 KiB
Python

import logging
import math
import time
import typing
from typing import Callable, Optional
import distrax
import hydra
import jax
import optax
import plotly.graph_objs as go
from flax import nnx, struct
from flax.struct import PyTreeNode
from gymnax.environments.environment import Environment, EnvParams, EnvState
from jax import numpy as jnp
from jax.experimental import checkify
from jax.random import PRNGKey
from omegaconf import DictConfig, OmegaConf
import wandb
from reppo_alg.env_utils.jax_wrappers import (
BraxGymnaxWrapper,
ClipAction,
LogWrapper,
MjxGymnaxWrapper,
)
from reppo_alg.jaxrl import utils
from reppo_alg.jaxrl.normalization import NormalizationState, Normalizer
logging.basicConfig(level=logging.INFO)
## INITIALIZE CLASS STRUCTURES (NETWORKS, STATES, ...)
class Policy(typing.Protocol):
def __call__(
self,
key: jax.random.PRNGKey,
obs: PyTreeNode,
state: Optional[PyTreeNode] = None,
) -> tuple[PyTreeNode, PyTreeNode]:
pass
class PPOConfig(struct.PyTreeNode):
lr: float
gamma: float
lmbda: float
clip_ratio: float
value_coef: float
entropy_coef: float
total_time_steps: int
num_steps: int
num_mini_batches: int
num_envs: int
num_epochs: int
max_grad_norm: float | None
normalize_advantages: bool
normalize_env: bool
anneal_lr: bool
num_eval: int = 25
max_episode_steps: int = 1000
class Transition(struct.PyTreeNode):
obs: jax.Array
critic_obs: jax.Array
action: jax.Array
reward: jax.Array
log_prob: jax.Array
value: jax.Array
done: jax.Array
truncated: jax.Array
info: dict[str, jax.Array]
class PPOTrainState(nnx.TrainState):
iteration: int
time_steps: int
last_env_state: EnvState
last_obs: jax.Array
last_critic_obs: jax.Array
normalization_state: NormalizationState | None = None
critic_normalization_state: NormalizationState | None = None
class PPONetworks(nnx.Module):
def __init__(
self,
obs_dim: int,
critic_obs_dim: int,
action_dim: int,
hidden_dim: int = 64,
*,
rngs: nnx.Rngs,
):
def linear_layer(in_features, out_features, scale=jnp.sqrt(2)):
return nnx.Linear(
in_features=in_features,
out_features=out_features,
kernel_init=nnx.initializers.orthogonal(scale=scale),
bias_init=nnx.initializers.zeros_init(),
rngs=rngs,
)
self.actor_module = nnx.Sequential(
linear_layer(obs_dim, hidden_dim),
nnx.tanh,
linear_layer(hidden_dim, hidden_dim),
nnx.tanh,
linear_layer(hidden_dim, action_dim, scale=0.01),
)
self.log_std = nnx.Param(jnp.zeros(action_dim))
self.critic_module = nnx.Sequential(
linear_layer(critic_obs_dim, hidden_dim),
nnx.tanh,
linear_layer(hidden_dim, hidden_dim),
nnx.tanh,
linear_layer(hidden_dim, 1, scale=1.0),
)
def critic(self, obs: jax.Array) -> jax.Array:
return self.critic_module(obs).squeeze()
def actor(self, obs: jax.Array) -> distrax.Distribution:
loc = self.actor_module(obs)
pi = distrax.MultivariateNormalDiag(
loc=loc, scale_diag=jnp.exp(self.log_std.value)
)
return pi
def make_policy(train_state: PPOTrainState) -> Policy:
normalizer = Normalizer()
def policy(
key: PRNGKey, obs: jax.Array, state: struct.PyTreeNode = None
) -> tuple[jax.Array, jax.Array]:
if train_state.normalization_state is not None:
obs = normalizer.normalize(train_state.normalization_state, obs)
model = nnx.merge(train_state.graphdef, train_state.params)
pi = model.actor(obs)
value = model.critic(obs)
action = pi.sample(seed=key)
log_prob = pi.log_prob(action)
return action, dict(log_prob=log_prob, value=value)
return policy
def make_eval_fn(
env: Environment, max_episode_steps: int
) -> Callable[[jax.random.PRNGKey, Policy], dict[str, float]]:
def evaluation_fn(key: jax.random.PRNGKey, policy: Policy):
def step_env(carry, _):
key, env_state, obs = carry
key, act_key, env_key = jax.random.split(key, 3)
action, _ = policy(act_key, obs)
env_key = jax.random.split(env_key, env.num_envs)
obs, _, env_state, reward, done, info = env.step(
env_key, env_state, action.clip(-1.0 + 1e-4, 1.0 - 1e-4)
)
return (key, env_state, obs), info
key, init_key = jax.random.split(key)
init_key = jax.random.split(init_key, env.num_envs)
obs, _, env_state = env.reset(init_key)
_, infos = jax.lax.scan(
f=step_env,
init=(key, env_state, obs),
xs=None,
length=max_episode_steps,
)
return {
"episode_return": infos["returned_episode_returns"].mean(
where=infos["returned_episode"]
),
"episode_return_std": infos["returned_episode_returns"].std(
where=infos["returned_episode"]
),
"episode_length": infos["returned_episode_lengths"].mean(
where=infos["returned_episode"]
),
"episode_length_std": infos["returned_episode_lengths"].std(
where=infos["returned_episode"]
),
"num_episodes": infos["returned_episode"].sum(),
}
return evaluation_fn
def make_init(
cfg: PPOConfig,
env: Environment,
env_params: EnvParams = None,
) -> PPOTrainState:
def init(key: jax.random.PRNGKey) -> PPOTrainState:
# Number of calls to train_step
num_train_steps = cfg.total_time_steps // (cfg.num_steps * cfg.num_envs)
# Number of calls to train_iter, add 1 if not divisible by eval_interval
eval_interval = int(
(cfg.total_time_steps / (cfg.num_steps * cfg.num_envs)) // cfg.num_eval
)
num_iterations = num_train_steps // eval_interval + int(
num_train_steps % eval_interval != 0
)
key, model_key = jax.random.split(key)
# Intialize the model
networks = PPONetworks(
obs_dim=env.observation_space(env_params)[0].shape[0],
critic_obs_dim=env.observation_space(env_params)[1].shape[0],
action_dim=env.action_space(env_params).shape[0],
rngs=nnx.Rngs(model_key),
)
# Set initial learning rate
if not cfg.anneal_lr:
lr = cfg.lr
else:
num_iterations = cfg.total_time_steps // cfg.num_steps // cfg.num_envs
num_updates = num_iterations * cfg.num_epochs * cfg.num_mini_batches
lr = optax.linear_schedule(cfg.lr, 1e-6, num_updates)
# Initialize the optimizer
if cfg.max_grad_norm is not None:
optimizer = optax.chain(
optax.clip_by_global_norm(cfg.max_grad_norm),
optax.adam(lr),
)
else:
optimizer = optax.adam(lr)
# Reset and fully initialize the environment
key, env_key = jax.random.split(key)
env_key = jax.random.split(env_key, cfg.num_envs)
obs, critic_obs, env_state = env.reset(env_key)
# randomize initial time step to prevent all envs stepping in tandem
_env_state = env_state.unwrapped()
key, randomize_steps_key = jax.random.split(key)
_env_state.info["steps"] = jax.random.randint(
randomize_steps_key,
_env_state.info["steps"].shape,
0,
cfg.max_episode_steps,
).astype(jnp.float32)
env_state.set_env_state(_env_state)
if cfg.normalize_env:
normalizer = Normalizer()
norm_state = normalizer.init(obs)
critic_normalizer = Normalizer()
critic_norm_state = critic_normalizer.init(critic_obs)
obs = normalizer.normalize(norm_state, obs)
critic_obs = critic_normalizer.normalize(critic_norm_state, critic_obs)
else:
norm_state = None
critic_norm_state = None
# Initialize the state observations of the environment
return PPOTrainState.create(
iteration=0,
time_steps=0,
graphdef=nnx.graphdef(networks),
params=nnx.state(networks),
tx=optimizer,
last_env_state=env_state,
last_obs=obs,
last_critic_obs=critic_obs,
normalization_state=norm_state,
critic_normalization_state=critic_norm_state,
)
return init
def make_train_fn(
cfg: PPOConfig,
env: Environment,
env_params: EnvParams = None,
log_callback: Callable[[PPOTrainState, dict[str, jax.Array]], None] = None,
num_seeds: int = 1,
):
# Initialize the environment and wrap it to admit vectorized behavior.
env_params = env_params or env.default_params
env = ClipAction(env)
env = LogWrapper(env, cfg.num_envs)
eval_fn = make_eval_fn(env, cfg.max_episode_steps)
normalizer = Normalizer()
eval_interval = int(
(cfg.total_time_steps / (cfg.num_steps * cfg.num_envs)) // cfg.num_eval
)
def collect_rollout(
key: PRNGKey, train_state: PPOTrainState
) -> tuple[Transition, PPOTrainState]:
model = nnx.merge(train_state.graphdef, train_state.params)
# Take a step in the environment
def step_env(carry, _) -> tuple[tuple, Transition]:
key, env_state, train_state, obs, critic_obs = carry
if cfg.normalize_env:
norm_state = normalizer.update(train_state.normalization_state, obs)
obs = normalizer.normalize(norm_state, obs)
train_state = train_state.replace(normalization_state=norm_state)
critic_obs = normalizer.normalize(
train_state.critic_normalization_state, critic_obs
)
# Select action
key, act_key, step_key = jax.random.split(key, 3)
pi = model.actor(obs)
action = pi.sample(seed=act_key)
# Take a step in the environment
step_key = jax.random.split(step_key, cfg.num_envs)
next_obs, next_critic_obs, next_env_state, reward, done, info = env.step(
step_key, env_state, action.clip(-1.0 + 1e-4, 1.0 - 1e-4)
)
# Record the transition
transition = Transition(
obs=obs,
critic_obs=critic_obs,
action=action,
reward=reward,
log_prob=pi.log_prob(action),
value=model.critic(critic_obs),
done=done,
truncated=next_env_state.truncated,
info=info,
)
return (
key,
next_env_state,
train_state,
next_obs,
next_critic_obs,
), transition
# Collect rollout via lax.scan taking steps in the environment
rollout_state, transitions = jax.lax.scan(
f=step_env,
init=(
key,
train_state.last_env_state,
train_state,
train_state.last_obs,
train_state.last_critic_obs,
),
length=cfg.num_steps,
)
# Aggregate the transitions across all the environments to reset for the next iteration
_, last_env_state, train_state, last_obs, last_critic_obs = rollout_state
train_state = train_state.replace(
last_env_state=last_env_state,
last_obs=last_obs,
last_critic_obs=last_critic_obs,
time_steps=train_state.time_steps + cfg.num_steps * cfg.num_envs,
)
return transitions, train_state
def learn_step(
key: PRNGKey, train_state: PPOTrainState, batch: Transition
) -> tuple[PPOTrainState, dict[str, jax.Array]]:
# Compute advantages and target values
model = nnx.merge(train_state.graphdef, train_state.params)
if cfg.normalize_env:
last_critic_obs = normalizer.normalize(
train_state.critic_normalization_state, train_state.last_critic_obs
)
else:
last_critic_obs = train_state.last_critic_obs
last_value = model.critic(last_critic_obs)
def compute_advantage(carry, transition):
gae, next_value = carry
done = transition.done
truncated = transition.truncated
reward = transition.reward
value = transition.value
delta = reward + cfg.gamma * next_value * (1 - done) - value
gae = delta + cfg.gamma * cfg.lmbda * (1 - done) * gae
truncated_gae = reward + cfg.gamma * next_value - value
gae = jnp.where(truncated, truncated_gae, gae)
return (gae, value), gae
# Compute the advantage using GAE
_, advantages = jax.lax.scan(
compute_advantage,
(jnp.zeros_like(last_value), last_value),
batch,
reverse=True,
)
target_values = advantages + batch.value
data = (batch, advantages, target_values)
# Reshape data to (num_steps * num_envs, ...)
data = jax.tree.map(
lambda x: x.reshape(
(math.floor(cfg.num_steps * cfg.num_envs), *x.shape[2:])
),
data,
)
def update(train_state, key) -> tuple[PPOTrainState, dict[str, jax.Array]]:
def minibatch_update(carry, indices):
idx, train_state = carry
# Sample data at indices from the batch
minibatch, advantages, target_values = jax.tree.map(
lambda x: jnp.take(x, indices, axis=0), data
)
if cfg.normalize_advantages:
advantages = (advantages - jnp.mean(advantages)) / (
jnp.std(advantages) + 1e-8
)
# Define the loss function
def loss_fn(params):
model = nnx.merge(train_state.graphdef, params)
pi = model.actor(minibatch.obs)
value = model.critic(minibatch.critic_obs)
log_prob = pi.log_prob(minibatch.action)
value_pred_clipped = minibatch.value + (
value - minibatch.value
).clip(-cfg.clip_ratio, cfg.clip_ratio)
value_error = jnp.square(value - target_values)
value_error_clipped = jnp.square(value_pred_clipped - target_values)
value_loss = 0.5 * jnp.mean(
(1.0 - minibatch.truncated)
* jnp.maximum(value_error, value_error_clipped)
)
ratio = jnp.exp(log_prob - minibatch.log_prob)
checkify.check(
jnp.allclose(ratio, 1.0) | (idx != 1),
debug=True,
msg="Ratio not equal to 1 on first iteration: {r}",
r=ratio,
)
actor_loss1 = ratio * advantages
actor_loss2 = (
jnp.clip(ratio, 1 - cfg.clip_ratio, 1 + cfg.clip_ratio)
* advantages
)
actor_loss = -jnp.mean(
(1.0 - minibatch.truncated)
* jnp.minimum(actor_loss1, actor_loss2)
)
entropy_loss = jnp.mean(pi.entropy())
loss = (
actor_loss
+ cfg.value_coef * value_loss
- cfg.entropy_coef * entropy_loss
)
return loss, dict(
actor_loss=actor_loss,
value_loss=value_loss,
entropy_loss=entropy_loss,
loss=loss,
mean_value=value.mean(),
mean_log_prob=log_prob.mean(),
mean_advantages=advantages.mean(),
mean_action=minibatch.action.mean(),
mean_reward=minibatch.reward.mean(),
)
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
output, grads = grad_fn(train_state.params)
# Global gradient norm (all parameters combined)
flat_grads, _ = jax.flatten_util.ravel_pytree(grads)
global_grad_norm = jnp.linalg.norm(flat_grads)
metrics = output[1]
metrics["advantages"] = advantages
metrics["global_grad_norm"] = global_grad_norm
train_state = train_state.apply_gradients(grads)
return (idx + 1, train_state), metrics
# Shuffle data and split into mini-batches
key, shuffle_key = jax.random.split(key)
mini_batch_size = (
math.floor(cfg.num_steps * cfg.num_envs) // cfg.num_mini_batches
)
indices = jax.random.permutation(shuffle_key, cfg.num_steps * cfg.num_envs)
minibatch_idxs = jax.tree.map(
lambda x: x.reshape(
(cfg.num_mini_batches, mini_batch_size, *x.shape[1:])
),
indices,
)
# Run model update for each mini-batch
train_state, metrics = jax.lax.scan(
minibatch_update, train_state, minibatch_idxs
)
# Compute mean metrics across mini-batches
metrics = jax.tree.map(lambda x: x.mean(0), metrics)
return train_state, metrics
# Update the model for a number of epochs
key, train_key = jax.random.split(key)
(_, train_state), update_metrics = jax.lax.scan(
f=update,
init=(1, train_state),
xs=jax.random.split(train_key, cfg.num_epochs),
)
# Get metrics from the last epoch
update_metrics = jax.tree.map(lambda x: x[-1], update_metrics)
return train_state, update_metrics
# Define the training loop
def train_fn(key: PRNGKey) -> tuple[PPOTrainState, dict]:
def train_eval_step(key, train_state):
def train_step(
state: PPOTrainState, key: PRNGKey
) -> tuple[PPOTrainState, dict[str, jax.Array]]:
key, rollout_key, learn_key = jax.random.split(key, 3)
# Collect trajectories from `state`
transitions, state = collect_rollout(key=rollout_key, train_state=state)
# Execute an update to the policy with `transitions`
state, update_metrics = learn_step(
key=learn_key, train_state=state, batch=transitions
)
metrics = {**update_metrics, **update_metrics}
state = state.replace(iteration=state.iteration + 1)
return state, metrics
train_key, eval_key = jax.random.split(key)
train_state, train_metrics = jax.lax.scan(
f=train_step,
init=train_state,
xs=jax.random.split(train_key, eval_interval),
)
train_metrics = jax.tree.map(lambda x: x[-1], train_metrics)
policy = make_policy(train_state)
eval_metrics = eval_fn(eval_key, policy)
metrics = {
"time_step": train_state.time_steps,
**utils.prefix_dict("train", train_metrics),
**utils.prefix_dict("eval", eval_metrics),
}
return train_state, metrics
def loop_body(
train_state: PPOTrainState, key: PRNGKey
) -> tuple[PPOTrainState, dict]:
# Map execution of the train+eval step across num_seeds (will be looped using jax.lax.scan)
key, subkey = jax.random.split(key)
train_state, metrics = jax.vmap(train_eval_step)(
jax.random.split(subkey, num_seeds), train_state
)
jax.debug.callback(log_callback, train_state, metrics)
return train_state, metrics
# Initialize the policy, environment and map that across the number of random seeds
num_train_steps = cfg.total_time_steps // (cfg.num_steps * cfg.num_envs)
num_iterations = num_train_steps // eval_interval + int(
num_train_steps % eval_interval != 0
)
key, init_key = jax.random.split(key)
train_state = jax.vmap(make_init(cfg, env, env_params))(
jax.random.split(init_key, num_seeds)
)
keys = jax.random.split(key, num_iterations)
# Run the training and evaluation loop from the initialized training state
state, metrics = jax.lax.scan(f=loop_body, init=train_state, xs=keys)
return state, metrics
return train_fn
def plot_history(history: list[dict[str, jax.Array]]):
steps = jnp.array([m["time_step"][0] for m in history])
eval_return = jnp.array([m["eval/episode_return"].mean() for m in history])
eval_return_std = jnp.array([m["eval/episode_return"].std() for m in history])
fig = go.Figure(
[
go.Scatter(
x=steps,
y=eval_return,
name="Mean Episode Return",
mode="lines",
line=dict(color="blue"),
showlegend=False,
),
go.Scatter(
x=steps,
y=eval_return + eval_return_std,
name="Upper Bound",
mode="lines",
line=dict(width=0),
showlegend=False,
),
go.Scatter(
x=steps,
y=eval_return - eval_return_std,
name="Lower Bound",
mode="lines",
line=dict(width=0),
fill="tonexty",
fillcolor="rgba(50, 127, 168, 0.3)",
showlegend=False,
),
]
)
fig.update_layout(
xaxis=dict(title=dict(text="Environment Steps")),
)
return fig
def run(cfg: DictConfig):
metric_history = []
# Define callback to log metrics during training
def log_callback(state, metrics):
metrics["sys_time"] = time.perf_counter()
if len(metric_history) > 0:
num_env_steps = state.time_steps[0] - metric_history[-1]["time_step"][0]
seconds = metrics["sys_time"] - metric_history[-1]["sys_time"]
sps = num_env_steps / seconds
else:
sps = 0
metric_history.append(metrics)
episode_return = metrics["eval/episode_return"].mean()
# Use pop() with a default value of None in case 'advantages' key doesn't exist
advantages = metrics.pop("train/advantages", None)
logging.info(
f"step={state.time_steps[0]} episode_return={episode_return:.3f}, sps={sps:.2f}"
)
log_data = {
"eval/episode_return": episode_return,
"train/advantages": wandb.Histogram(advantages),
**jax.tree.map(jnp.mean, utils.filter_prefix("train", metrics)),
}
# Push log data to WandB
wandb.log(log_data, step=state.time_steps[0])
logging.info(OmegaConf.to_yaml(cfg))
# Set up the experimental environment
if cfg.env.type == "brax":
env = BraxGymnaxWrapper(
cfg.env.name
) # , episode_length=cfg.env.max_episode_steps
elif cfg.env.type == "mjx":
env = MjxGymnaxWrapper(cfg.env.name, episode_length=cfg.env.max_episode_steps)
else:
raise ValueError(f"Unknown environment type: {cfg.env.type}")
key = jax.random.PRNGKey(cfg.seed)
train_fn = make_train_fn(
cfg=PPOConfig(**cfg.hyperparameters),
env=env,
log_callback=log_callback,
num_seeds=cfg.num_seeds,
)
for i in range(cfg.trials):
# Initialize WandB reporting
key, train_key = jax.random.split(key)
wandb.init(
mode=cfg.wandb.mode,
project=cfg.wandb.project,
entity=cfg.wandb.entity,
tags=[cfg.name, cfg.env.name, cfg.env.type, *cfg.tags],
config=OmegaConf.to_container(cfg),
name=f"ppo-{cfg.name}-{cfg.env.name.lower()}",
save_code=True,
)
start = time.perf_counter()
train_state, metrics = jax.jit(train_fn)(train_key)
jax.block_until_ready(metrics)
duration = time.perf_counter() - start
# Save metrics and finish the run
logging.info(f"Training took {duration:.2f} seconds.")
# jnp.savez("metrics.npz", **metrics) # TODO: fix the directory here to save to a unique output directory
wandb.finish()
def tune(cfg: DictConfig):
def log_callback(state, metrics):
episode_return = metrics["eval/episode_return"].mean()
t = state.time_steps[0]
wandb.log(
{
"episode_return": episode_return,
},
step=t,
)
env = MjxGymnaxWrapper(cfg.env.name, episode_length=cfg.env.max_episode_steps)
def train_agent():
wandb.init(project=cfg.wandb.project)
run_cfg = OmegaConf.to_container(cfg)
for k, v in dict(wandb.config).items():
run_cfg["experiment"]["hyperparameters"][k] = v
ppo_cfg = PPOConfig(**run_cfg["experiment"]["hyperparameters"])
train_fn = make_train_fn(
cfg=ppo_cfg,
env=env,
log_callback=log_callback,
num_seeds=cfg.num_seeds,
)
train_fn = jax.jit(train_fn)
logging.info(f"Running experiment with params: \n {run_cfg}")
key = jax.random.PRNGKey(cfg.seed)
train_state, metrics = train_fn(key)
jax.block_until_ready(metrics)
sweep_id = wandb.sweep(
sweep={
"name": f"{cfg.name}-{cfg.env.name}",
"method": "bayes",
"metric": {"name": "episode_return", "goal": "maximize"},
"parameters": {
"lr": {
"values": [1e-4, 3e-4, 1e-3],
},
"normalize_env": {
"values": [True, False],
},
},
},
project=cfg.wandb.project,
entity=cfg.wandb.entity,
)
wandb.agent(sweep_id, function=train_agent, count=cfg.tune.num_runs)
@hydra.main(version_base=None, config_path="../../config", config_name="ppo")
def main(cfg: DictConfig):
if cfg.tune:
tune(cfg)
else:
run(cfg)
if __name__ == "__main__":
main()