Update example.py to new func

This commit is contained in:
Dominik Moritz Roth 2024-05-31 13:04:56 +02:00
parent 1d8d217ec0
commit 0808655136
2 changed files with 25 additions and 26 deletions

View File

@ -1,7 +1,10 @@
policy:
input_dim: 4
output_dim: 2
actor:
hidden_sizes: [64, 64]
activation_fn: "ReLU"
critic:
hidden_sizes: [64, 64]
activation_fn: "ReLU"
ppo:
learning_rate: 3e-4
@ -15,11 +18,11 @@ ppo:
eval_interval: 2048
eval_deterministic: true
eval_episodes: 10
seed: 42
loggers:
- type: terminal
- type: wandb
- backend: 'wandb'
logger_name: "ppo"
experiment_name: "PPO"
project: "PPO_project"
entity: "your_entity"
push_interval: 10

View File

@ -1,35 +1,31 @@
import yaml
import torch
from fancy_rl.ppo import PPO
from fancy_rl.policy import Policy
from fancy_rl.loggers import TerminalLogger, WandbLogger
from ppo import PPO
from torchrl.record.loggers import get_logger
import gymnasium as gym
def main(config_file):
with open(config_file, 'r') as file:
config = yaml.safe_load(file)
env_fn = lambda: gym.make("CartPole-v1")
env = env_fn()
policy_config = config['policy']
policy = Policy(env=env, hidden_sizes=policy_config['hidden_sizes'])
env_spec = "CartPole-v1"
ppo_config = config['ppo']
loggers_config = config['loggers']
actor_config = config['actor']
critic_config = config['critic']
loggers_config = config.get('loggers', [])
loggers = []
for logger_config in loggers_config:
logger_type = logger_config.pop('type')
if logger_type == 'terminal':
loggers.append(TerminalLogger(**logger_config))
elif logger_type == 'wandb':
loggers.append(WandbLogger(**logger_config))
loggers = [get_logger(**logger_config) for logger_config in loggers_config]
ppo = PPO(policy=policy,
env_fn=env_fn,
ppo = PPO(
env_spec=env_spec,
loggers=loggers,
**ppo_config)
actor_hidden_sizes=actor_config['hidden_sizes'],
critic_hidden_sizes=critic_config['hidden_sizes'],
actor_activation_fn=actor_config['activation_fn'],
critic_activation_fn=critic_config['activation_fn'],
**ppo_config
)
ppo.train()