- Fix missing MUON optimizer by replacing with optax.adam - Fix Hydra configuration parameter paths (env.name instead of env_name) - Fix BraxGymnaxWrapper method signatures to accept params argument - Fix training loop division by zero with proper total_time_steps - Fix incorrect algorithm name in wandb (reppo instead of sac) - Fix JAX key batching error in BraxGymnaxWrapper reset method - Add comprehensive HoReKa SLURM integration with wandb logging - Update README with detailed bug documentation and fixes
925 lines
34 KiB
Python
925 lines
34 KiB
Python
import logging
|
|
import time
|
|
import typing
|
|
from typing import Callable
|
|
|
|
import hydra
|
|
import jax
|
|
import numpy as np
|
|
import optax
|
|
import optuna
|
|
import plotly.graph_objs as go
|
|
from flax import nnx, struct
|
|
from flax.struct import PyTreeNode
|
|
from gymnax.environments.environment import Environment, EnvParams, EnvState
|
|
from jax import numpy as jnp
|
|
from jax.random import PRNGKey
|
|
from omegaconf import DictConfig, OmegaConf
|
|
|
|
import wandb
|
|
from reppo_alg.env_utils.jax_wrappers import (
|
|
BraxGymnaxWrapper,
|
|
ClipAction,
|
|
LogWrapper,
|
|
MjxGymnaxWrapper,
|
|
NormalizeVec,
|
|
)
|
|
from reppo_alg.jaxrl import utils
|
|
from reppo_alg.network_utils.jax_models import (
|
|
CategoricalCriticNetwork,
|
|
CriticNetwork,
|
|
SACActorNetworks,
|
|
)
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
|
|
class Policy(typing.Protocol):
|
|
def __call__(
|
|
self,
|
|
key: jax.random.PRNGKey,
|
|
obs: PyTreeNode,
|
|
) -> tuple[PyTreeNode, PyTreeNode]:
|
|
pass
|
|
|
|
|
|
class Transition(struct.PyTreeNode):
|
|
obs: jax.Array
|
|
critic_obs: jax.Array
|
|
action: jax.Array
|
|
reward: jax.Array
|
|
next_emb: jax.Array
|
|
value: jax.Array
|
|
done: jax.Array
|
|
truncated: jax.Array
|
|
importance_weight: jax.Array
|
|
info: dict[str, jax.Array]
|
|
|
|
|
|
class ReppoConfig(struct.PyTreeNode):
|
|
lr: float
|
|
gamma: float
|
|
total_time_steps: int
|
|
num_steps: int
|
|
lmbda: float
|
|
lmbda_min: float
|
|
num_mini_batches: int
|
|
num_envs: int
|
|
num_epochs: int
|
|
max_grad_norm: float | None
|
|
normalize_env: bool
|
|
polyak: float
|
|
exploration_noise_min: float
|
|
exploration_noise_max: float
|
|
exploration_base_envs: int
|
|
ent_start: float
|
|
ent_target_mult: float
|
|
kl_start: float
|
|
eval_interval: int = 10
|
|
num_eval: int = 25
|
|
max_episode_steps: int = 1000
|
|
critic_hidden_dim: int = 512
|
|
actor_hidden_dim: int = 512
|
|
vmin: int = -100
|
|
vmax: int = 100
|
|
num_bins: int = 250
|
|
hl_gauss: bool = False
|
|
kl_bound: float = 1.0
|
|
aux_loss_mult: float = 0.0
|
|
update_kl_lagrangian: bool = True
|
|
update_entropy_lagrangian: bool = True
|
|
use_critic_norm: bool = True
|
|
num_critic_encoder_layers: int = 1
|
|
num_critic_head_layers: int = 1
|
|
num_critic_pred_layers: int = 1
|
|
use_simplical_embedding: bool = False
|
|
use_actor_norm: bool = True
|
|
num_actor_layers: int = 2
|
|
actor_min_std: float = 0.05
|
|
reduce_kl: bool = True
|
|
reverse_kl: bool = False
|
|
anneal_lr: bool = False
|
|
actor_kl_clip_mode: str = "clipped"
|
|
|
|
|
|
class SACTrainState(struct.PyTreeNode):
|
|
critic: nnx.TrainState
|
|
actor: nnx.TrainState
|
|
actor_target: nnx.TrainState
|
|
iteration: int
|
|
time_steps: int
|
|
last_env_state: EnvState
|
|
last_obs: jax.Array
|
|
last_critic_obs: jax.Array
|
|
|
|
|
|
def make_policy(
|
|
train_state: SACTrainState,
|
|
) -> Callable[[jax.Array, jax.Array], tuple[jax.Array, dict]]:
|
|
def policy(key: PRNGKey, obs: jax.Array) -> tuple[jax.Array, dict]:
|
|
actor_model = nnx.merge(train_state.actor.graphdef, train_state.actor.params)
|
|
action: jax.Array = actor_model.det_action(obs)
|
|
return action, {}
|
|
|
|
return policy
|
|
|
|
|
|
def make_eval_fn(
|
|
env: Environment, max_episode_steps: int, reward_scale: float = 1.0
|
|
) -> Callable[[jax.random.PRNGKey, Policy, PyTreeNode | None], dict[str, float]]:
|
|
def evaluation_fn(
|
|
key: jax.random.PRNGKey, policy: Policy, norm_state: PyTreeNode | None
|
|
):
|
|
def step_env(carry, _):
|
|
key, env_state, obs = carry
|
|
key, act_key, env_key = jax.random.split(key, 3)
|
|
action, _ = policy(act_key, obs)
|
|
step_key = jax.random.split(env_key, env.num_envs)
|
|
obs, _, env_state, reward, done, info = env.step(
|
|
step_key, env_state, action
|
|
)
|
|
return (key, env_state, obs), info
|
|
|
|
key, init_key = jax.random.split(key)
|
|
init_key = jax.random.split(init_key, env.num_envs)
|
|
obs, _, env_state = env.reset(init_key, norm_state)
|
|
# randomize initial steps
|
|
key, env_key = jax.random.split(key)
|
|
_, infos = jax.lax.scan(
|
|
f=step_env,
|
|
init=(key, env_state, obs),
|
|
xs=None,
|
|
length=max_episode_steps,
|
|
)
|
|
|
|
return {
|
|
"episode_return": infos["returned_episode_returns"].mean(
|
|
where=infos["returned_episode"]
|
|
)
|
|
* reward_scale,
|
|
"episode_return_std": infos["returned_episode_returns"].std(
|
|
where=infos["returned_episode"]
|
|
),
|
|
"episode_length": infos["returned_episode_lengths"].mean(
|
|
where=infos["returned_episode"]
|
|
),
|
|
"episode_length_std": infos["returned_episode_lengths"].std(
|
|
where=infos["returned_episode"]
|
|
),
|
|
"num_episodes": infos["returned_episode"].sum(),
|
|
}
|
|
|
|
return evaluation_fn
|
|
|
|
|
|
def make_init(
|
|
cfg: ReppoConfig,
|
|
env: Environment,
|
|
env_params: EnvParams = None,
|
|
) -> Callable[[jax.Array], SACTrainState]:
|
|
def init(key: jax.random.PRNGKey) -> SACTrainState:
|
|
# Number of calls to train_step
|
|
key, model_key = jax.random.split(key)
|
|
actor_networks = SACActorNetworks(
|
|
obs_dim=env.observation_space(env_params)[0].shape[0],
|
|
action_dim=env.action_space(env_params).shape[0],
|
|
hidden_dim=cfg.actor_hidden_dim,
|
|
ent_start=cfg.ent_start,
|
|
kl_start=cfg.kl_start,
|
|
use_norm=cfg.use_actor_norm,
|
|
layers=cfg.num_actor_layers,
|
|
rngs=nnx.Rngs(model_key),
|
|
)
|
|
actor_target_networks = SACActorNetworks(
|
|
obs_dim=env.observation_space(env_params)[0].shape[0],
|
|
action_dim=env.action_space(env_params).shape[0],
|
|
hidden_dim=cfg.actor_hidden_dim,
|
|
ent_start=cfg.ent_start,
|
|
kl_start=cfg.kl_start,
|
|
use_norm=cfg.use_actor_norm,
|
|
layers=cfg.num_actor_layers,
|
|
rngs=nnx.Rngs(model_key),
|
|
)
|
|
|
|
if cfg.hl_gauss:
|
|
critic_networks: nnx.Module = CategoricalCriticNetwork(
|
|
obs_dim=env.observation_space(env_params)[1].shape[0],
|
|
action_dim=env.action_space(env_params).shape[0],
|
|
hidden_dim=cfg.critic_hidden_dim,
|
|
num_bins=cfg.num_bins,
|
|
vmin=cfg.vmin,
|
|
vmax=cfg.vmax,
|
|
use_norm=cfg.use_critic_norm,
|
|
encoder_layers=cfg.num_critic_encoder_layers,
|
|
use_simplical_embedding=cfg.use_simplical_embedding,
|
|
head_layers=cfg.num_critic_head_layers,
|
|
pred_layers=cfg.num_critic_pred_layers,
|
|
rngs=nnx.Rngs(model_key),
|
|
)
|
|
else:
|
|
critic_networks: nnx.Module = CriticNetwork(
|
|
obs_dim=env.observation_space(env_params)[1].shape[0],
|
|
action_dim=env.action_space(env_params).shape[0],
|
|
hidden_dim=cfg.critic_hidden_dim,
|
|
use_norm=cfg.use_critic_norm,
|
|
encoder_layers=cfg.num_critic_encoder_layers,
|
|
use_simplical_embedding=cfg.use_simplical_embedding,
|
|
head_layers=cfg.num_critic_head_layers,
|
|
pred_layers=cfg.num_critic_pred_layers,
|
|
rngs=nnx.Rngs(model_key),
|
|
)
|
|
|
|
if not cfg.anneal_lr:
|
|
lr = cfg.lr
|
|
else:
|
|
num_iterations = cfg.total_time_steps // cfg.num_steps // cfg.num_envs
|
|
num_updates = num_iterations * cfg.num_epochs * cfg.num_mini_batches
|
|
lr = optax.linear_schedule(cfg.lr, 0, num_updates)
|
|
|
|
if cfg.max_grad_norm is not None:
|
|
actor_optimizer = optax.chain(
|
|
optax.clip_by_global_norm(cfg.max_grad_norm),
|
|
optax.adam(lr), # optax.adam(lr) optax.adam(lr)
|
|
)
|
|
critic_optimizer = optax.chain(
|
|
optax.clip_by_global_norm(cfg.max_grad_norm),
|
|
optax.adam(lr), # optax.adam(lr) optax.adam(lr)
|
|
)
|
|
else:
|
|
actor_optimizer = optax.adam(lr) # optax.adam(lr)
|
|
critic_optimizer = optax.adam(lr) # optax.adam(lr)
|
|
|
|
actor_trainstate = nnx.TrainState.create(
|
|
graphdef=nnx.graphdef(actor_networks),
|
|
params=nnx.state(actor_networks),
|
|
tx=actor_optimizer,
|
|
)
|
|
actor_target_trainstate = nnx.TrainState.create(
|
|
graphdef=nnx.graphdef(actor_target_networks),
|
|
params=nnx.state(actor_target_networks),
|
|
tx=optax.set_to_zero(),
|
|
)
|
|
critic_trainstate = nnx.TrainState.create(
|
|
graphdef=nnx.graphdef(critic_networks),
|
|
params=nnx.state(critic_networks),
|
|
tx=critic_optimizer,
|
|
)
|
|
|
|
key, env_key = jax.random.split(key)
|
|
env_key = jax.random.split(env_key, cfg.num_envs)
|
|
obs, critic_obs, env_state = env.reset(key=env_key, params=env_params)
|
|
|
|
# randomize initial time step to prevent all envs stepping in tandem
|
|
_env_state = env_state.unwrapped()
|
|
key, randomize_steps_key = jax.random.split(key)
|
|
_env_state.info["steps"] = jax.random.randint(
|
|
randomize_steps_key,
|
|
_env_state.info["steps"].shape,
|
|
0,
|
|
cfg.max_episode_steps,
|
|
).astype(jnp.float32)
|
|
env_state.set_env_state(_env_state)
|
|
|
|
return SACTrainState(
|
|
actor=actor_trainstate,
|
|
actor_target=actor_target_trainstate,
|
|
critic=critic_trainstate,
|
|
iteration=0,
|
|
time_steps=0,
|
|
last_env_state=env_state,
|
|
last_obs=obs,
|
|
last_critic_obs=critic_obs,
|
|
)
|
|
|
|
return init
|
|
|
|
|
|
def make_train_fn(
|
|
cfg: ReppoConfig,
|
|
env: Environment,
|
|
env_params: EnvParams = None,
|
|
log_callback: Callable[[SACTrainState, dict[str, jax.Array]], None] | None = None,
|
|
num_seeds: int = 1,
|
|
reward_scale: float = 1.0,
|
|
):
|
|
env_params = env_params # or env.default_params
|
|
env = LogWrapper(env, cfg.num_envs)
|
|
env = ClipAction(env)
|
|
# env = VecEnv(env, cfg.num_envs)
|
|
if cfg.normalize_env:
|
|
env = NormalizeVec(env)
|
|
eval_fn = make_eval_fn(env, cfg.max_episode_steps, reward_scale=reward_scale)
|
|
action_size_target = (
|
|
jnp.prod(jnp.array(env.action_space(env_params).shape)) * cfg.ent_target_mult
|
|
)
|
|
|
|
def collect_rollout(
|
|
key: PRNGKey, train_state: SACTrainState
|
|
) -> tuple[Transition, SACTrainState]:
|
|
actor_model = nnx.merge(train_state.actor.graphdef, train_state.actor.params)
|
|
critic_model = nnx.merge(train_state.critic.graphdef, train_state.critic.params)
|
|
|
|
offset = (
|
|
jnp.arange(cfg.num_envs - cfg.exploration_base_envs)[:, None]
|
|
* (cfg.exploration_noise_max - cfg.exploration_noise_min)
|
|
/ (cfg.num_envs - cfg.exploration_base_envs)
|
|
) + cfg.exploration_noise_min
|
|
offset = jnp.concatenate(
|
|
[
|
|
jnp.ones((cfg.exploration_base_envs, 1)) * cfg.exploration_noise_min,
|
|
offset,
|
|
],
|
|
axis=0,
|
|
)
|
|
|
|
def step_env(carry, _) -> tuple[tuple, Transition]:
|
|
key, env_state, train_state, obs, critic_obs = carry
|
|
key, act_key, step_key = jax.random.split(key, 3)
|
|
step_key = jax.random.split(step_key, cfg.num_envs)
|
|
|
|
# get policy action
|
|
og_pi = actor_model.actor(obs)
|
|
pi = actor_model.actor(obs, scale=offset)
|
|
action = pi.sample(seed=act_key)
|
|
|
|
next_obs, next_critic_obs, next_env_state, reward, done, info = env.step(
|
|
step_key, env_state, action
|
|
)
|
|
|
|
# compute importance weights
|
|
action = jnp.clip(action, -0.999, 0.999)
|
|
raw_importance_weight = jnp.nan_to_num(
|
|
og_pi.log_prob(action).sum(-1) - pi.log_prob(action).sum(-1),
|
|
nan=jnp.log(cfg.lmbda_min),
|
|
)
|
|
importance_weight = jnp.clip(
|
|
raw_importance_weight, min=jnp.log(cfg.lmbda_min), max=jnp.log(1.0)
|
|
)
|
|
|
|
# compute next state embedding and value
|
|
next_action, log_prob = actor_model.actor(next_obs).sample_and_log_prob(
|
|
seed=act_key
|
|
)
|
|
next_emb, value = critic_model.forward(next_critic_obs, next_action)
|
|
reward = (
|
|
reward
|
|
- cfg.gamma * log_prob.sum(-1).squeeze() * actor_model.temperature()
|
|
)
|
|
transition = Transition(
|
|
obs=obs,
|
|
critic_obs=critic_obs,
|
|
action=action,
|
|
next_emb=next_emb,
|
|
reward=reward,
|
|
value=value,
|
|
done=done,
|
|
truncated=next_env_state.truncated,
|
|
info=info,
|
|
importance_weight=importance_weight,
|
|
)
|
|
return (
|
|
key,
|
|
next_env_state,
|
|
train_state,
|
|
next_obs,
|
|
next_critic_obs,
|
|
), transition
|
|
|
|
rollout_state, transitions = jax.lax.scan(
|
|
f=step_env,
|
|
init=(
|
|
key,
|
|
train_state.last_env_state,
|
|
train_state,
|
|
train_state.last_obs,
|
|
train_state.last_critic_obs,
|
|
),
|
|
length=cfg.num_steps,
|
|
)
|
|
_, last_env_state, train_state, last_obs, last_critic_obs = rollout_state
|
|
train_state = train_state.replace(
|
|
last_env_state=last_env_state,
|
|
last_obs=last_obs,
|
|
last_critic_obs=last_critic_obs,
|
|
time_steps=train_state.time_steps + cfg.num_steps * cfg.num_envs,
|
|
)
|
|
|
|
return transitions, train_state
|
|
|
|
def learn_step(
|
|
key: PRNGKey, train_state: SACTrainState, batch: Transition
|
|
) -> tuple[SACTrainState, dict[str, jax.Array]]:
|
|
# compute n-step lambda estimates
|
|
|
|
def compute_nstep_lambda(carry, transition):
|
|
lambda_return, truncated, importance_weight = carry
|
|
# combine importance_weights with TD lambda
|
|
done = transition.done
|
|
reward = transition.reward
|
|
value = transition.value
|
|
lambda_sum = (
|
|
jnp.exp(importance_weight) * cfg.lmbda * lambda_return
|
|
+ (1 - jnp.exp(importance_weight) * cfg.lmbda) * value
|
|
)
|
|
delta = cfg.gamma * jnp.where(truncated, value, (1.0 - done) * lambda_sum)
|
|
lambda_return = reward + delta
|
|
truncated = transition.truncated
|
|
return (
|
|
lambda_return,
|
|
truncated,
|
|
transition.importance_weight,
|
|
), lambda_return
|
|
|
|
_, target_values = jax.lax.scan(
|
|
compute_nstep_lambda,
|
|
(
|
|
batch.value[-1],
|
|
jnp.ones_like(batch.truncated[0]),
|
|
jnp.zeros_like(batch.importance_weight[0]),
|
|
),
|
|
batch,
|
|
reverse=True,
|
|
)
|
|
# Reshape data to (num_steps * num_envs, ...)
|
|
data = (batch, target_values)
|
|
data = jax.tree.map(
|
|
lambda x: x.reshape((cfg.num_steps * cfg.num_envs, *x.shape[2:])), data
|
|
)
|
|
|
|
train_state = train_state.replace(
|
|
actor_target=train_state.actor_target.replace(
|
|
params=train_state.actor.params
|
|
),
|
|
)
|
|
actor_target_model = nnx.merge(
|
|
train_state.actor_target.graphdef, train_state.actor_target.params
|
|
)
|
|
|
|
def update(train_state, key) -> tuple[SACTrainState, dict[str, jax.Array]]:
|
|
def minibatch_update(carry, indices):
|
|
idx, train_state = carry
|
|
# Sample data at indices from the batch
|
|
minibatch, target_values = jax.tree.map(
|
|
lambda x: jnp.take(x, indices, axis=0), data
|
|
)
|
|
|
|
def critic_loss_fn(params):
|
|
critic_model = nnx.merge(train_state.critic.graphdef, params)
|
|
critic_pred = critic_model.critic_cat(
|
|
minibatch.critic_obs, minibatch.action
|
|
).squeeze()
|
|
if cfg.hl_gauss:
|
|
target_cat = jax.vmap(
|
|
utils.hl_gauss, in_axes=(0, None, None, None)
|
|
)(target_values, cfg.num_bins, cfg.vmin, cfg.vmax)
|
|
critic_update_loss = optax.softmax_cross_entropy(
|
|
critic_pred, target_cat
|
|
)
|
|
else:
|
|
critic_update_loss = optax.squared_error(
|
|
critic_pred,
|
|
target_values,
|
|
)
|
|
|
|
# Aux loss
|
|
pred, value = critic_model.forward(
|
|
minibatch.critic_obs, minibatch.action
|
|
)
|
|
aux_loss = jnp.mean(
|
|
(1 - minibatch.done.reshape(-1, 1))
|
|
* (pred - minibatch.next_emb) ** 2,
|
|
axis=-1,
|
|
)
|
|
|
|
# compute l2 error for logging
|
|
critic_loss = optax.squared_error(
|
|
value,
|
|
target_values,
|
|
)
|
|
critic_loss = jnp.mean(critic_loss)
|
|
loss = jnp.mean(
|
|
(1.0 - minibatch.truncated)
|
|
* (critic_update_loss + cfg.aux_loss_mult * aux_loss)
|
|
)
|
|
return loss, dict(
|
|
value_loss=critic_loss,
|
|
critic_update_loss=critic_update_loss,
|
|
loss=loss,
|
|
aux_loss=aux_loss,
|
|
q=critic_pred.mean(),
|
|
abs_batch_action=jnp.abs(minibatch.action).mean(),
|
|
reward_mean=minibatch.reward.mean(),
|
|
target_values=target_values.mean(),
|
|
)
|
|
|
|
def actor_loss(params):
|
|
critic_target_model = nnx.merge(
|
|
train_state.critic.graphdef,
|
|
train_state.critic.params,
|
|
)
|
|
actor_model = nnx.merge(train_state.actor.graphdef, params)
|
|
|
|
# SAC actor loss
|
|
pi = actor_model.actor(minibatch.obs)
|
|
pred_action, log_prob = pi.sample_and_log_prob(seed=key)
|
|
value = critic_target_model.critic(
|
|
minibatch.critic_obs, pred_action
|
|
)
|
|
log_prob = log_prob.sum(-1)
|
|
entropy = -log_prob
|
|
|
|
# policy KL constraint
|
|
if cfg.reverse_kl:
|
|
pi_action, pi_act_log_prob = pi.sample_and_log_prob(
|
|
sample_shape=(16,), seed=key
|
|
)
|
|
pi_action = jnp.clip(pi_action, -1 + 1e-4, 1 - 1e-4)
|
|
|
|
old_pi = actor_target_model.actor(minibatch.obs)
|
|
|
|
old_pi_act_log_prob = old_pi.log_prob(pi_action).sum(-1).mean(0)
|
|
pi_act_log_prob = pi_act_log_prob.sum(-1).mean(0)
|
|
kl = pi_act_log_prob - old_pi_act_log_prob
|
|
else:
|
|
old_pi_action, old_pi_act_log_prob = actor_target_model.actor(
|
|
minibatch.obs
|
|
).sample_and_log_prob(sample_shape=(16,), seed=key)
|
|
old_pi_action = jnp.clip(old_pi_action, -1 + 1e-4, 1 - 1e-4)
|
|
|
|
old_pi_act_log_prob = old_pi_act_log_prob.sum(-1).mean(0)
|
|
pi_act_log_prob = pi.log_prob(old_pi_action).sum(-1).mean(0)
|
|
|
|
kl = old_pi_act_log_prob - pi_act_log_prob
|
|
|
|
lagrangian = actor_model.lagrangian()
|
|
|
|
if cfg.actor_kl_clip_mode == "full":
|
|
actor_loss = (
|
|
log_prob * jax.lax.stop_gradient(actor_model.temperature())
|
|
- value
|
|
+ kl * jax.lax.stop_gradient(lagrangian) * cfg.reduce_kl
|
|
)
|
|
elif cfg.actor_kl_clip_mode == "clipped":
|
|
actor_loss = jnp.where(
|
|
kl < cfg.kl_bound,
|
|
log_prob * jax.lax.stop_gradient(actor_model.temperature())
|
|
- value,
|
|
kl * jax.lax.stop_gradient(lagrangian) * cfg.reduce_kl,
|
|
)
|
|
elif cfg.actor_kl_clip_mode == "value":
|
|
actor_loss = (
|
|
log_prob * jax.lax.stop_gradient(actor_model.temperature())
|
|
- value
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
f"Unknown actor loss mode: {cfg.actor_kl_clip_mode}"
|
|
)
|
|
|
|
# SAC target entropy loss
|
|
target_entropy = action_size_target + entropy
|
|
target_entropy_loss = (
|
|
actor_model.temperature()
|
|
* jax.lax.stop_gradient(target_entropy)
|
|
)
|
|
|
|
# Lagrangian constraint (follows temperature update)
|
|
lagrangian_loss = -lagrangian * jax.lax.stop_gradient(
|
|
kl - cfg.kl_bound
|
|
)
|
|
|
|
# total loss
|
|
loss = jnp.mean(actor_loss)
|
|
if cfg.update_entropy_lagrangian:
|
|
loss += jnp.mean(target_entropy_loss)
|
|
if cfg.update_kl_lagrangian:
|
|
loss += jnp.mean(lagrangian_loss)
|
|
|
|
return loss, dict(
|
|
actor_loss=actor_loss,
|
|
loss=loss,
|
|
temp=actor_model.temperature(),
|
|
abs_batch_action=jnp.abs(minibatch.action).mean(),
|
|
abs_pred_action=jnp.abs(pred_action).mean(),
|
|
reward_mean=minibatch.reward.mean(),
|
|
kl=kl.mean(),
|
|
lagrangian=lagrangian,
|
|
lagrangian_loss=lagrangian_loss,
|
|
entropy=entropy,
|
|
entropy_loss=target_entropy_loss,
|
|
target_values=target_values.mean(),
|
|
)
|
|
|
|
critic_grad_fn = jax.value_and_grad(critic_loss_fn, has_aux=True)
|
|
output, grads = critic_grad_fn(train_state.critic.params)
|
|
critic_train_state = train_state.critic.apply_gradients(grads)
|
|
train_state = train_state.replace(
|
|
critic=critic_train_state,
|
|
)
|
|
critic_metrics = output[1]
|
|
|
|
actor_grad_fn = jax.value_and_grad(actor_loss, has_aux=True)
|
|
output, grads = actor_grad_fn(train_state.actor.params)
|
|
actor_train_state = train_state.actor.apply_gradients(grads)
|
|
train_state = train_state.replace(
|
|
actor=actor_train_state,
|
|
)
|
|
actor_metrics = output[1]
|
|
return (idx + 1, train_state), {
|
|
**critic_metrics,
|
|
**actor_metrics,
|
|
}
|
|
|
|
# Shuffle data and split into mini-batches
|
|
key, shuffle_key = jax.random.split(key)
|
|
mini_batch_size = (cfg.num_steps * cfg.num_envs) // cfg.num_mini_batches
|
|
indices = jax.random.permutation(shuffle_key, cfg.num_steps * cfg.num_envs)
|
|
minibatch_idxs = jax.tree.map(
|
|
lambda x: x.reshape(
|
|
(cfg.num_mini_batches, mini_batch_size, *x.shape[1:])
|
|
),
|
|
indices,
|
|
)
|
|
|
|
# Run model update for each mini-batch
|
|
train_state, metrics = jax.lax.scan(
|
|
minibatch_update, train_state, minibatch_idxs
|
|
)
|
|
# Compute mean metrics across mini-batches
|
|
metrics = jax.tree.map(lambda x: x.mean(0), metrics)
|
|
return train_state, metrics
|
|
|
|
# Update the model for a number of epochs
|
|
key, train_key = jax.random.split(key)
|
|
(_, train_state), update_metrics = jax.lax.scan(
|
|
f=update,
|
|
init=(1, train_state),
|
|
xs=jax.random.split(train_key, cfg.num_epochs),
|
|
)
|
|
# Get metrics from the last epoch
|
|
update_metrics = jax.tree.map(lambda x: x[-1], update_metrics)
|
|
|
|
return train_state, update_metrics
|
|
|
|
def train_fn(key: PRNGKey, cfg: ReppoConfig) -> tuple[SACTrainState, dict]:
|
|
def train_eval_step(key, train_state):
|
|
def train_step(
|
|
state: SACTrainState, key: PRNGKey
|
|
) -> tuple[SACTrainState, dict[str, jax.Array]]:
|
|
key, rollout_key, learn_key = jax.random.split(key, 3)
|
|
transitions, state = collect_rollout(key=rollout_key, train_state=state)
|
|
state, update_metrics = learn_step(
|
|
key=learn_key, train_state=state, batch=transitions
|
|
)
|
|
metrics = {**update_metrics, **update_metrics}
|
|
state = state.replace(iteration=state.iteration + 1)
|
|
return state, metrics
|
|
|
|
train_key, eval_key = jax.random.split(key)
|
|
eval_interval = int(
|
|
(cfg.total_time_steps / (cfg.num_steps * cfg.num_envs)) // cfg.num_eval
|
|
)
|
|
train_state, train_metrics = jax.lax.scan(
|
|
f=train_step,
|
|
init=train_state,
|
|
xs=jax.random.split(train_key, eval_interval),
|
|
)
|
|
train_metrics = jax.tree.map(lambda x: x[-1], train_metrics)
|
|
policy = make_policy(train_state)
|
|
if cfg.normalize_env:
|
|
norm_state = train_state.last_env_state
|
|
else:
|
|
norm_state = None
|
|
eval_metrics = eval_fn(eval_key, policy, norm_state)
|
|
train_returns = {
|
|
"train/episode_return": train_state.last_env_state.info[
|
|
"returned_episode_returns"
|
|
].mean(),
|
|
"train/episode_length": train_state.last_env_state.info[
|
|
"returned_episode_lengths"
|
|
].mean(),
|
|
}
|
|
metrics = {
|
|
"time_step": train_state.time_steps,
|
|
**utils.prefix_dict("train", train_metrics),
|
|
**utils.prefix_dict("eval", eval_metrics),
|
|
**train_returns,
|
|
}
|
|
return train_state, metrics
|
|
|
|
def loop_body(
|
|
train_state: SACTrainState, key: PRNGKey
|
|
) -> tuple[SACTrainState, dict]:
|
|
key, subkey = jax.random.split(key)
|
|
train_state, metrics = jax.vmap(train_eval_step)(
|
|
jax.random.split(subkey, num_seeds), train_state
|
|
)
|
|
jax.debug.callback(log_callback, train_state, metrics)
|
|
return train_state, metrics
|
|
|
|
eval_interval = int(
|
|
(cfg.total_time_steps / (cfg.num_steps * cfg.num_envs)) // cfg.num_eval
|
|
)
|
|
num_train_steps = cfg.total_time_steps // (cfg.num_steps * cfg.num_envs)
|
|
num_iterations = num_train_steps // eval_interval + int(
|
|
num_train_steps % eval_interval != 0
|
|
)
|
|
key, init_key = jax.random.split(key)
|
|
train_state = jax.vmap(make_init(cfg, env, env_params))(
|
|
jax.random.split(init_key, num_seeds)
|
|
)
|
|
keys = jax.random.split(key, num_iterations)
|
|
state, metrics = jax.lax.scan(f=loop_body, init=train_state, xs=keys)
|
|
return state, metrics
|
|
|
|
return train_fn
|
|
|
|
|
|
def plot_history(history: list[dict[str, jax.Array]]):
|
|
steps = jnp.array([m["time_step"][0] for m in history])
|
|
eval_return = jnp.array([m["eval/episode_return"].mean() for m in history])
|
|
eval_return_std = jnp.array([m["eval/episode_return"].std() for m in history])
|
|
fig = go.Figure(
|
|
[
|
|
go.Scatter(
|
|
x=steps,
|
|
y=eval_return,
|
|
name="Mean Episode Return",
|
|
mode="lines",
|
|
line=dict(color="blue"),
|
|
showlegend=False,
|
|
),
|
|
go.Scatter(
|
|
x=steps,
|
|
y=eval_return + eval_return_std,
|
|
name="Upper Bound",
|
|
mode="lines",
|
|
line=dict(width=0),
|
|
showlegend=False,
|
|
),
|
|
go.Scatter(
|
|
x=steps,
|
|
y=eval_return - eval_return_std,
|
|
name="Lower Bound",
|
|
mode="lines",
|
|
line=dict(width=0),
|
|
fill="tonexty",
|
|
fillcolor="rgba(50, 127, 168, 0.3)",
|
|
showlegend=False,
|
|
),
|
|
]
|
|
)
|
|
fig.update_layout(
|
|
xaxis=dict(title=dict(text="Environment Steps")),
|
|
)
|
|
|
|
return fig
|
|
|
|
|
|
# type object
|
|
def _get_optuna_type(trial: optuna.Trial, name, values: list):
|
|
if all(isinstance(v, int) for v in values):
|
|
return trial.suggest_int(name, low=min(values), high=max(values))
|
|
elif all(isinstance(v, float) for v in values):
|
|
return trial.suggest_float(name, low=min(values), high=max(values))
|
|
elif all(isinstance(v, str) for v in values):
|
|
return trial.suggest_categorical(name, values)
|
|
elif all(isinstance(v, bool) for v in values):
|
|
return trial.suggest_categorical(name, [True, False])
|
|
else:
|
|
raise ValueError("Values must be of the same type (int, float, or str).")
|
|
|
|
|
|
def run(cfg: DictConfig, trial: optuna.Trial | None) -> float:
|
|
"""
|
|
Run a single trial of the SAC training process with hyperparameter tuning.
|
|
Args:
|
|
cfg (DictConfig): Configuration for the SAC training.
|
|
trial (optuna.Trial | None): Optuna trial object for hyperparameter tuning.
|
|
Returns:
|
|
float: The mean episode return from the trial.
|
|
"""
|
|
sweep_metrics = []
|
|
|
|
if trial is not None:
|
|
# Set hyperparameters from the trial
|
|
for name, values in cfg.trial_spec.items():
|
|
if name in cfg.hyperparameters:
|
|
sampled_value = _get_optuna_type(trial, name, values)
|
|
# TODO: Why the fuck is this happening
|
|
if isinstance(sampled_value, np.float64):
|
|
sampled_value = float(sampled_value)
|
|
cfg.hyperparameters[name] = sampled_value
|
|
else:
|
|
raise ValueError(f"Hyperparameter {name} not found in config.")
|
|
|
|
try:
|
|
with open("completed_trials.txt", "r") as f:
|
|
completed_trials = int(f.read())
|
|
except FileNotFoundError:
|
|
completed_trials = 0
|
|
|
|
metric_history = []
|
|
|
|
def log_callback(state, metrics):
|
|
metrics["sys_time"] = time.perf_counter()
|
|
if len(metric_history) > 0:
|
|
num_env_steps = state.time_steps[0] - metric_history[-1]["time_step"][0]
|
|
seconds = metrics["sys_time"] - metric_history[-1]["sys_time"]
|
|
sps = num_env_steps / seconds
|
|
else:
|
|
sps = 0
|
|
|
|
metric_history.append(metrics)
|
|
episode_return = metrics["eval/episode_return"].mean()
|
|
eval_length = metrics["eval/episode_length"].mean()
|
|
logging.info(
|
|
f"step={state.time_steps[0]} episode_return={episode_return:.3f}, episode_length={eval_length:.3f} sps={sps:.2f}"
|
|
)
|
|
log_data = {
|
|
"eval/episode_return": episode_return,
|
|
"eval/episode_length": eval_length,
|
|
**jax.tree.map(jnp.mean, utils.filter_prefix("train", metrics)),
|
|
}
|
|
wandb.log(log_data, step=state.time_steps[0])
|
|
|
|
# Set up the experiment
|
|
if cfg.env.type == "brax":
|
|
env = BraxGymnaxWrapper(
|
|
cfg.env.name,
|
|
episode_length=cfg.env.max_episode_steps,
|
|
reward_scaling=cfg.env.reward_scaling,
|
|
terminate=cfg.env.terminate,
|
|
)
|
|
elif cfg.env.type == "mjx":
|
|
env = MjxGymnaxWrapper(
|
|
cfg.env.name,
|
|
episode_length=cfg.env.max_episode_steps,
|
|
reward_scale=cfg.env.reward_scaling,
|
|
push_distractions=cfg.env.get("push_distractions", False),
|
|
asymmetric_observation=cfg.env.get("asymmetric_observation", False),
|
|
)
|
|
else:
|
|
raise ValueError(f"Unknown environment type: {cfg.env.type}")
|
|
|
|
# build algo config with overrides
|
|
|
|
train_fn = make_train_fn(
|
|
cfg=ReppoConfig(**cfg.hyperparameters),
|
|
env=env,
|
|
log_callback=log_callback,
|
|
num_seeds=cfg.num_seeds,
|
|
reward_scale=1.0 / cfg.env.reward_scaling,
|
|
)
|
|
|
|
for i in range(completed_trials, cfg.num_trials):
|
|
cfg.seed = cfg.seed + i
|
|
|
|
wandb.init(
|
|
mode=cfg.wandb.mode,
|
|
project=cfg.wandb.project,
|
|
entity=cfg.wandb.entity,
|
|
tags=[
|
|
cfg.name,
|
|
cfg.env.name,
|
|
cfg.env.type,
|
|
"hp_tune" if trial is not None else "val",
|
|
*cfg.tags,
|
|
],
|
|
config=OmegaConf.to_container(cfg),
|
|
name=f"resampling-{cfg.name}-{cfg.env.name.lower()}",
|
|
save_code=True,
|
|
)
|
|
|
|
logging.info(OmegaConf.to_yaml(cfg))
|
|
|
|
key = jax.random.PRNGKey(cfg.seed)
|
|
start = time.perf_counter()
|
|
_, metrics = jax.jit(train_fn, static_argnums=(1,))(
|
|
key, ReppoConfig(**cfg.hyperparameters)
|
|
)
|
|
jax.block_until_ready(metrics)
|
|
duration = time.perf_counter() - start
|
|
|
|
# Save metrics and finish the run
|
|
logging.info(f"Training took {duration:.2f} seconds.")
|
|
jnp.savez("metrics.npz", **metrics)
|
|
wandb.finish()
|
|
|
|
sweep_metrics.append(metrics["eval/episode_return"])
|
|
|
|
with open("completed_trials.txt", "w") as f:
|
|
f.write(str(i))
|
|
|
|
sweep_metrics_array = jnp.array(sweep_metrics)
|
|
return (0.1 * sweep_metrics_array.mean() + sweep_metrics_array[:, -1].mean()).item()
|
|
|
|
|
|
@hydra.main(version_base=None, config_path="../../config", config_name="reppo")
|
|
def main(cfg: DictConfig):
|
|
run(cfg, trial=None)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|