This commit is contained in:
Axel Brunnbauer 2025-07-15 22:40:47 -07:00
parent 011cbce7f8
commit 25bdedc780
4 changed files with 9 additions and 106 deletions

View File

@ -1,89 +0,0 @@
defaults:
- env: brax
- experiment_overrides: default
- trial_spec: default
- platform: torch
- _self_
hyperparameters:
# env and run settings (mostly don't touch)
total_time_steps: 50_000_000
normalize_env: true
max_episode_steps: 1000
eval_interval: 2
num_eval: 20
# optimization settings (seem very stable)
lr: 3e-4
anneal_lr: false
max_grad_norm: 0.5
polyak: 1.0 # maybe ablate ?
# problem discount settings (need tuning)
gamma: 0.99
lmbda: 0.95
lmbda_min: 0.50 # irrelevant if no exploration noise is added
# batch settings (need tuning for MJX humanoid)
num_steps: 128
num_mini_batches: 128
num_envs: 1024
num_epochs: 4
# exploration settings (currently not touched)
exploration_noise_max: 1.0
exploration_noise_min: 1.0
exploration_base_envs: 0
# critic architecture settings (need to be increased for MJX humanoid)
critic_hidden_dim: 512
actor_hidden_dim: 512
vmin: ${env.vmin}
vmax: ${env.vmax}
num_bins: 151
hl_gauss: true
use_critic_norm: true
num_critic_encoder_layers: 2
num_critic_head_layers: 2
num_critic_pred_layers: 2
use_simplical_embedding: False
# actor architecture settings (seem stable)
use_actor_norm: true
num_actor_layers: 3
actor_min_std: 0.0
# actor & critic loss settings (seem remarkably stable)
## kl settings
kl_start: 0.01
kl_bound: 0.1 # switched to tighter bounds for MJX
reduce_kl: true
reverse_kl: false # previous default "false"
update_kl_lagrangian: true
actor_kl_clip_mode: "clipped" # "full", "clipped", "kl_relu_clipped", "kl_bound_clipped", "value"
## entropy settings
ent_start: 0.01
ent_target_mult: 0.5
update_entropy_lagrangian: true
## auxiliary loss settings
aux_loss_mult: 1.0
measure_burnin: 3
name: "sac"
seed: 0
num_seeds: 1
tune: false
checkpoint_dir: null
num_trials: 10
tags: ["experimental"]
wandb:
mode: "online" # set to online to activate wandb
entity: "viper_svg"
project: "online_sac"
hydra:
job:
chdir: True

View File

@ -565,7 +565,6 @@ def make_train_fn(
num_train_steps % eval_interval != 0
)
key, init_key = jax.random.split(key)
# TWK ??: We retain the same initial state for each of the seeds across all episodes?
train_state = jax.vmap(make_init(cfg, env, env_params))(
jax.random.split(init_key, num_seeds)
)
@ -578,9 +577,6 @@ def make_train_fn(
def plot_history(history: list[dict[str, jax.Array]]):
"""
TODO -- TWK: Possibly remove this...
"""
steps = jnp.array([m["time_step"][0] for m in history])
eval_return = jnp.array([m["eval/episode_return"].mean() for m in history])
eval_return_std = jnp.array([m["eval/episode_return"].std() for m in history])
@ -692,10 +688,6 @@ def run(cfg: DictConfig):
def tune(cfg: DictConfig):
"""
TODO: Signature + also adjusting to run tuning for Brax environments as well
"""
def log_callback(state, metrics):
episode_return = metrics["eval/episode_return"].mean()
t = state.time_steps[0]

View File

@ -29,7 +29,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from fast_sac_utils import (
from src.torchrl.reppo_util import (
EmpiricalNormalization,
PerTaskRewardNormalizer,
RewardNormalizer,
@ -90,7 +90,7 @@ def main():
print(f"Using device: {device}")
if args.env_name.startswith("h1hand-") or args.env_name.startswith("h1-"):
from reppo.env_utils.torch_wrappers.humanoid_bench_env import (
from src.env_utils.torch_wrappers.humanoid_bench_env import (
HumanoidBenchEnv,
)
@ -98,7 +98,7 @@ def main():
envs = HumanoidBenchEnv(args.env_name, args.num_envs, device=device)
eval_envs = envs
elif args.env_name.startswith("Isaac-"):
from reppo.env_utils.torch_wrappers.isaaclab_env import IsaacLabEnv
from src.env_utils.torch_wrappers.isaaclab_env import IsaacLabEnv
env_type = "isaaclab"
envs = IsaacLabEnv(
@ -110,14 +110,14 @@ def main():
)
eval_envs = envs
elif args.env_name.startswith("MTBench-"):
from reppo.env_utils.torch_wrappers.mtbench_env import MTBenchEnv
from src.env_utils.torch_wrappers.mtbench_env import MTBenchEnv
env_name = "-".join(args.env_name.split("-")[1:])
env_type = "mtbench"
envs = MTBenchEnv(env_name, args.device_rank, args.num_envs, args.seed)
eval_envs = envs
else:
from reppo.env_utils.torch_wrappers.mujoco_playground_env import make_env
from src.env_utils.torch_wrappers.mujoco_playground_env import make_env
# TODO: Check if re-using same envs for eval could reduce memory usage
env_type = "mujoco_playground"
@ -198,7 +198,7 @@ def main():
if args.agent == "fasttd3":
if env_type in ["mtbench"]:
from reppo.network_utils.fast_td3_nets import (
from src.network_utils.fast_td3_nets import (
MultiTaskActor,
MultiTaskCritic,
)
@ -206,7 +206,7 @@ def main():
actor_cls = MultiTaskActor
critic_cls = MultiTaskCritic
else:
from reppo.network_utils.fast_td3_nets import Actor, Critic
from src.network_utils.fast_td3_nets import Actor, Critic
actor_cls = Actor
critic_cls = Critic
@ -214,7 +214,7 @@ def main():
print("Using FastTD3")
elif args.agent == "fasttd3_simbav2":
if env_type in ["mtbench"]:
from reppo.network_utils.fast_td3_nets_simbav2 import (
from src.network_utils.fast_td3_nets_simbav2 import (
MultiTaskActor,
MultiTaskCritic,
)
@ -222,7 +222,7 @@ def main():
actor_cls = MultiTaskActor
critic_cls = MultiTaskCritic
else:
from reppo.network_utils.fast_td3_nets_simbav2 import Actor, Critic
from src.network_utils.fast_td3_nets_simbav2 import Actor, Critic
actor_cls = Actor
critic_cls = Critic