Update example.py to new func

This commit is contained in:
Dominik Moritz Roth 2024-05-31 13:04:56 +02:00
parent 1d8d217ec0
commit 0808655136
2 changed files with 25 additions and 26 deletions

View File

@ -1,7 +1,10 @@
policy: actor:
input_dim: 4
output_dim: 2
hidden_sizes: [64, 64] hidden_sizes: [64, 64]
activation_fn: "ReLU"
critic:
hidden_sizes: [64, 64]
activation_fn: "ReLU"
ppo: ppo:
learning_rate: 3e-4 learning_rate: 3e-4
@ -15,11 +18,11 @@ ppo:
eval_interval: 2048 eval_interval: 2048
eval_deterministic: true eval_deterministic: true
eval_episodes: 10 eval_episodes: 10
seed: 42
loggers: loggers:
- type: terminal - backend: 'wandb'
- type: wandb logger_name: "ppo"
experiment_name: "PPO"
project: "PPO_project" project: "PPO_project"
entity: "your_entity" entity: "your_entity"
push_interval: 10 push_interval: 10

View File

@ -1,35 +1,31 @@
import yaml import yaml
import torch import torch
from fancy_rl.ppo import PPO from ppo import PPO
from fancy_rl.policy import Policy from torchrl.record.loggers import get_logger
from fancy_rl.loggers import TerminalLogger, WandbLogger
import gymnasium as gym import gymnasium as gym
def main(config_file): def main(config_file):
with open(config_file, 'r') as file: with open(config_file, 'r') as file:
config = yaml.safe_load(file) config = yaml.safe_load(file)
env_fn = lambda: gym.make("CartPole-v1") env_spec = "CartPole-v1"
env = env_fn()
policy_config = config['policy']
policy = Policy(env=env, hidden_sizes=policy_config['hidden_sizes'])
ppo_config = config['ppo'] ppo_config = config['ppo']
loggers_config = config['loggers'] actor_config = config['actor']
critic_config = config['critic']
loggers_config = config.get('loggers', [])
loggers = [] loggers = [get_logger(**logger_config) for logger_config in loggers_config]
for logger_config in loggers_config:
logger_type = logger_config.pop('type')
if logger_type == 'terminal':
loggers.append(TerminalLogger(**logger_config))
elif logger_type == 'wandb':
loggers.append(WandbLogger(**logger_config))
ppo = PPO(policy=policy, ppo = PPO(
env_fn=env_fn, env_spec=env_spec,
loggers=loggers, loggers=loggers,
**ppo_config) actor_hidden_sizes=actor_config['hidden_sizes'],
critic_hidden_sizes=critic_config['hidden_sizes'],
actor_activation_fn=actor_config['activation_fn'],
critic_activation_fn=critic_config['activation_fn'],
**ppo_config
)
ppo.train() ppo.train()