cleanup
This commit is contained in:
parent
011cbce7f8
commit
25bdedc780
@ -1,89 +0,0 @@
|
|||||||
defaults:
|
|
||||||
- env: brax
|
|
||||||
- experiment_overrides: default
|
|
||||||
- trial_spec: default
|
|
||||||
- platform: torch
|
|
||||||
- _self_
|
|
||||||
|
|
||||||
hyperparameters:
|
|
||||||
# env and run settings (mostly don't touch)
|
|
||||||
total_time_steps: 50_000_000
|
|
||||||
normalize_env: true
|
|
||||||
max_episode_steps: 1000
|
|
||||||
eval_interval: 2
|
|
||||||
num_eval: 20
|
|
||||||
|
|
||||||
# optimization settings (seem very stable)
|
|
||||||
lr: 3e-4
|
|
||||||
anneal_lr: false
|
|
||||||
max_grad_norm: 0.5
|
|
||||||
polyak: 1.0 # maybe ablate ?
|
|
||||||
|
|
||||||
# problem discount settings (need tuning)
|
|
||||||
gamma: 0.99
|
|
||||||
lmbda: 0.95
|
|
||||||
lmbda_min: 0.50 # irrelevant if no exploration noise is added
|
|
||||||
|
|
||||||
# batch settings (need tuning for MJX humanoid)
|
|
||||||
num_steps: 128
|
|
||||||
num_mini_batches: 128
|
|
||||||
num_envs: 1024
|
|
||||||
num_epochs: 4
|
|
||||||
|
|
||||||
# exploration settings (currently not touched)
|
|
||||||
exploration_noise_max: 1.0
|
|
||||||
exploration_noise_min: 1.0
|
|
||||||
exploration_base_envs: 0
|
|
||||||
|
|
||||||
# critic architecture settings (need to be increased for MJX humanoid)
|
|
||||||
critic_hidden_dim: 512
|
|
||||||
actor_hidden_dim: 512
|
|
||||||
vmin: ${env.vmin}
|
|
||||||
vmax: ${env.vmax}
|
|
||||||
num_bins: 151
|
|
||||||
hl_gauss: true
|
|
||||||
use_critic_norm: true
|
|
||||||
num_critic_encoder_layers: 2
|
|
||||||
num_critic_head_layers: 2
|
|
||||||
num_critic_pred_layers: 2
|
|
||||||
use_simplical_embedding: False
|
|
||||||
|
|
||||||
# actor architecture settings (seem stable)
|
|
||||||
use_actor_norm: true
|
|
||||||
num_actor_layers: 3
|
|
||||||
actor_min_std: 0.0
|
|
||||||
|
|
||||||
# actor & critic loss settings (seem remarkably stable)
|
|
||||||
## kl settings
|
|
||||||
kl_start: 0.01
|
|
||||||
kl_bound: 0.1 # switched to tighter bounds for MJX
|
|
||||||
reduce_kl: true
|
|
||||||
reverse_kl: false # previous default "false"
|
|
||||||
update_kl_lagrangian: true
|
|
||||||
actor_kl_clip_mode: "clipped" # "full", "clipped", "kl_relu_clipped", "kl_bound_clipped", "value"
|
|
||||||
## entropy settings
|
|
||||||
ent_start: 0.01
|
|
||||||
ent_target_mult: 0.5
|
|
||||||
update_entropy_lagrangian: true
|
|
||||||
## auxiliary loss settings
|
|
||||||
aux_loss_mult: 1.0
|
|
||||||
|
|
||||||
|
|
||||||
measure_burnin: 3
|
|
||||||
|
|
||||||
|
|
||||||
name: "sac"
|
|
||||||
seed: 0
|
|
||||||
num_seeds: 1
|
|
||||||
tune: false
|
|
||||||
checkpoint_dir: null
|
|
||||||
num_trials: 10
|
|
||||||
tags: ["experimental"]
|
|
||||||
wandb:
|
|
||||||
mode: "online" # set to online to activate wandb
|
|
||||||
entity: "viper_svg"
|
|
||||||
project: "online_sac"
|
|
||||||
|
|
||||||
hydra:
|
|
||||||
job:
|
|
||||||
chdir: True
|
|
@ -565,7 +565,6 @@ def make_train_fn(
|
|||||||
num_train_steps % eval_interval != 0
|
num_train_steps % eval_interval != 0
|
||||||
)
|
)
|
||||||
key, init_key = jax.random.split(key)
|
key, init_key = jax.random.split(key)
|
||||||
# TWK ??: We retain the same initial state for each of the seeds across all episodes?
|
|
||||||
train_state = jax.vmap(make_init(cfg, env, env_params))(
|
train_state = jax.vmap(make_init(cfg, env, env_params))(
|
||||||
jax.random.split(init_key, num_seeds)
|
jax.random.split(init_key, num_seeds)
|
||||||
)
|
)
|
||||||
@ -578,9 +577,6 @@ def make_train_fn(
|
|||||||
|
|
||||||
|
|
||||||
def plot_history(history: list[dict[str, jax.Array]]):
|
def plot_history(history: list[dict[str, jax.Array]]):
|
||||||
"""
|
|
||||||
TODO -- TWK: Possibly remove this...
|
|
||||||
"""
|
|
||||||
steps = jnp.array([m["time_step"][0] for m in history])
|
steps = jnp.array([m["time_step"][0] for m in history])
|
||||||
eval_return = jnp.array([m["eval/episode_return"].mean() for m in history])
|
eval_return = jnp.array([m["eval/episode_return"].mean() for m in history])
|
||||||
eval_return_std = jnp.array([m["eval/episode_return"].std() for m in history])
|
eval_return_std = jnp.array([m["eval/episode_return"].std() for m in history])
|
||||||
@ -692,10 +688,6 @@ def run(cfg: DictConfig):
|
|||||||
|
|
||||||
|
|
||||||
def tune(cfg: DictConfig):
|
def tune(cfg: DictConfig):
|
||||||
"""
|
|
||||||
TODO: Signature + also adjusting to run tuning for Brax environments as well
|
|
||||||
"""
|
|
||||||
|
|
||||||
def log_callback(state, metrics):
|
def log_callback(state, metrics):
|
||||||
episode_return = metrics["eval/episode_return"].mean()
|
episode_return = metrics["eval/episode_return"].mean()
|
||||||
t = state.time_steps[0]
|
t = state.time_steps[0]
|
||||||
|
@ -29,7 +29,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
from fast_sac_utils import (
|
from src.torchrl.reppo_util import (
|
||||||
EmpiricalNormalization,
|
EmpiricalNormalization,
|
||||||
PerTaskRewardNormalizer,
|
PerTaskRewardNormalizer,
|
||||||
RewardNormalizer,
|
RewardNormalizer,
|
||||||
@ -90,7 +90,7 @@ def main():
|
|||||||
print(f"Using device: {device}")
|
print(f"Using device: {device}")
|
||||||
|
|
||||||
if args.env_name.startswith("h1hand-") or args.env_name.startswith("h1-"):
|
if args.env_name.startswith("h1hand-") or args.env_name.startswith("h1-"):
|
||||||
from reppo.env_utils.torch_wrappers.humanoid_bench_env import (
|
from src.env_utils.torch_wrappers.humanoid_bench_env import (
|
||||||
HumanoidBenchEnv,
|
HumanoidBenchEnv,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -98,7 +98,7 @@ def main():
|
|||||||
envs = HumanoidBenchEnv(args.env_name, args.num_envs, device=device)
|
envs = HumanoidBenchEnv(args.env_name, args.num_envs, device=device)
|
||||||
eval_envs = envs
|
eval_envs = envs
|
||||||
elif args.env_name.startswith("Isaac-"):
|
elif args.env_name.startswith("Isaac-"):
|
||||||
from reppo.env_utils.torch_wrappers.isaaclab_env import IsaacLabEnv
|
from src.env_utils.torch_wrappers.isaaclab_env import IsaacLabEnv
|
||||||
|
|
||||||
env_type = "isaaclab"
|
env_type = "isaaclab"
|
||||||
envs = IsaacLabEnv(
|
envs = IsaacLabEnv(
|
||||||
@ -110,14 +110,14 @@ def main():
|
|||||||
)
|
)
|
||||||
eval_envs = envs
|
eval_envs = envs
|
||||||
elif args.env_name.startswith("MTBench-"):
|
elif args.env_name.startswith("MTBench-"):
|
||||||
from reppo.env_utils.torch_wrappers.mtbench_env import MTBenchEnv
|
from src.env_utils.torch_wrappers.mtbench_env import MTBenchEnv
|
||||||
|
|
||||||
env_name = "-".join(args.env_name.split("-")[1:])
|
env_name = "-".join(args.env_name.split("-")[1:])
|
||||||
env_type = "mtbench"
|
env_type = "mtbench"
|
||||||
envs = MTBenchEnv(env_name, args.device_rank, args.num_envs, args.seed)
|
envs = MTBenchEnv(env_name, args.device_rank, args.num_envs, args.seed)
|
||||||
eval_envs = envs
|
eval_envs = envs
|
||||||
else:
|
else:
|
||||||
from reppo.env_utils.torch_wrappers.mujoco_playground_env import make_env
|
from src.env_utils.torch_wrappers.mujoco_playground_env import make_env
|
||||||
|
|
||||||
# TODO: Check if re-using same envs for eval could reduce memory usage
|
# TODO: Check if re-using same envs for eval could reduce memory usage
|
||||||
env_type = "mujoco_playground"
|
env_type = "mujoco_playground"
|
||||||
@ -198,7 +198,7 @@ def main():
|
|||||||
|
|
||||||
if args.agent == "fasttd3":
|
if args.agent == "fasttd3":
|
||||||
if env_type in ["mtbench"]:
|
if env_type in ["mtbench"]:
|
||||||
from reppo.network_utils.fast_td3_nets import (
|
from src.network_utils.fast_td3_nets import (
|
||||||
MultiTaskActor,
|
MultiTaskActor,
|
||||||
MultiTaskCritic,
|
MultiTaskCritic,
|
||||||
)
|
)
|
||||||
@ -206,7 +206,7 @@ def main():
|
|||||||
actor_cls = MultiTaskActor
|
actor_cls = MultiTaskActor
|
||||||
critic_cls = MultiTaskCritic
|
critic_cls = MultiTaskCritic
|
||||||
else:
|
else:
|
||||||
from reppo.network_utils.fast_td3_nets import Actor, Critic
|
from src.network_utils.fast_td3_nets import Actor, Critic
|
||||||
|
|
||||||
actor_cls = Actor
|
actor_cls = Actor
|
||||||
critic_cls = Critic
|
critic_cls = Critic
|
||||||
@ -214,7 +214,7 @@ def main():
|
|||||||
print("Using FastTD3")
|
print("Using FastTD3")
|
||||||
elif args.agent == "fasttd3_simbav2":
|
elif args.agent == "fasttd3_simbav2":
|
||||||
if env_type in ["mtbench"]:
|
if env_type in ["mtbench"]:
|
||||||
from reppo.network_utils.fast_td3_nets_simbav2 import (
|
from src.network_utils.fast_td3_nets_simbav2 import (
|
||||||
MultiTaskActor,
|
MultiTaskActor,
|
||||||
MultiTaskCritic,
|
MultiTaskCritic,
|
||||||
)
|
)
|
||||||
@ -222,7 +222,7 @@ def main():
|
|||||||
actor_cls = MultiTaskActor
|
actor_cls = MultiTaskActor
|
||||||
critic_cls = MultiTaskCritic
|
critic_cls = MultiTaskCritic
|
||||||
else:
|
else:
|
||||||
from reppo.network_utils.fast_td3_nets_simbav2 import Actor, Critic
|
from src.network_utils.fast_td3_nets_simbav2 import Actor, Critic
|
||||||
|
|
||||||
actor_cls = Actor
|
actor_cls = Actor
|
||||||
critic_cls = Critic
|
critic_cls = Critic
|
||||||
|
Loading…
Reference in New Issue
Block a user