reppo/reppo_alg/jaxrl/reppo.py
2025-07-21 18:31:20 -04:00

925 lines
34 KiB
Python

import logging
import time
import typing
from typing import Callable
import hydra
import jax
import numpy as np
import optax
import optuna
import plotly.graph_objs as go
from flax import nnx, struct
from flax.struct import PyTreeNode
from gymnax.environments.environment import Environment, EnvParams, EnvState
from jax import numpy as jnp
from jax.random import PRNGKey
from omegaconf import DictConfig, OmegaConf
import wandb
from reppo_alg.env_utils.jax_wrappers import (
BraxGymnaxWrapper,
ClipAction,
LogWrapper,
MjxGymnaxWrapper,
NormalizeVec,
)
from reppo_alg.jaxrl import utils, muon
from reppo_alg.network_utils.jax_models import (
CategoricalCriticNetwork,
CriticNetwork,
SACActorNetworks,
)
logging.basicConfig(level=logging.INFO)
class Policy(typing.Protocol):
def __call__(
self,
key: jax.random.PRNGKey,
obs: PyTreeNode,
) -> tuple[PyTreeNode, PyTreeNode]:
pass
class Transition(struct.PyTreeNode):
obs: jax.Array
critic_obs: jax.Array
action: jax.Array
reward: jax.Array
next_emb: jax.Array
value: jax.Array
done: jax.Array
truncated: jax.Array
importance_weight: jax.Array
info: dict[str, jax.Array]
class ReppoConfig(struct.PyTreeNode):
lr: float
gamma: float
total_time_steps: int
num_steps: int
lmbda: float
lmbda_min: float
num_mini_batches: int
num_envs: int
num_epochs: int
max_grad_norm: float | None
normalize_env: bool
polyak: float
exploration_noise_min: float
exploration_noise_max: float
exploration_base_envs: int
ent_start: float
ent_target_mult: float
kl_start: float
eval_interval: int = 10
num_eval: int = 25
max_episode_steps: int = 1000
critic_hidden_dim: int = 512
actor_hidden_dim: int = 512
vmin: int = -100
vmax: int = 100
num_bins: int = 250
hl_gauss: bool = False
kl_bound: float = 1.0
aux_loss_mult: float = 0.0
update_kl_lagrangian: bool = True
update_entropy_lagrangian: bool = True
use_critic_norm: bool = True
num_critic_encoder_layers: int = 1
num_critic_head_layers: int = 1
num_critic_pred_layers: int = 1
use_simplical_embedding: bool = False
use_actor_norm: bool = True
num_actor_layers: int = 2
actor_min_std: float = 0.05
reduce_kl: bool = True
reverse_kl: bool = False
anneal_lr: bool = False
actor_kl_clip_mode: str = "clipped"
class SACTrainState(struct.PyTreeNode):
critic: nnx.TrainState
actor: nnx.TrainState
actor_target: nnx.TrainState
iteration: int
time_steps: int
last_env_state: EnvState
last_obs: jax.Array
last_critic_obs: jax.Array
def make_policy(
train_state: SACTrainState,
) -> Callable[[jax.Array, jax.Array], tuple[jax.Array, dict]]:
def policy(key: PRNGKey, obs: jax.Array) -> tuple[jax.Array, dict]:
actor_model = nnx.merge(train_state.actor.graphdef, train_state.actor.params)
action: jax.Array = actor_model.det_action(obs)
return action, {}
return policy
def make_eval_fn(
env: Environment, max_episode_steps: int, reward_scale: float = 1.0
) -> Callable[[jax.random.PRNGKey, Policy, PyTreeNode | None], dict[str, float]]:
def evaluation_fn(
key: jax.random.PRNGKey, policy: Policy, norm_state: PyTreeNode | None
):
def step_env(carry, _):
key, env_state, obs = carry
key, act_key, env_key = jax.random.split(key, 3)
action, _ = policy(act_key, obs)
step_key = jax.random.split(env_key, env.num_envs)
obs, _, env_state, reward, done, info = env.step(
step_key, env_state, action
)
return (key, env_state, obs), info
key, init_key = jax.random.split(key)
init_key = jax.random.split(init_key, env.num_envs)
obs, _, env_state = env.reset(init_key, norm_state)
# randomize initial steps
key, env_key = jax.random.split(key)
_, infos = jax.lax.scan(
f=step_env,
init=(key, env_state, obs),
xs=None,
length=max_episode_steps,
)
return {
"episode_return": infos["returned_episode_returns"].mean(
where=infos["returned_episode"]
)
* reward_scale,
"episode_return_std": infos["returned_episode_returns"].std(
where=infos["returned_episode"]
),
"episode_length": infos["returned_episode_lengths"].mean(
where=infos["returned_episode"]
),
"episode_length_std": infos["returned_episode_lengths"].std(
where=infos["returned_episode"]
),
"num_episodes": infos["returned_episode"].sum(),
}
return evaluation_fn
def make_init(
cfg: ReppoConfig,
env: Environment,
env_params: EnvParams = None,
) -> Callable[[jax.Array], SACTrainState]:
def init(key: jax.random.PRNGKey) -> SACTrainState:
# Number of calls to train_step
key, model_key = jax.random.split(key)
actor_networks = SACActorNetworks(
obs_dim=env.observation_space(env_params)[0].shape[0],
action_dim=env.action_space(env_params).shape[0],
hidden_dim=cfg.actor_hidden_dim,
ent_start=cfg.ent_start,
kl_start=cfg.kl_start,
use_norm=cfg.use_actor_norm,
layers=cfg.num_actor_layers,
rngs=nnx.Rngs(model_key),
)
actor_target_networks = SACActorNetworks(
obs_dim=env.observation_space(env_params)[0].shape[0],
action_dim=env.action_space(env_params).shape[0],
hidden_dim=cfg.actor_hidden_dim,
ent_start=cfg.ent_start,
kl_start=cfg.kl_start,
use_norm=cfg.use_actor_norm,
layers=cfg.num_actor_layers,
rngs=nnx.Rngs(model_key),
)
if cfg.hl_gauss:
critic_networks: nnx.Module = CategoricalCriticNetwork(
obs_dim=env.observation_space(env_params)[1].shape[0],
action_dim=env.action_space(env_params).shape[0],
hidden_dim=cfg.critic_hidden_dim,
num_bins=cfg.num_bins,
vmin=cfg.vmin,
vmax=cfg.vmax,
use_norm=cfg.use_critic_norm,
encoder_layers=cfg.num_critic_encoder_layers,
use_simplical_embedding=cfg.use_simplical_embedding,
head_layers=cfg.num_critic_head_layers,
pred_layers=cfg.num_critic_pred_layers,
rngs=nnx.Rngs(model_key),
)
else:
critic_networks: nnx.Module = CriticNetwork(
obs_dim=env.observation_space(env_params)[1].shape[0],
action_dim=env.action_space(env_params).shape[0],
hidden_dim=cfg.critic_hidden_dim,
use_norm=cfg.use_critic_norm,
encoder_layers=cfg.num_critic_encoder_layers,
use_simplical_embedding=cfg.use_simplical_embedding,
head_layers=cfg.num_critic_head_layers,
pred_layers=cfg.num_critic_pred_layers,
rngs=nnx.Rngs(model_key),
)
if not cfg.anneal_lr:
lr = cfg.lr
else:
num_iterations = cfg.total_time_steps // cfg.num_steps // cfg.num_envs
num_updates = num_iterations * cfg.num_epochs * cfg.num_mini_batches
lr = optax.linear_schedule(cfg.lr, 0, num_updates)
if cfg.max_grad_norm is not None:
actor_optimizer = optax.chain(
optax.clip_by_global_norm(cfg.max_grad_norm),
muon.muon(lr), # optax.adam(lr) optax.adam(lr)
)
critic_optimizer = optax.chain(
optax.clip_by_global_norm(cfg.max_grad_norm),
muon.muon(lr), # optax.adam(lr) optax.adam(lr)
)
else:
actor_optimizer = muon.muon(lr) # optax.adam(lr)
critic_optimizer = muon.muon(lr) # optax.adam(lr)
actor_trainstate = nnx.TrainState.create(
graphdef=nnx.graphdef(actor_networks),
params=nnx.state(actor_networks),
tx=actor_optimizer,
)
actor_target_trainstate = nnx.TrainState.create(
graphdef=nnx.graphdef(actor_target_networks),
params=nnx.state(actor_target_networks),
tx=optax.set_to_zero(),
)
critic_trainstate = nnx.TrainState.create(
graphdef=nnx.graphdef(critic_networks),
params=nnx.state(critic_networks),
tx=critic_optimizer,
)
key, env_key = jax.random.split(key)
env_key = jax.random.split(env_key, cfg.num_envs)
obs, critic_obs, env_state = env.reset(key=env_key, params=env_params)
# randomize initial time step to prevent all envs stepping in tandem
_env_state = env_state.unwrapped()
key, randomize_steps_key = jax.random.split(key)
_env_state.info["steps"] = jax.random.randint(
randomize_steps_key,
_env_state.info["steps"].shape,
0,
cfg.max_episode_steps,
).astype(jnp.float32)
env_state.set_env_state(_env_state)
return SACTrainState(
actor=actor_trainstate,
actor_target=actor_target_trainstate,
critic=critic_trainstate,
iteration=0,
time_steps=0,
last_env_state=env_state,
last_obs=obs,
last_critic_obs=critic_obs,
)
return init
def make_train_fn(
cfg: ReppoConfig,
env: Environment,
env_params: EnvParams = None,
log_callback: Callable[[SACTrainState, dict[str, jax.Array]], None] | None = None,
num_seeds: int = 1,
reward_scale: float = 1.0,
):
env_params = env_params # or env.default_params
env = LogWrapper(env, cfg.num_envs)
env = ClipAction(env)
# env = VecEnv(env, cfg.num_envs)
if cfg.normalize_env:
env = NormalizeVec(env)
eval_fn = make_eval_fn(env, cfg.max_episode_steps, reward_scale=reward_scale)
action_size_target = (
jnp.prod(jnp.array(env.action_space(env_params).shape)) * cfg.ent_target_mult
)
def collect_rollout(
key: PRNGKey, train_state: SACTrainState
) -> tuple[Transition, SACTrainState]:
actor_model = nnx.merge(train_state.actor.graphdef, train_state.actor.params)
critic_model = nnx.merge(train_state.critic.graphdef, train_state.critic.params)
offset = (
jnp.arange(cfg.num_envs - cfg.exploration_base_envs)[:, None]
* (cfg.exploration_noise_max - cfg.exploration_noise_min)
/ (cfg.num_envs - cfg.exploration_base_envs)
) + cfg.exploration_noise_min
offset = jnp.concatenate(
[
jnp.ones((cfg.exploration_base_envs, 1)) * cfg.exploration_noise_min,
offset,
],
axis=0,
)
def step_env(carry, _) -> tuple[tuple, Transition]:
key, env_state, train_state, obs, critic_obs = carry
key, act_key, step_key = jax.random.split(key, 3)
step_key = jax.random.split(step_key, cfg.num_envs)
# get policy action
og_pi = actor_model.actor(obs)
pi = actor_model.actor(obs, scale=offset)
action = pi.sample(seed=act_key)
next_obs, next_critic_obs, next_env_state, reward, done, info = env.step(
step_key, env_state, action
)
# compute importance weights
action = jnp.clip(action, -0.999, 0.999)
raw_importance_weight = jnp.nan_to_num(
og_pi.log_prob(action).sum(-1) - pi.log_prob(action).sum(-1),
nan=jnp.log(cfg.lmbda_min),
)
importance_weight = jnp.clip(
raw_importance_weight, min=jnp.log(cfg.lmbda_min), max=jnp.log(1.0)
)
# compute next state embedding and value
next_action, log_prob = actor_model.actor(next_obs).sample_and_log_prob(
seed=act_key
)
next_emb, value = critic_model.forward(next_critic_obs, next_action)
reward = (
reward
- cfg.gamma * log_prob.sum(-1).squeeze() * actor_model.temperature()
)
transition = Transition(
obs=obs,
critic_obs=critic_obs,
action=action,
next_emb=next_emb,
reward=reward,
value=value,
done=done,
truncated=next_env_state.truncated,
info=info,
importance_weight=importance_weight,
)
return (
key,
next_env_state,
train_state,
next_obs,
next_critic_obs,
), transition
rollout_state, transitions = jax.lax.scan(
f=step_env,
init=(
key,
train_state.last_env_state,
train_state,
train_state.last_obs,
train_state.last_critic_obs,
),
length=cfg.num_steps,
)
_, last_env_state, train_state, last_obs, last_critic_obs = rollout_state
train_state = train_state.replace(
last_env_state=last_env_state,
last_obs=last_obs,
last_critic_obs=last_critic_obs,
time_steps=train_state.time_steps + cfg.num_steps * cfg.num_envs,
)
return transitions, train_state
def learn_step(
key: PRNGKey, train_state: SACTrainState, batch: Transition
) -> tuple[SACTrainState, dict[str, jax.Array]]:
# compute n-step lambda estimates
def compute_nstep_lambda(carry, transition):
lambda_return, truncated, importance_weight = carry
# combine importance_weights with TD lambda
done = transition.done
reward = transition.reward
value = transition.value
lambda_sum = (
jnp.exp(importance_weight) * cfg.lmbda * lambda_return
+ (1 - jnp.exp(importance_weight) * cfg.lmbda) * value
)
delta = cfg.gamma * jnp.where(truncated, value, (1.0 - done) * lambda_sum)
lambda_return = reward + delta
truncated = transition.truncated
return (
lambda_return,
truncated,
transition.importance_weight,
), lambda_return
_, target_values = jax.lax.scan(
compute_nstep_lambda,
(
batch.value[-1],
jnp.ones_like(batch.truncated[0]),
jnp.zeros_like(batch.importance_weight[0]),
),
batch,
reverse=True,
)
# Reshape data to (num_steps * num_envs, ...)
data = (batch, target_values)
data = jax.tree.map(
lambda x: x.reshape((cfg.num_steps * cfg.num_envs, *x.shape[2:])), data
)
train_state = train_state.replace(
actor_target=train_state.actor_target.replace(
params=train_state.actor.params
),
)
actor_target_model = nnx.merge(
train_state.actor_target.graphdef, train_state.actor_target.params
)
def update(train_state, key) -> tuple[SACTrainState, dict[str, jax.Array]]:
def minibatch_update(carry, indices):
idx, train_state = carry
# Sample data at indices from the batch
minibatch, target_values = jax.tree.map(
lambda x: jnp.take(x, indices, axis=0), data
)
def critic_loss_fn(params):
critic_model = nnx.merge(train_state.critic.graphdef, params)
critic_pred = critic_model.critic_cat(
minibatch.critic_obs, minibatch.action
).squeeze()
if cfg.hl_gauss:
target_cat = jax.vmap(
utils.hl_gauss, in_axes=(0, None, None, None)
)(target_values, cfg.num_bins, cfg.vmin, cfg.vmax)
critic_update_loss = optax.softmax_cross_entropy(
critic_pred, target_cat
)
else:
critic_update_loss = optax.squared_error(
critic_pred,
target_values,
)
# Aux loss
pred, value = critic_model.forward(
minibatch.critic_obs, minibatch.action
)
aux_loss = jnp.mean(
(1 - minibatch.done.reshape(-1, 1))
* (pred - minibatch.next_emb) ** 2,
axis=-1,
)
# compute l2 error for logging
critic_loss = optax.squared_error(
value,
target_values,
)
critic_loss = jnp.mean(critic_loss)
loss = jnp.mean(
(1.0 - minibatch.truncated)
* (critic_update_loss + cfg.aux_loss_mult * aux_loss)
)
return loss, dict(
value_loss=critic_loss,
critic_update_loss=critic_update_loss,
loss=loss,
aux_loss=aux_loss,
q=critic_pred.mean(),
abs_batch_action=jnp.abs(minibatch.action).mean(),
reward_mean=minibatch.reward.mean(),
target_values=target_values.mean(),
)
def actor_loss(params):
critic_target_model = nnx.merge(
train_state.critic.graphdef,
train_state.critic.params,
)
actor_model = nnx.merge(train_state.actor.graphdef, params)
# SAC actor loss
pi = actor_model.actor(minibatch.obs)
pred_action, log_prob = pi.sample_and_log_prob(seed=key)
value = critic_target_model.critic(
minibatch.critic_obs, pred_action
)
log_prob = log_prob.sum(-1)
entropy = -log_prob
# policy KL constraint
if cfg.reverse_kl:
pi_action, pi_act_log_prob = pi.sample_and_log_prob(
sample_shape=(16,), seed=key
)
pi_action = jnp.clip(pi_action, -1 + 1e-4, 1 - 1e-4)
old_pi = actor_target_model.actor(minibatch.obs)
old_pi_act_log_prob = old_pi.log_prob(pi_action).sum(-1).mean(0)
pi_act_log_prob = pi_act_log_prob.sum(-1).mean(0)
kl = pi_act_log_prob - old_pi_act_log_prob
else:
old_pi_action, old_pi_act_log_prob = actor_target_model.actor(
minibatch.obs
).sample_and_log_prob(sample_shape=(16,), seed=key)
old_pi_action = jnp.clip(old_pi_action, -1 + 1e-4, 1 - 1e-4)
old_pi_act_log_prob = old_pi_act_log_prob.sum(-1).mean(0)
pi_act_log_prob = pi.log_prob(old_pi_action).sum(-1).mean(0)
kl = old_pi_act_log_prob - pi_act_log_prob
lagrangian = actor_model.lagrangian()
if cfg.actor_kl_clip_mode == "full":
actor_loss = (
log_prob * jax.lax.stop_gradient(actor_model.temperature())
- value
+ kl * jax.lax.stop_gradient(lagrangian) * cfg.reduce_kl
)
elif cfg.actor_kl_clip_mode == "clipped":
actor_loss = jnp.where(
kl < cfg.kl_bound,
log_prob * jax.lax.stop_gradient(actor_model.temperature())
- value,
kl * jax.lax.stop_gradient(lagrangian) * cfg.reduce_kl,
)
elif cfg.actor_kl_clip_mode == "value":
actor_loss = (
log_prob * jax.lax.stop_gradient(actor_model.temperature())
- value
)
else:
raise ValueError(
f"Unknown actor loss mode: {cfg.actor_kl_clip_mode}"
)
# SAC target entropy loss
target_entropy = action_size_target + entropy
target_entropy_loss = (
actor_model.temperature()
* jax.lax.stop_gradient(target_entropy)
)
# Lagrangian constraint (follows temperature update)
lagrangian_loss = -lagrangian * jax.lax.stop_gradient(
kl - cfg.kl_bound
)
# total loss
loss = jnp.mean(actor_loss)
if cfg.update_entropy_lagrangian:
loss += jnp.mean(target_entropy_loss)
if cfg.update_kl_lagrangian:
loss += jnp.mean(lagrangian_loss)
return loss, dict(
actor_loss=actor_loss,
loss=loss,
temp=actor_model.temperature(),
abs_batch_action=jnp.abs(minibatch.action).mean(),
abs_pred_action=jnp.abs(pred_action).mean(),
reward_mean=minibatch.reward.mean(),
kl=kl.mean(),
lagrangian=lagrangian,
lagrangian_loss=lagrangian_loss,
entropy=entropy,
entropy_loss=target_entropy_loss,
target_values=target_values.mean(),
)
critic_grad_fn = jax.value_and_grad(critic_loss_fn, has_aux=True)
output, grads = critic_grad_fn(train_state.critic.params)
critic_train_state = train_state.critic.apply_gradients(grads)
train_state = train_state.replace(
critic=critic_train_state,
)
critic_metrics = output[1]
actor_grad_fn = jax.value_and_grad(actor_loss, has_aux=True)
output, grads = actor_grad_fn(train_state.actor.params)
actor_train_state = train_state.actor.apply_gradients(grads)
train_state = train_state.replace(
actor=actor_train_state,
)
actor_metrics = output[1]
return (idx + 1, train_state), {
**critic_metrics,
**actor_metrics,
}
# Shuffle data and split into mini-batches
key, shuffle_key = jax.random.split(key)
mini_batch_size = (cfg.num_steps * cfg.num_envs) // cfg.num_mini_batches
indices = jax.random.permutation(shuffle_key, cfg.num_steps * cfg.num_envs)
minibatch_idxs = jax.tree.map(
lambda x: x.reshape(
(cfg.num_mini_batches, mini_batch_size, *x.shape[1:])
),
indices,
)
# Run model update for each mini-batch
train_state, metrics = jax.lax.scan(
minibatch_update, train_state, minibatch_idxs
)
# Compute mean metrics across mini-batches
metrics = jax.tree.map(lambda x: x.mean(0), metrics)
return train_state, metrics
# Update the model for a number of epochs
key, train_key = jax.random.split(key)
(_, train_state), update_metrics = jax.lax.scan(
f=update,
init=(1, train_state),
xs=jax.random.split(train_key, cfg.num_epochs),
)
# Get metrics from the last epoch
update_metrics = jax.tree.map(lambda x: x[-1], update_metrics)
return train_state, update_metrics
def train_fn(key: PRNGKey, cfg: ReppoConfig) -> tuple[SACTrainState, dict]:
def train_eval_step(key, train_state):
def train_step(
state: SACTrainState, key: PRNGKey
) -> tuple[SACTrainState, dict[str, jax.Array]]:
key, rollout_key, learn_key = jax.random.split(key, 3)
transitions, state = collect_rollout(key=rollout_key, train_state=state)
state, update_metrics = learn_step(
key=learn_key, train_state=state, batch=transitions
)
metrics = {**update_metrics, **update_metrics}
state = state.replace(iteration=state.iteration + 1)
return state, metrics
train_key, eval_key = jax.random.split(key)
eval_interval = int(
(cfg.total_time_steps / (cfg.num_steps * cfg.num_envs)) // cfg.num_eval
)
train_state, train_metrics = jax.lax.scan(
f=train_step,
init=train_state,
xs=jax.random.split(train_key, eval_interval),
)
train_metrics = jax.tree.map(lambda x: x[-1], train_metrics)
policy = make_policy(train_state)
if cfg.normalize_env:
norm_state = train_state.last_env_state
else:
norm_state = None
eval_metrics = eval_fn(eval_key, policy, norm_state)
train_returns = {
"train/episode_return": train_state.last_env_state.info[
"returned_episode_returns"
].mean(),
"train/episode_length": train_state.last_env_state.info[
"returned_episode_lengths"
].mean(),
}
metrics = {
"time_step": train_state.time_steps,
**utils.prefix_dict("train", train_metrics),
**utils.prefix_dict("eval", eval_metrics),
**train_returns,
}
return train_state, metrics
def loop_body(
train_state: SACTrainState, key: PRNGKey
) -> tuple[SACTrainState, dict]:
key, subkey = jax.random.split(key)
train_state, metrics = jax.vmap(train_eval_step)(
jax.random.split(subkey, num_seeds), train_state
)
jax.debug.callback(log_callback, train_state, metrics)
return train_state, metrics
eval_interval = int(
(cfg.total_time_steps / (cfg.num_steps * cfg.num_envs)) // cfg.num_eval
)
num_train_steps = cfg.total_time_steps // (cfg.num_steps * cfg.num_envs)
num_iterations = num_train_steps // eval_interval + int(
num_train_steps % eval_interval != 0
)
key, init_key = jax.random.split(key)
train_state = jax.vmap(make_init(cfg, env, env_params))(
jax.random.split(init_key, num_seeds)
)
keys = jax.random.split(key, num_iterations)
state, metrics = jax.lax.scan(f=loop_body, init=train_state, xs=keys)
return state, metrics
return train_fn
def plot_history(history: list[dict[str, jax.Array]]):
steps = jnp.array([m["time_step"][0] for m in history])
eval_return = jnp.array([m["eval/episode_return"].mean() for m in history])
eval_return_std = jnp.array([m["eval/episode_return"].std() for m in history])
fig = go.Figure(
[
go.Scatter(
x=steps,
y=eval_return,
name="Mean Episode Return",
mode="lines",
line=dict(color="blue"),
showlegend=False,
),
go.Scatter(
x=steps,
y=eval_return + eval_return_std,
name="Upper Bound",
mode="lines",
line=dict(width=0),
showlegend=False,
),
go.Scatter(
x=steps,
y=eval_return - eval_return_std,
name="Lower Bound",
mode="lines",
line=dict(width=0),
fill="tonexty",
fillcolor="rgba(50, 127, 168, 0.3)",
showlegend=False,
),
]
)
fig.update_layout(
xaxis=dict(title=dict(text="Environment Steps")),
)
return fig
# type object
def _get_optuna_type(trial: optuna.Trial, name, values: list):
if all(isinstance(v, int) for v in values):
return trial.suggest_int(name, low=min(values), high=max(values))
elif all(isinstance(v, float) for v in values):
return trial.suggest_float(name, low=min(values), high=max(values))
elif all(isinstance(v, str) for v in values):
return trial.suggest_categorical(name, values)
elif all(isinstance(v, bool) for v in values):
return trial.suggest_categorical(name, [True, False])
else:
raise ValueError("Values must be of the same type (int, float, or str).")
def run(cfg: DictConfig, trial: optuna.Trial | None) -> float:
"""
Run a single trial of the SAC training process with hyperparameter tuning.
Args:
cfg (DictConfig): Configuration for the SAC training.
trial (optuna.Trial | None): Optuna trial object for hyperparameter tuning.
Returns:
float: The mean episode return from the trial.
"""
sweep_metrics = []
if trial is not None:
# Set hyperparameters from the trial
for name, values in cfg.trial_spec.items():
if name in cfg.hyperparameters:
sampled_value = _get_optuna_type(trial, name, values)
# TODO: Why the fuck is this happening
if isinstance(sampled_value, np.float64):
sampled_value = float(sampled_value)
cfg.hyperparameters[name] = sampled_value
else:
raise ValueError(f"Hyperparameter {name} not found in config.")
try:
with open("completed_trials.txt", "r") as f:
completed_trials = int(f.read())
except FileNotFoundError:
completed_trials = 0
metric_history = []
def log_callback(state, metrics):
metrics["sys_time"] = time.perf_counter()
if len(metric_history) > 0:
num_env_steps = state.time_steps[0] - metric_history[-1]["time_step"][0]
seconds = metrics["sys_time"] - metric_history[-1]["sys_time"]
sps = num_env_steps / seconds
else:
sps = 0
metric_history.append(metrics)
episode_return = metrics["eval/episode_return"].mean()
eval_length = metrics["eval/episode_length"].mean()
logging.info(
f"step={state.time_steps[0]} episode_return={episode_return:.3f}, episode_length={eval_length:.3f} sps={sps:.2f}"
)
log_data = {
"eval/episode_return": episode_return,
"eval/episode_length": eval_length,
**jax.tree.map(jnp.mean, utils.filter_prefix("train", metrics)),
}
wandb.log(log_data, step=state.time_steps[0])
# Set up the experiment
if cfg.env.type == "brax":
env = BraxGymnaxWrapper(
cfg.env.name,
episode_length=cfg.env.max_episode_steps,
reward_scaling=cfg.env.reward_scaling,
terminate=cfg.env.terminate,
)
elif cfg.env.type == "mjx":
env = MjxGymnaxWrapper(
cfg.env.name,
episode_length=cfg.env.max_episode_steps,
reward_scale=cfg.env.reward_scaling,
push_distractions=cfg.env.get("push_distractions", False),
asymmetric_observation=cfg.env.get("asymmetric_observation", False),
)
else:
raise ValueError(f"Unknown environment type: {cfg.env.type}")
# build algo config with overrides
train_fn = make_train_fn(
cfg=ReppoConfig(**cfg.hyperparameters),
env=env,
log_callback=log_callback,
num_seeds=cfg.num_seeds,
reward_scale=1.0 / cfg.env.reward_scaling,
)
for i in range(completed_trials, cfg.num_trials):
cfg.seed = cfg.seed + i
wandb.init(
mode=cfg.wandb.mode,
project=cfg.wandb.project,
entity=cfg.wandb.entity,
tags=[
cfg.name,
cfg.env.name,
cfg.env.type,
"hp_tune" if trial is not None else "val",
*cfg.tags,
],
config=OmegaConf.to_container(cfg),
name=f"resampling-{cfg.name}-{cfg.env.name.lower()}",
save_code=True,
)
logging.info(OmegaConf.to_yaml(cfg))
key = jax.random.PRNGKey(cfg.seed)
start = time.perf_counter()
_, metrics = jax.jit(train_fn, static_argnums=(1,))(
key, ReppoConfig(**cfg.hyperparameters)
)
jax.block_until_ready(metrics)
duration = time.perf_counter() - start
# Save metrics and finish the run
logging.info(f"Training took {duration:.2f} seconds.")
jnp.savez("metrics.npz", **metrics)
wandb.finish()
sweep_metrics.append(metrics["eval/episode_return"])
with open("completed_trials.txt", "w") as f:
f.write(str(i))
sweep_metrics_array = jnp.array(sweep_metrics)
return (0.1 * sweep_metrics_array.mean() + sweep_metrics_array[:, -1].mean()).item()
@hydra.main(version_base=None, config_path="../../config", config_name="reppo")
def main(cfg: DictConfig):
run(cfg, trial=None)
if __name__ == "__main__":
main()