fancy_rl/example/example.py

34 lines
888 B
Python
Raw Permalink Normal View History

2024-05-29 21:21:43 +02:00
import yaml
import torch
2024-05-31 13:04:56 +02:00
from ppo import PPO
from torchrl.record.loggers import get_logger
2024-05-29 21:21:43 +02:00
import gymnasium as gym
def main(config_file):
with open(config_file, 'r') as file:
config = yaml.safe_load(file)
2024-05-31 13:04:56 +02:00
env_spec = "CartPole-v1"
2024-05-29 21:21:43 +02:00
ppo_config = config['ppo']
2024-05-31 13:04:56 +02:00
actor_config = config['actor']
critic_config = config['critic']
loggers_config = config.get('loggers', [])
loggers = [get_logger(**logger_config) for logger_config in loggers_config]
ppo = PPO(
env_spec=env_spec,
loggers=loggers,
actor_hidden_sizes=actor_config['hidden_sizes'],
critic_hidden_sizes=critic_config['hidden_sizes'],
actor_activation_fn=actor_config['activation_fn'],
critic_activation_fn=critic_config['activation_fn'],
**ppo_config
)
2024-05-29 21:21:43 +02:00
ppo.train()
if __name__ == "__main__":
main("example/config.yaml")