diff --git a/example/config.yaml b/example/config.yaml index 89ca90a..24854cf 100644 --- a/example/config.yaml +++ b/example/config.yaml @@ -1,7 +1,10 @@ -policy: - input_dim: 4 - output_dim: 2 +actor: hidden_sizes: [64, 64] + activation_fn: "ReLU" + +critic: + hidden_sizes: [64, 64] + activation_fn: "ReLU" ppo: learning_rate: 3e-4 @@ -15,11 +18,11 @@ ppo: eval_interval: 2048 eval_deterministic: true eval_episodes: 10 - seed: 42 loggers: - - type: terminal - - type: wandb + - backend: 'wandb' + logger_name: "ppo" + experiment_name: "PPO" project: "PPO_project" entity: "your_entity" push_interval: 10 diff --git a/example/example.py b/example/example.py index 5baf4ba..61ccbc3 100644 --- a/example/example.py +++ b/example/example.py @@ -1,35 +1,31 @@ import yaml import torch -from fancy_rl.ppo import PPO -from fancy_rl.policy import Policy -from fancy_rl.loggers import TerminalLogger, WandbLogger +from ppo import PPO +from torchrl.record.loggers import get_logger import gymnasium as gym def main(config_file): with open(config_file, 'r') as file: config = yaml.safe_load(file) - env_fn = lambda: gym.make("CartPole-v1") - env = env_fn() - - policy_config = config['policy'] - policy = Policy(env=env, hidden_sizes=policy_config['hidden_sizes']) + env_spec = "CartPole-v1" ppo_config = config['ppo'] - loggers_config = config['loggers'] + actor_config = config['actor'] + critic_config = config['critic'] + loggers_config = config.get('loggers', []) - loggers = [] - for logger_config in loggers_config: - logger_type = logger_config.pop('type') - if logger_type == 'terminal': - loggers.append(TerminalLogger(**logger_config)) - elif logger_type == 'wandb': - loggers.append(WandbLogger(**logger_config)) + loggers = [get_logger(**logger_config) for logger_config in loggers_config] - ppo = PPO(policy=policy, - env_fn=env_fn, - loggers=loggers, - **ppo_config) + ppo = PPO( + env_spec=env_spec, + loggers=loggers, + actor_hidden_sizes=actor_config['hidden_sizes'], + critic_hidden_sizes=critic_config['hidden_sizes'], + actor_activation_fn=actor_config['activation_fn'], + critic_activation_fn=critic_config['activation_fn'], + **ppo_config + ) ppo.train()