Update example.py to new func
This commit is contained in:
parent
1d8d217ec0
commit
0808655136
@ -1,7 +1,10 @@
|
||||
policy:
|
||||
input_dim: 4
|
||||
output_dim: 2
|
||||
actor:
|
||||
hidden_sizes: [64, 64]
|
||||
activation_fn: "ReLU"
|
||||
|
||||
critic:
|
||||
hidden_sizes: [64, 64]
|
||||
activation_fn: "ReLU"
|
||||
|
||||
ppo:
|
||||
learning_rate: 3e-4
|
||||
@ -15,11 +18,11 @@ ppo:
|
||||
eval_interval: 2048
|
||||
eval_deterministic: true
|
||||
eval_episodes: 10
|
||||
seed: 42
|
||||
|
||||
loggers:
|
||||
- type: terminal
|
||||
- type: wandb
|
||||
- backend: 'wandb'
|
||||
logger_name: "ppo"
|
||||
experiment_name: "PPO"
|
||||
project: "PPO_project"
|
||||
entity: "your_entity"
|
||||
push_interval: 10
|
||||
|
@ -1,35 +1,31 @@
|
||||
import yaml
|
||||
import torch
|
||||
from fancy_rl.ppo import PPO
|
||||
from fancy_rl.policy import Policy
|
||||
from fancy_rl.loggers import TerminalLogger, WandbLogger
|
||||
from ppo import PPO
|
||||
from torchrl.record.loggers import get_logger
|
||||
import gymnasium as gym
|
||||
|
||||
def main(config_file):
|
||||
with open(config_file, 'r') as file:
|
||||
config = yaml.safe_load(file)
|
||||
|
||||
env_fn = lambda: gym.make("CartPole-v1")
|
||||
env = env_fn()
|
||||
|
||||
policy_config = config['policy']
|
||||
policy = Policy(env=env, hidden_sizes=policy_config['hidden_sizes'])
|
||||
env_spec = "CartPole-v1"
|
||||
|
||||
ppo_config = config['ppo']
|
||||
loggers_config = config['loggers']
|
||||
actor_config = config['actor']
|
||||
critic_config = config['critic']
|
||||
loggers_config = config.get('loggers', [])
|
||||
|
||||
loggers = []
|
||||
for logger_config in loggers_config:
|
||||
logger_type = logger_config.pop('type')
|
||||
if logger_type == 'terminal':
|
||||
loggers.append(TerminalLogger(**logger_config))
|
||||
elif logger_type == 'wandb':
|
||||
loggers.append(WandbLogger(**logger_config))
|
||||
loggers = [get_logger(**logger_config) for logger_config in loggers_config]
|
||||
|
||||
ppo = PPO(policy=policy,
|
||||
env_fn=env_fn,
|
||||
loggers=loggers,
|
||||
**ppo_config)
|
||||
ppo = PPO(
|
||||
env_spec=env_spec,
|
||||
loggers=loggers,
|
||||
actor_hidden_sizes=actor_config['hidden_sizes'],
|
||||
critic_hidden_sizes=critic_config['hidden_sizes'],
|
||||
actor_activation_fn=actor_config['activation_fn'],
|
||||
critic_activation_fn=critic_config['activation_fn'],
|
||||
**ppo_config
|
||||
)
|
||||
|
||||
ppo.train()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user