cleanup
This commit is contained in:
parent
011cbce7f8
commit
25bdedc780
@ -1,89 +0,0 @@
|
||||
defaults:
|
||||
- env: brax
|
||||
- experiment_overrides: default
|
||||
- trial_spec: default
|
||||
- platform: torch
|
||||
- _self_
|
||||
|
||||
hyperparameters:
|
||||
# env and run settings (mostly don't touch)
|
||||
total_time_steps: 50_000_000
|
||||
normalize_env: true
|
||||
max_episode_steps: 1000
|
||||
eval_interval: 2
|
||||
num_eval: 20
|
||||
|
||||
# optimization settings (seem very stable)
|
||||
lr: 3e-4
|
||||
anneal_lr: false
|
||||
max_grad_norm: 0.5
|
||||
polyak: 1.0 # maybe ablate ?
|
||||
|
||||
# problem discount settings (need tuning)
|
||||
gamma: 0.99
|
||||
lmbda: 0.95
|
||||
lmbda_min: 0.50 # irrelevant if no exploration noise is added
|
||||
|
||||
# batch settings (need tuning for MJX humanoid)
|
||||
num_steps: 128
|
||||
num_mini_batches: 128
|
||||
num_envs: 1024
|
||||
num_epochs: 4
|
||||
|
||||
# exploration settings (currently not touched)
|
||||
exploration_noise_max: 1.0
|
||||
exploration_noise_min: 1.0
|
||||
exploration_base_envs: 0
|
||||
|
||||
# critic architecture settings (need to be increased for MJX humanoid)
|
||||
critic_hidden_dim: 512
|
||||
actor_hidden_dim: 512
|
||||
vmin: ${env.vmin}
|
||||
vmax: ${env.vmax}
|
||||
num_bins: 151
|
||||
hl_gauss: true
|
||||
use_critic_norm: true
|
||||
num_critic_encoder_layers: 2
|
||||
num_critic_head_layers: 2
|
||||
num_critic_pred_layers: 2
|
||||
use_simplical_embedding: False
|
||||
|
||||
# actor architecture settings (seem stable)
|
||||
use_actor_norm: true
|
||||
num_actor_layers: 3
|
||||
actor_min_std: 0.0
|
||||
|
||||
# actor & critic loss settings (seem remarkably stable)
|
||||
## kl settings
|
||||
kl_start: 0.01
|
||||
kl_bound: 0.1 # switched to tighter bounds for MJX
|
||||
reduce_kl: true
|
||||
reverse_kl: false # previous default "false"
|
||||
update_kl_lagrangian: true
|
||||
actor_kl_clip_mode: "clipped" # "full", "clipped", "kl_relu_clipped", "kl_bound_clipped", "value"
|
||||
## entropy settings
|
||||
ent_start: 0.01
|
||||
ent_target_mult: 0.5
|
||||
update_entropy_lagrangian: true
|
||||
## auxiliary loss settings
|
||||
aux_loss_mult: 1.0
|
||||
|
||||
|
||||
measure_burnin: 3
|
||||
|
||||
|
||||
name: "sac"
|
||||
seed: 0
|
||||
num_seeds: 1
|
||||
tune: false
|
||||
checkpoint_dir: null
|
||||
num_trials: 10
|
||||
tags: ["experimental"]
|
||||
wandb:
|
||||
mode: "online" # set to online to activate wandb
|
||||
entity: "viper_svg"
|
||||
project: "online_sac"
|
||||
|
||||
hydra:
|
||||
job:
|
||||
chdir: True
|
@ -565,7 +565,6 @@ def make_train_fn(
|
||||
num_train_steps % eval_interval != 0
|
||||
)
|
||||
key, init_key = jax.random.split(key)
|
||||
# TWK ??: We retain the same initial state for each of the seeds across all episodes?
|
||||
train_state = jax.vmap(make_init(cfg, env, env_params))(
|
||||
jax.random.split(init_key, num_seeds)
|
||||
)
|
||||
@ -578,9 +577,6 @@ def make_train_fn(
|
||||
|
||||
|
||||
def plot_history(history: list[dict[str, jax.Array]]):
|
||||
"""
|
||||
TODO -- TWK: Possibly remove this...
|
||||
"""
|
||||
steps = jnp.array([m["time_step"][0] for m in history])
|
||||
eval_return = jnp.array([m["eval/episode_return"].mean() for m in history])
|
||||
eval_return_std = jnp.array([m["eval/episode_return"].std() for m in history])
|
||||
@ -692,10 +688,6 @@ def run(cfg: DictConfig):
|
||||
|
||||
|
||||
def tune(cfg: DictConfig):
|
||||
"""
|
||||
TODO: Signature + also adjusting to run tuning for Brax environments as well
|
||||
"""
|
||||
|
||||
def log_callback(state, metrics):
|
||||
episode_return = metrics["eval/episode_return"].mean()
|
||||
t = state.time_steps[0]
|
||||
|
@ -29,7 +29,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from fast_sac_utils import (
|
||||
from src.torchrl.reppo_util import (
|
||||
EmpiricalNormalization,
|
||||
PerTaskRewardNormalizer,
|
||||
RewardNormalizer,
|
||||
@ -90,7 +90,7 @@ def main():
|
||||
print(f"Using device: {device}")
|
||||
|
||||
if args.env_name.startswith("h1hand-") or args.env_name.startswith("h1-"):
|
||||
from reppo.env_utils.torch_wrappers.humanoid_bench_env import (
|
||||
from src.env_utils.torch_wrappers.humanoid_bench_env import (
|
||||
HumanoidBenchEnv,
|
||||
)
|
||||
|
||||
@ -98,7 +98,7 @@ def main():
|
||||
envs = HumanoidBenchEnv(args.env_name, args.num_envs, device=device)
|
||||
eval_envs = envs
|
||||
elif args.env_name.startswith("Isaac-"):
|
||||
from reppo.env_utils.torch_wrappers.isaaclab_env import IsaacLabEnv
|
||||
from src.env_utils.torch_wrappers.isaaclab_env import IsaacLabEnv
|
||||
|
||||
env_type = "isaaclab"
|
||||
envs = IsaacLabEnv(
|
||||
@ -110,14 +110,14 @@ def main():
|
||||
)
|
||||
eval_envs = envs
|
||||
elif args.env_name.startswith("MTBench-"):
|
||||
from reppo.env_utils.torch_wrappers.mtbench_env import MTBenchEnv
|
||||
from src.env_utils.torch_wrappers.mtbench_env import MTBenchEnv
|
||||
|
||||
env_name = "-".join(args.env_name.split("-")[1:])
|
||||
env_type = "mtbench"
|
||||
envs = MTBenchEnv(env_name, args.device_rank, args.num_envs, args.seed)
|
||||
eval_envs = envs
|
||||
else:
|
||||
from reppo.env_utils.torch_wrappers.mujoco_playground_env import make_env
|
||||
from src.env_utils.torch_wrappers.mujoco_playground_env import make_env
|
||||
|
||||
# TODO: Check if re-using same envs for eval could reduce memory usage
|
||||
env_type = "mujoco_playground"
|
||||
@ -198,7 +198,7 @@ def main():
|
||||
|
||||
if args.agent == "fasttd3":
|
||||
if env_type in ["mtbench"]:
|
||||
from reppo.network_utils.fast_td3_nets import (
|
||||
from src.network_utils.fast_td3_nets import (
|
||||
MultiTaskActor,
|
||||
MultiTaskCritic,
|
||||
)
|
||||
@ -206,7 +206,7 @@ def main():
|
||||
actor_cls = MultiTaskActor
|
||||
critic_cls = MultiTaskCritic
|
||||
else:
|
||||
from reppo.network_utils.fast_td3_nets import Actor, Critic
|
||||
from src.network_utils.fast_td3_nets import Actor, Critic
|
||||
|
||||
actor_cls = Actor
|
||||
critic_cls = Critic
|
||||
@ -214,7 +214,7 @@ def main():
|
||||
print("Using FastTD3")
|
||||
elif args.agent == "fasttd3_simbav2":
|
||||
if env_type in ["mtbench"]:
|
||||
from reppo.network_utils.fast_td3_nets_simbav2 import (
|
||||
from src.network_utils.fast_td3_nets_simbav2 import (
|
||||
MultiTaskActor,
|
||||
MultiTaskCritic,
|
||||
)
|
||||
@ -222,7 +222,7 @@ def main():
|
||||
actor_cls = MultiTaskActor
|
||||
critic_cls = MultiTaskCritic
|
||||
else:
|
||||
from reppo.network_utils.fast_td3_nets_simbav2 import Actor, Critic
|
||||
from src.network_utils.fast_td3_nets_simbav2 import Actor, Critic
|
||||
|
||||
actor_cls = Actor
|
||||
critic_cls = Critic
|
||||
|
Loading…
Reference in New Issue
Block a user