diff --git a/config/experiment_overrides/default.yaml b/config/experiment_overrides/default.yaml deleted file mode 100644 index e69de29..0000000 diff --git a/config/sac.yaml b/config/sac.yaml deleted file mode 100644 index 7f38935..0000000 --- a/config/sac.yaml +++ /dev/null @@ -1,89 +0,0 @@ -defaults: - - env: brax - - experiment_overrides: default - - trial_spec: default - - platform: torch - - _self_ - -hyperparameters: - # env and run settings (mostly don't touch) - total_time_steps: 50_000_000 - normalize_env: true - max_episode_steps: 1000 - eval_interval: 2 - num_eval: 20 - - # optimization settings (seem very stable) - lr: 3e-4 - anneal_lr: false - max_grad_norm: 0.5 - polyak: 1.0 # maybe ablate ? - - # problem discount settings (need tuning) - gamma: 0.99 - lmbda: 0.95 - lmbda_min: 0.50 # irrelevant if no exploration noise is added - - # batch settings (need tuning for MJX humanoid) - num_steps: 128 - num_mini_batches: 128 - num_envs: 1024 - num_epochs: 4 - - # exploration settings (currently not touched) - exploration_noise_max: 1.0 - exploration_noise_min: 1.0 - exploration_base_envs: 0 - - # critic architecture settings (need to be increased for MJX humanoid) - critic_hidden_dim: 512 - actor_hidden_dim: 512 - vmin: ${env.vmin} - vmax: ${env.vmax} - num_bins: 151 - hl_gauss: true - use_critic_norm: true - num_critic_encoder_layers: 2 - num_critic_head_layers: 2 - num_critic_pred_layers: 2 - use_simplical_embedding: False - - # actor architecture settings (seem stable) - use_actor_norm: true - num_actor_layers: 3 - actor_min_std: 0.0 - - # actor & critic loss settings (seem remarkably stable) - ## kl settings - kl_start: 0.01 - kl_bound: 0.1 # switched to tighter bounds for MJX - reduce_kl: true - reverse_kl: false # previous default "false" - update_kl_lagrangian: true - actor_kl_clip_mode: "clipped" # "full", "clipped", "kl_relu_clipped", "kl_bound_clipped", "value" - ## entropy settings - ent_start: 0.01 - ent_target_mult: 0.5 - update_entropy_lagrangian: true - ## auxiliary loss settings - aux_loss_mult: 1.0 - - -measure_burnin: 3 - - -name: "sac" -seed: 0 -num_seeds: 1 -tune: false -checkpoint_dir: null -num_trials: 10 -tags: ["experimental"] -wandb: - mode: "online" # set to online to activate wandb - entity: "viper_svg" - project: "online_sac" - -hydra: - job: - chdir: True diff --git a/src/jaxrl/ppo_mjx.py b/src/jaxrl/ppo_mjx.py index 0d55cd7..bbeb1b9 100644 --- a/src/jaxrl/ppo_mjx.py +++ b/src/jaxrl/ppo_mjx.py @@ -565,7 +565,6 @@ def make_train_fn( num_train_steps % eval_interval != 0 ) key, init_key = jax.random.split(key) - # TWK ??: We retain the same initial state for each of the seeds across all episodes? train_state = jax.vmap(make_init(cfg, env, env_params))( jax.random.split(init_key, num_seeds) ) @@ -578,9 +577,6 @@ def make_train_fn( def plot_history(history: list[dict[str, jax.Array]]): - """ - TODO -- TWK: Possibly remove this... - """ steps = jnp.array([m["time_step"][0] for m in history]) eval_return = jnp.array([m["eval/episode_return"].mean() for m in history]) eval_return_std = jnp.array([m["eval/episode_return"].std() for m in history]) @@ -692,10 +688,6 @@ def run(cfg: DictConfig): def tune(cfg: DictConfig): - """ - TODO: Signature + also adjusting to run tuning for Brax environments as well - """ - def log_callback(state, metrics): episode_return = metrics["eval/episode_return"].mean() t = state.time_steps[0] diff --git a/src/torchrl/fast_td3.py b/src/torchrl/fast_td3.py index a1cd25a..aacdb2a 100644 --- a/src/torchrl/fast_td3.py +++ b/src/torchrl/fast_td3.py @@ -29,7 +29,7 @@ import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim -from fast_sac_utils import ( +from src.torchrl.reppo_util import ( EmpiricalNormalization, PerTaskRewardNormalizer, RewardNormalizer, @@ -90,7 +90,7 @@ def main(): print(f"Using device: {device}") if args.env_name.startswith("h1hand-") or args.env_name.startswith("h1-"): - from reppo.env_utils.torch_wrappers.humanoid_bench_env import ( + from src.env_utils.torch_wrappers.humanoid_bench_env import ( HumanoidBenchEnv, ) @@ -98,7 +98,7 @@ def main(): envs = HumanoidBenchEnv(args.env_name, args.num_envs, device=device) eval_envs = envs elif args.env_name.startswith("Isaac-"): - from reppo.env_utils.torch_wrappers.isaaclab_env import IsaacLabEnv + from src.env_utils.torch_wrappers.isaaclab_env import IsaacLabEnv env_type = "isaaclab" envs = IsaacLabEnv( @@ -110,14 +110,14 @@ def main(): ) eval_envs = envs elif args.env_name.startswith("MTBench-"): - from reppo.env_utils.torch_wrappers.mtbench_env import MTBenchEnv + from src.env_utils.torch_wrappers.mtbench_env import MTBenchEnv env_name = "-".join(args.env_name.split("-")[1:]) env_type = "mtbench" envs = MTBenchEnv(env_name, args.device_rank, args.num_envs, args.seed) eval_envs = envs else: - from reppo.env_utils.torch_wrappers.mujoco_playground_env import make_env + from src.env_utils.torch_wrappers.mujoco_playground_env import make_env # TODO: Check if re-using same envs for eval could reduce memory usage env_type = "mujoco_playground" @@ -198,7 +198,7 @@ def main(): if args.agent == "fasttd3": if env_type in ["mtbench"]: - from reppo.network_utils.fast_td3_nets import ( + from src.network_utils.fast_td3_nets import ( MultiTaskActor, MultiTaskCritic, ) @@ -206,7 +206,7 @@ def main(): actor_cls = MultiTaskActor critic_cls = MultiTaskCritic else: - from reppo.network_utils.fast_td3_nets import Actor, Critic + from src.network_utils.fast_td3_nets import Actor, Critic actor_cls = Actor critic_cls = Critic @@ -214,7 +214,7 @@ def main(): print("Using FastTD3") elif args.agent == "fasttd3_simbav2": if env_type in ["mtbench"]: - from reppo.network_utils.fast_td3_nets_simbav2 import ( + from src.network_utils.fast_td3_nets_simbav2 import ( MultiTaskActor, MultiTaskCritic, ) @@ -222,7 +222,7 @@ def main(): actor_cls = MultiTaskActor critic_cls = MultiTaskCritic else: - from reppo.network_utils.fast_td3_nets_simbav2 import Actor, Critic + from src.network_utils.fast_td3_nets_simbav2 import Actor, Critic actor_cls = Actor critic_cls = Critic