751 lines
27 KiB
Python
751 lines
27 KiB
Python
import logging
|
|
import math
|
|
import time
|
|
import typing
|
|
from typing import Callable, Optional
|
|
|
|
import distrax
|
|
import hydra
|
|
import jax
|
|
import optax
|
|
import plotly.graph_objs as go
|
|
from flax import nnx, struct
|
|
from flax.struct import PyTreeNode
|
|
from gymnax.environments.environment import Environment, EnvParams, EnvState
|
|
from jax import numpy as jnp
|
|
from jax.experimental import checkify
|
|
from jax.random import PRNGKey
|
|
from omegaconf import DictConfig, OmegaConf
|
|
|
|
import wandb
|
|
from reppo_alg.env_utils.jax_wrappers import (
|
|
BraxGymnaxWrapper,
|
|
ClipAction,
|
|
LogWrapper,
|
|
MjxGymnaxWrapper,
|
|
)
|
|
from reppo_alg.jaxrl import utils
|
|
from reppo_alg.jaxrl.normalization import NormalizationState, Normalizer
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
|
|
## INITIALIZE CLASS STRUCTURES (NETWORKS, STATES, ...)
|
|
class Policy(typing.Protocol):
|
|
def __call__(
|
|
self,
|
|
key: jax.random.PRNGKey,
|
|
obs: PyTreeNode,
|
|
state: Optional[PyTreeNode] = None,
|
|
) -> tuple[PyTreeNode, PyTreeNode]:
|
|
pass
|
|
|
|
|
|
class PPOConfig(struct.PyTreeNode):
|
|
lr: float
|
|
gamma: float
|
|
lmbda: float
|
|
clip_ratio: float
|
|
value_coef: float
|
|
entropy_coef: float
|
|
total_time_steps: int
|
|
num_steps: int
|
|
num_mini_batches: int
|
|
num_envs: int
|
|
num_epochs: int
|
|
max_grad_norm: float | None
|
|
normalize_advantages: bool
|
|
normalize_env: bool
|
|
anneal_lr: bool
|
|
num_eval: int = 25
|
|
max_episode_steps: int = 1000
|
|
|
|
|
|
class Transition(struct.PyTreeNode):
|
|
obs: jax.Array
|
|
critic_obs: jax.Array
|
|
action: jax.Array
|
|
reward: jax.Array
|
|
log_prob: jax.Array
|
|
value: jax.Array
|
|
done: jax.Array
|
|
truncated: jax.Array
|
|
info: dict[str, jax.Array]
|
|
|
|
|
|
class PPOTrainState(nnx.TrainState):
|
|
iteration: int
|
|
time_steps: int
|
|
last_env_state: EnvState
|
|
last_obs: jax.Array
|
|
last_critic_obs: jax.Array
|
|
normalization_state: NormalizationState | None = None
|
|
critic_normalization_state: NormalizationState | None = None
|
|
|
|
|
|
class PPONetworks(nnx.Module):
|
|
def __init__(
|
|
self,
|
|
obs_dim: int,
|
|
critic_obs_dim: int,
|
|
action_dim: int,
|
|
hidden_dim: int = 64,
|
|
*,
|
|
rngs: nnx.Rngs,
|
|
):
|
|
def linear_layer(in_features, out_features, scale=jnp.sqrt(2)):
|
|
return nnx.Linear(
|
|
in_features=in_features,
|
|
out_features=out_features,
|
|
kernel_init=nnx.initializers.orthogonal(scale=scale),
|
|
bias_init=nnx.initializers.zeros_init(),
|
|
rngs=rngs,
|
|
)
|
|
|
|
self.actor_module = nnx.Sequential(
|
|
linear_layer(obs_dim, hidden_dim),
|
|
nnx.tanh,
|
|
linear_layer(hidden_dim, hidden_dim),
|
|
nnx.tanh,
|
|
linear_layer(hidden_dim, action_dim, scale=0.01),
|
|
)
|
|
self.log_std = nnx.Param(jnp.zeros(action_dim))
|
|
self.critic_module = nnx.Sequential(
|
|
linear_layer(critic_obs_dim, hidden_dim),
|
|
nnx.tanh,
|
|
linear_layer(hidden_dim, hidden_dim),
|
|
nnx.tanh,
|
|
linear_layer(hidden_dim, 1, scale=1.0),
|
|
)
|
|
|
|
def critic(self, obs: jax.Array) -> jax.Array:
|
|
return self.critic_module(obs).squeeze()
|
|
|
|
def actor(self, obs: jax.Array) -> distrax.Distribution:
|
|
loc = self.actor_module(obs)
|
|
pi = distrax.MultivariateNormalDiag(
|
|
loc=loc, scale_diag=jnp.exp(self.log_std.value)
|
|
)
|
|
return pi
|
|
|
|
|
|
def make_policy(train_state: PPOTrainState) -> Policy:
|
|
normalizer = Normalizer()
|
|
|
|
def policy(
|
|
key: PRNGKey, obs: jax.Array, state: struct.PyTreeNode = None
|
|
) -> tuple[jax.Array, jax.Array]:
|
|
if train_state.normalization_state is not None:
|
|
obs = normalizer.normalize(train_state.normalization_state, obs)
|
|
model = nnx.merge(train_state.graphdef, train_state.params)
|
|
pi = model.actor(obs)
|
|
value = model.critic(obs)
|
|
action = pi.sample(seed=key)
|
|
log_prob = pi.log_prob(action)
|
|
return action, dict(log_prob=log_prob, value=value)
|
|
|
|
return policy
|
|
|
|
|
|
def make_eval_fn(
|
|
env: Environment, max_episode_steps: int
|
|
) -> Callable[[jax.random.PRNGKey, Policy], dict[str, float]]:
|
|
def evaluation_fn(key: jax.random.PRNGKey, policy: Policy):
|
|
def step_env(carry, _):
|
|
key, env_state, obs = carry
|
|
key, act_key, env_key = jax.random.split(key, 3)
|
|
action, _ = policy(act_key, obs)
|
|
env_key = jax.random.split(env_key, env.num_envs)
|
|
obs, _, env_state, reward, done, info = env.step(
|
|
env_key, env_state, action.clip(-1.0 + 1e-4, 1.0 - 1e-4)
|
|
)
|
|
return (key, env_state, obs), info
|
|
|
|
key, init_key = jax.random.split(key)
|
|
init_key = jax.random.split(init_key, env.num_envs)
|
|
obs, _, env_state = env.reset(init_key)
|
|
_, infos = jax.lax.scan(
|
|
f=step_env,
|
|
init=(key, env_state, obs),
|
|
xs=None,
|
|
length=max_episode_steps,
|
|
)
|
|
|
|
return {
|
|
"episode_return": infos["returned_episode_returns"].mean(
|
|
where=infos["returned_episode"]
|
|
),
|
|
"episode_return_std": infos["returned_episode_returns"].std(
|
|
where=infos["returned_episode"]
|
|
),
|
|
"episode_length": infos["returned_episode_lengths"].mean(
|
|
where=infos["returned_episode"]
|
|
),
|
|
"episode_length_std": infos["returned_episode_lengths"].std(
|
|
where=infos["returned_episode"]
|
|
),
|
|
"num_episodes": infos["returned_episode"].sum(),
|
|
}
|
|
|
|
return evaluation_fn
|
|
|
|
|
|
def make_init(
|
|
cfg: PPOConfig,
|
|
env: Environment,
|
|
env_params: EnvParams = None,
|
|
) -> PPOTrainState:
|
|
def init(key: jax.random.PRNGKey) -> PPOTrainState:
|
|
# Number of calls to train_step
|
|
num_train_steps = cfg.total_time_steps // (cfg.num_steps * cfg.num_envs)
|
|
# Number of calls to train_iter, add 1 if not divisible by eval_interval
|
|
eval_interval = int(
|
|
(cfg.total_time_steps / (cfg.num_steps * cfg.num_envs)) // cfg.num_eval
|
|
)
|
|
num_iterations = num_train_steps // eval_interval + int(
|
|
num_train_steps % eval_interval != 0
|
|
)
|
|
key, model_key = jax.random.split(key)
|
|
# Intialize the model
|
|
networks = PPONetworks(
|
|
obs_dim=env.observation_space(env_params)[0].shape[0],
|
|
critic_obs_dim=env.observation_space(env_params)[1].shape[0],
|
|
action_dim=env.action_space(env_params).shape[0],
|
|
rngs=nnx.Rngs(model_key),
|
|
)
|
|
|
|
# Set initial learning rate
|
|
if not cfg.anneal_lr:
|
|
lr = cfg.lr
|
|
else:
|
|
num_iterations = cfg.total_time_steps // cfg.num_steps // cfg.num_envs
|
|
num_updates = num_iterations * cfg.num_epochs * cfg.num_mini_batches
|
|
lr = optax.linear_schedule(cfg.lr, 1e-6, num_updates)
|
|
|
|
# Initialize the optimizer
|
|
if cfg.max_grad_norm is not None:
|
|
optimizer = optax.chain(
|
|
optax.clip_by_global_norm(cfg.max_grad_norm),
|
|
optax.adam(lr),
|
|
)
|
|
else:
|
|
optimizer = optax.adam(lr)
|
|
|
|
# Reset and fully initialize the environment
|
|
key, env_key = jax.random.split(key)
|
|
env_key = jax.random.split(env_key, cfg.num_envs)
|
|
obs, critic_obs, env_state = env.reset(env_key)
|
|
# randomize initial time step to prevent all envs stepping in tandem
|
|
_env_state = env_state.unwrapped()
|
|
key, randomize_steps_key = jax.random.split(key)
|
|
_env_state.info["steps"] = jax.random.randint(
|
|
randomize_steps_key,
|
|
_env_state.info["steps"].shape,
|
|
0,
|
|
cfg.max_episode_steps,
|
|
).astype(jnp.float32)
|
|
env_state.set_env_state(_env_state)
|
|
|
|
if cfg.normalize_env:
|
|
normalizer = Normalizer()
|
|
norm_state = normalizer.init(obs)
|
|
critic_normalizer = Normalizer()
|
|
critic_norm_state = critic_normalizer.init(critic_obs)
|
|
obs = normalizer.normalize(norm_state, obs)
|
|
critic_obs = critic_normalizer.normalize(critic_norm_state, critic_obs)
|
|
else:
|
|
norm_state = None
|
|
critic_norm_state = None
|
|
|
|
# Initialize the state observations of the environment
|
|
return PPOTrainState.create(
|
|
iteration=0,
|
|
time_steps=0,
|
|
graphdef=nnx.graphdef(networks),
|
|
params=nnx.state(networks),
|
|
tx=optimizer,
|
|
last_env_state=env_state,
|
|
last_obs=obs,
|
|
last_critic_obs=critic_obs,
|
|
normalization_state=norm_state,
|
|
critic_normalization_state=critic_norm_state,
|
|
)
|
|
|
|
return init
|
|
|
|
|
|
def make_train_fn(
|
|
cfg: PPOConfig,
|
|
env: Environment,
|
|
env_params: EnvParams = None,
|
|
log_callback: Callable[[PPOTrainState, dict[str, jax.Array]], None] = None,
|
|
num_seeds: int = 1,
|
|
):
|
|
# Initialize the environment and wrap it to admit vectorized behavior.
|
|
env_params = env_params or env.default_params
|
|
env = ClipAction(env)
|
|
env = LogWrapper(env, cfg.num_envs)
|
|
eval_fn = make_eval_fn(env, cfg.max_episode_steps)
|
|
normalizer = Normalizer()
|
|
eval_interval = int(
|
|
(cfg.total_time_steps / (cfg.num_steps * cfg.num_envs)) // cfg.num_eval
|
|
)
|
|
|
|
def collect_rollout(
|
|
key: PRNGKey, train_state: PPOTrainState
|
|
) -> tuple[Transition, PPOTrainState]:
|
|
model = nnx.merge(train_state.graphdef, train_state.params)
|
|
|
|
# Take a step in the environment
|
|
def step_env(carry, _) -> tuple[tuple, Transition]:
|
|
key, env_state, train_state, obs, critic_obs = carry
|
|
|
|
if cfg.normalize_env:
|
|
norm_state = normalizer.update(train_state.normalization_state, obs)
|
|
obs = normalizer.normalize(norm_state, obs)
|
|
train_state = train_state.replace(normalization_state=norm_state)
|
|
critic_obs = normalizer.normalize(
|
|
train_state.critic_normalization_state, critic_obs
|
|
)
|
|
# Select action
|
|
key, act_key, step_key = jax.random.split(key, 3)
|
|
pi = model.actor(obs)
|
|
action = pi.sample(seed=act_key)
|
|
# Take a step in the environment
|
|
step_key = jax.random.split(step_key, cfg.num_envs)
|
|
next_obs, next_critic_obs, next_env_state, reward, done, info = env.step(
|
|
step_key, env_state, action.clip(-1.0 + 1e-4, 1.0 - 1e-4)
|
|
)
|
|
# Record the transition
|
|
transition = Transition(
|
|
obs=obs,
|
|
critic_obs=critic_obs,
|
|
action=action,
|
|
reward=reward,
|
|
log_prob=pi.log_prob(action),
|
|
value=model.critic(critic_obs),
|
|
done=done,
|
|
truncated=next_env_state.truncated,
|
|
info=info,
|
|
)
|
|
return (
|
|
key,
|
|
next_env_state,
|
|
train_state,
|
|
next_obs,
|
|
next_critic_obs,
|
|
), transition
|
|
|
|
# Collect rollout via lax.scan taking steps in the environment
|
|
rollout_state, transitions = jax.lax.scan(
|
|
f=step_env,
|
|
init=(
|
|
key,
|
|
train_state.last_env_state,
|
|
train_state,
|
|
train_state.last_obs,
|
|
train_state.last_critic_obs,
|
|
),
|
|
length=cfg.num_steps,
|
|
)
|
|
# Aggregate the transitions across all the environments to reset for the next iteration
|
|
_, last_env_state, train_state, last_obs, last_critic_obs = rollout_state
|
|
train_state = train_state.replace(
|
|
last_env_state=last_env_state,
|
|
last_obs=last_obs,
|
|
last_critic_obs=last_critic_obs,
|
|
time_steps=train_state.time_steps + cfg.num_steps * cfg.num_envs,
|
|
)
|
|
|
|
return transitions, train_state
|
|
|
|
def learn_step(
|
|
key: PRNGKey, train_state: PPOTrainState, batch: Transition
|
|
) -> tuple[PPOTrainState, dict[str, jax.Array]]:
|
|
# Compute advantages and target values
|
|
model = nnx.merge(train_state.graphdef, train_state.params)
|
|
if cfg.normalize_env:
|
|
last_critic_obs = normalizer.normalize(
|
|
train_state.critic_normalization_state, train_state.last_critic_obs
|
|
)
|
|
else:
|
|
last_critic_obs = train_state.last_critic_obs
|
|
last_value = model.critic(last_critic_obs)
|
|
|
|
def compute_advantage(carry, transition):
|
|
gae, next_value = carry
|
|
done = transition.done
|
|
truncated = transition.truncated
|
|
reward = transition.reward
|
|
value = transition.value
|
|
delta = reward + cfg.gamma * next_value * (1 - done) - value
|
|
gae = delta + cfg.gamma * cfg.lmbda * (1 - done) * gae
|
|
truncated_gae = reward + cfg.gamma * next_value - value
|
|
gae = jnp.where(truncated, truncated_gae, gae)
|
|
return (gae, value), gae
|
|
|
|
# Compute the advantage using GAE
|
|
_, advantages = jax.lax.scan(
|
|
compute_advantage,
|
|
(jnp.zeros_like(last_value), last_value),
|
|
batch,
|
|
reverse=True,
|
|
)
|
|
target_values = advantages + batch.value
|
|
|
|
data = (batch, advantages, target_values)
|
|
# Reshape data to (num_steps * num_envs, ...)
|
|
data = jax.tree.map(
|
|
lambda x: x.reshape(
|
|
(math.floor(cfg.num_steps * cfg.num_envs), *x.shape[2:])
|
|
),
|
|
data,
|
|
)
|
|
|
|
def update(train_state, key) -> tuple[PPOTrainState, dict[str, jax.Array]]:
|
|
def minibatch_update(carry, indices):
|
|
idx, train_state = carry
|
|
# Sample data at indices from the batch
|
|
minibatch, advantages, target_values = jax.tree.map(
|
|
lambda x: jnp.take(x, indices, axis=0), data
|
|
)
|
|
if cfg.normalize_advantages:
|
|
advantages = (advantages - jnp.mean(advantages)) / (
|
|
jnp.std(advantages) + 1e-8
|
|
)
|
|
|
|
# Define the loss function
|
|
def loss_fn(params):
|
|
model = nnx.merge(train_state.graphdef, params)
|
|
pi = model.actor(minibatch.obs)
|
|
value = model.critic(minibatch.critic_obs)
|
|
log_prob = pi.log_prob(minibatch.action)
|
|
value_pred_clipped = minibatch.value + (
|
|
value - minibatch.value
|
|
).clip(-cfg.clip_ratio, cfg.clip_ratio)
|
|
value_error = jnp.square(value - target_values)
|
|
value_error_clipped = jnp.square(value_pred_clipped - target_values)
|
|
value_loss = 0.5 * jnp.mean(
|
|
(1.0 - minibatch.truncated)
|
|
* jnp.maximum(value_error, value_error_clipped)
|
|
)
|
|
|
|
ratio = jnp.exp(log_prob - minibatch.log_prob)
|
|
checkify.check(
|
|
jnp.allclose(ratio, 1.0) | (idx != 1),
|
|
debug=True,
|
|
msg="Ratio not equal to 1 on first iteration: {r}",
|
|
r=ratio,
|
|
)
|
|
|
|
actor_loss1 = ratio * advantages
|
|
actor_loss2 = (
|
|
jnp.clip(ratio, 1 - cfg.clip_ratio, 1 + cfg.clip_ratio)
|
|
* advantages
|
|
)
|
|
actor_loss = -jnp.mean(
|
|
(1.0 - minibatch.truncated)
|
|
* jnp.minimum(actor_loss1, actor_loss2)
|
|
)
|
|
entropy_loss = jnp.mean(pi.entropy())
|
|
|
|
loss = (
|
|
actor_loss
|
|
+ cfg.value_coef * value_loss
|
|
- cfg.entropy_coef * entropy_loss
|
|
)
|
|
|
|
return loss, dict(
|
|
actor_loss=actor_loss,
|
|
value_loss=value_loss,
|
|
entropy_loss=entropy_loss,
|
|
loss=loss,
|
|
mean_value=value.mean(),
|
|
mean_log_prob=log_prob.mean(),
|
|
mean_advantages=advantages.mean(),
|
|
mean_action=minibatch.action.mean(),
|
|
mean_reward=minibatch.reward.mean(),
|
|
)
|
|
|
|
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
|
|
output, grads = grad_fn(train_state.params)
|
|
|
|
# Global gradient norm (all parameters combined)
|
|
flat_grads, _ = jax.flatten_util.ravel_pytree(grads)
|
|
global_grad_norm = jnp.linalg.norm(flat_grads)
|
|
|
|
metrics = output[1]
|
|
metrics["advantages"] = advantages
|
|
metrics["global_grad_norm"] = global_grad_norm
|
|
train_state = train_state.apply_gradients(grads)
|
|
return (idx + 1, train_state), metrics
|
|
|
|
# Shuffle data and split into mini-batches
|
|
key, shuffle_key = jax.random.split(key)
|
|
|
|
mini_batch_size = (
|
|
math.floor(cfg.num_steps * cfg.num_envs) // cfg.num_mini_batches
|
|
)
|
|
indices = jax.random.permutation(shuffle_key, cfg.num_steps * cfg.num_envs)
|
|
minibatch_idxs = jax.tree.map(
|
|
lambda x: x.reshape(
|
|
(cfg.num_mini_batches, mini_batch_size, *x.shape[1:])
|
|
),
|
|
indices,
|
|
)
|
|
|
|
# Run model update for each mini-batch
|
|
train_state, metrics = jax.lax.scan(
|
|
minibatch_update, train_state, minibatch_idxs
|
|
)
|
|
# Compute mean metrics across mini-batches
|
|
metrics = jax.tree.map(lambda x: x.mean(0), metrics)
|
|
return train_state, metrics
|
|
|
|
# Update the model for a number of epochs
|
|
key, train_key = jax.random.split(key)
|
|
(_, train_state), update_metrics = jax.lax.scan(
|
|
f=update,
|
|
init=(1, train_state),
|
|
xs=jax.random.split(train_key, cfg.num_epochs),
|
|
)
|
|
# Get metrics from the last epoch
|
|
update_metrics = jax.tree.map(lambda x: x[-1], update_metrics)
|
|
|
|
return train_state, update_metrics
|
|
|
|
# Define the training loop
|
|
def train_fn(key: PRNGKey) -> tuple[PPOTrainState, dict]:
|
|
def train_eval_step(key, train_state):
|
|
def train_step(
|
|
state: PPOTrainState, key: PRNGKey
|
|
) -> tuple[PPOTrainState, dict[str, jax.Array]]:
|
|
key, rollout_key, learn_key = jax.random.split(key, 3)
|
|
# Collect trajectories from `state`
|
|
transitions, state = collect_rollout(key=rollout_key, train_state=state)
|
|
# Execute an update to the policy with `transitions`
|
|
state, update_metrics = learn_step(
|
|
key=learn_key, train_state=state, batch=transitions
|
|
)
|
|
metrics = {**update_metrics, **update_metrics}
|
|
state = state.replace(iteration=state.iteration + 1)
|
|
return state, metrics
|
|
|
|
train_key, eval_key = jax.random.split(key)
|
|
train_state, train_metrics = jax.lax.scan(
|
|
f=train_step,
|
|
init=train_state,
|
|
xs=jax.random.split(train_key, eval_interval),
|
|
)
|
|
train_metrics = jax.tree.map(lambda x: x[-1], train_metrics)
|
|
policy = make_policy(train_state)
|
|
eval_metrics = eval_fn(eval_key, policy)
|
|
metrics = {
|
|
"time_step": train_state.time_steps,
|
|
**utils.prefix_dict("train", train_metrics),
|
|
**utils.prefix_dict("eval", eval_metrics),
|
|
}
|
|
|
|
return train_state, metrics
|
|
|
|
def loop_body(
|
|
train_state: PPOTrainState, key: PRNGKey
|
|
) -> tuple[PPOTrainState, dict]:
|
|
# Map execution of the train+eval step across num_seeds (will be looped using jax.lax.scan)
|
|
key, subkey = jax.random.split(key)
|
|
train_state, metrics = jax.vmap(train_eval_step)(
|
|
jax.random.split(subkey, num_seeds), train_state
|
|
)
|
|
jax.debug.callback(log_callback, train_state, metrics)
|
|
return train_state, metrics
|
|
|
|
# Initialize the policy, environment and map that across the number of random seeds
|
|
num_train_steps = cfg.total_time_steps // (cfg.num_steps * cfg.num_envs)
|
|
num_iterations = num_train_steps // eval_interval + int(
|
|
num_train_steps % eval_interval != 0
|
|
)
|
|
key, init_key = jax.random.split(key)
|
|
train_state = jax.vmap(make_init(cfg, env, env_params))(
|
|
jax.random.split(init_key, num_seeds)
|
|
)
|
|
keys = jax.random.split(key, num_iterations)
|
|
# Run the training and evaluation loop from the initialized training state
|
|
state, metrics = jax.lax.scan(f=loop_body, init=train_state, xs=keys)
|
|
return state, metrics
|
|
|
|
return train_fn
|
|
|
|
|
|
def plot_history(history: list[dict[str, jax.Array]]):
|
|
steps = jnp.array([m["time_step"][0] for m in history])
|
|
eval_return = jnp.array([m["eval/episode_return"].mean() for m in history])
|
|
eval_return_std = jnp.array([m["eval/episode_return"].std() for m in history])
|
|
fig = go.Figure(
|
|
[
|
|
go.Scatter(
|
|
x=steps,
|
|
y=eval_return,
|
|
name="Mean Episode Return",
|
|
mode="lines",
|
|
line=dict(color="blue"),
|
|
showlegend=False,
|
|
),
|
|
go.Scatter(
|
|
x=steps,
|
|
y=eval_return + eval_return_std,
|
|
name="Upper Bound",
|
|
mode="lines",
|
|
line=dict(width=0),
|
|
showlegend=False,
|
|
),
|
|
go.Scatter(
|
|
x=steps,
|
|
y=eval_return - eval_return_std,
|
|
name="Lower Bound",
|
|
mode="lines",
|
|
line=dict(width=0),
|
|
fill="tonexty",
|
|
fillcolor="rgba(50, 127, 168, 0.3)",
|
|
showlegend=False,
|
|
),
|
|
]
|
|
)
|
|
fig.update_layout(
|
|
xaxis=dict(title=dict(text="Environment Steps")),
|
|
)
|
|
|
|
return fig
|
|
|
|
|
|
def run(cfg: DictConfig):
|
|
metric_history = []
|
|
|
|
# Define callback to log metrics during training
|
|
def log_callback(state, metrics):
|
|
metrics["sys_time"] = time.perf_counter()
|
|
if len(metric_history) > 0:
|
|
num_env_steps = state.time_steps[0] - metric_history[-1]["time_step"][0]
|
|
seconds = metrics["sys_time"] - metric_history[-1]["sys_time"]
|
|
sps = num_env_steps / seconds
|
|
else:
|
|
sps = 0
|
|
|
|
metric_history.append(metrics)
|
|
episode_return = metrics["eval/episode_return"].mean()
|
|
# Use pop() with a default value of None in case 'advantages' key doesn't exist
|
|
advantages = metrics.pop("train/advantages", None)
|
|
logging.info(
|
|
f"step={state.time_steps[0]} episode_return={episode_return:.3f}, sps={sps:.2f}"
|
|
)
|
|
log_data = {
|
|
"eval/episode_return": episode_return,
|
|
"train/advantages": wandb.Histogram(advantages),
|
|
**jax.tree.map(jnp.mean, utils.filter_prefix("train", metrics)),
|
|
}
|
|
# Push log data to WandB
|
|
wandb.log(log_data, step=state.time_steps[0])
|
|
|
|
logging.info(OmegaConf.to_yaml(cfg))
|
|
|
|
# Set up the experimental environment
|
|
if cfg.env.type == "brax":
|
|
env = BraxGymnaxWrapper(
|
|
cfg.env.name
|
|
) # , episode_length=cfg.env.max_episode_steps
|
|
elif cfg.env.type == "mjx":
|
|
env = MjxGymnaxWrapper(cfg.env.name, episode_length=cfg.env.max_episode_steps)
|
|
else:
|
|
raise ValueError(f"Unknown environment type: {cfg.env.type}")
|
|
|
|
key = jax.random.PRNGKey(cfg.seed)
|
|
train_fn = make_train_fn(
|
|
cfg=PPOConfig(**cfg.hyperparameters),
|
|
env=env,
|
|
log_callback=log_callback,
|
|
num_seeds=cfg.num_seeds,
|
|
)
|
|
for i in range(cfg.trials):
|
|
# Initialize WandB reporting
|
|
key, train_key = jax.random.split(key)
|
|
wandb.init(
|
|
mode=cfg.wandb.mode,
|
|
project=cfg.wandb.project,
|
|
entity=cfg.wandb.entity,
|
|
tags=[cfg.name, cfg.env.name, cfg.env.type, *cfg.tags],
|
|
config=OmegaConf.to_container(cfg),
|
|
name=f"ppo-{cfg.name}-{cfg.env.name.lower()}",
|
|
save_code=True,
|
|
)
|
|
start = time.perf_counter()
|
|
train_state, metrics = jax.jit(train_fn)(train_key)
|
|
jax.block_until_ready(metrics)
|
|
duration = time.perf_counter() - start
|
|
|
|
# Save metrics and finish the run
|
|
logging.info(f"Training took {duration:.2f} seconds.")
|
|
# jnp.savez("metrics.npz", **metrics) # TODO: fix the directory here to save to a unique output directory
|
|
wandb.finish()
|
|
|
|
|
|
def tune(cfg: DictConfig):
|
|
def log_callback(state, metrics):
|
|
episode_return = metrics["eval/episode_return"].mean()
|
|
t = state.time_steps[0]
|
|
wandb.log(
|
|
{
|
|
"episode_return": episode_return,
|
|
},
|
|
step=t,
|
|
)
|
|
|
|
env = MjxGymnaxWrapper(cfg.env.name, episode_length=cfg.env.max_episode_steps)
|
|
|
|
def train_agent():
|
|
wandb.init(project=cfg.wandb.project)
|
|
run_cfg = OmegaConf.to_container(cfg)
|
|
for k, v in dict(wandb.config).items():
|
|
run_cfg["experiment"]["hyperparameters"][k] = v
|
|
ppo_cfg = PPOConfig(**run_cfg["experiment"]["hyperparameters"])
|
|
train_fn = make_train_fn(
|
|
cfg=ppo_cfg,
|
|
env=env,
|
|
log_callback=log_callback,
|
|
num_seeds=cfg.num_seeds,
|
|
)
|
|
train_fn = jax.jit(train_fn)
|
|
logging.info(f"Running experiment with params: \n {run_cfg}")
|
|
key = jax.random.PRNGKey(cfg.seed)
|
|
train_state, metrics = train_fn(key)
|
|
jax.block_until_ready(metrics)
|
|
|
|
sweep_id = wandb.sweep(
|
|
sweep={
|
|
"name": f"{cfg.name}-{cfg.env.name}",
|
|
"method": "bayes",
|
|
"metric": {"name": "episode_return", "goal": "maximize"},
|
|
"parameters": {
|
|
"lr": {
|
|
"values": [1e-4, 3e-4, 1e-3],
|
|
},
|
|
"normalize_env": {
|
|
"values": [True, False],
|
|
},
|
|
},
|
|
},
|
|
project=cfg.wandb.project,
|
|
entity=cfg.wandb.entity,
|
|
)
|
|
wandb.agent(sweep_id, function=train_agent, count=cfg.tune.num_runs)
|
|
|
|
|
|
@hydra.main(version_base=None, config_path="../../config", config_name="ppo")
|
|
def main(cfg: DictConfig):
|
|
if cfg.tune:
|
|
tune(cfg)
|
|
else:
|
|
run(cfg)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|