Update example.py to new func
This commit is contained in:
parent
1d8d217ec0
commit
0808655136
@ -1,7 +1,10 @@
|
|||||||
policy:
|
actor:
|
||||||
input_dim: 4
|
|
||||||
output_dim: 2
|
|
||||||
hidden_sizes: [64, 64]
|
hidden_sizes: [64, 64]
|
||||||
|
activation_fn: "ReLU"
|
||||||
|
|
||||||
|
critic:
|
||||||
|
hidden_sizes: [64, 64]
|
||||||
|
activation_fn: "ReLU"
|
||||||
|
|
||||||
ppo:
|
ppo:
|
||||||
learning_rate: 3e-4
|
learning_rate: 3e-4
|
||||||
@ -15,11 +18,11 @@ ppo:
|
|||||||
eval_interval: 2048
|
eval_interval: 2048
|
||||||
eval_deterministic: true
|
eval_deterministic: true
|
||||||
eval_episodes: 10
|
eval_episodes: 10
|
||||||
seed: 42
|
|
||||||
|
|
||||||
loggers:
|
loggers:
|
||||||
- type: terminal
|
- backend: 'wandb'
|
||||||
- type: wandb
|
logger_name: "ppo"
|
||||||
|
experiment_name: "PPO"
|
||||||
project: "PPO_project"
|
project: "PPO_project"
|
||||||
entity: "your_entity"
|
entity: "your_entity"
|
||||||
push_interval: 10
|
push_interval: 10
|
||||||
|
@ -1,35 +1,31 @@
|
|||||||
import yaml
|
import yaml
|
||||||
import torch
|
import torch
|
||||||
from fancy_rl.ppo import PPO
|
from ppo import PPO
|
||||||
from fancy_rl.policy import Policy
|
from torchrl.record.loggers import get_logger
|
||||||
from fancy_rl.loggers import TerminalLogger, WandbLogger
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
|
|
||||||
def main(config_file):
|
def main(config_file):
|
||||||
with open(config_file, 'r') as file:
|
with open(config_file, 'r') as file:
|
||||||
config = yaml.safe_load(file)
|
config = yaml.safe_load(file)
|
||||||
|
|
||||||
env_fn = lambda: gym.make("CartPole-v1")
|
env_spec = "CartPole-v1"
|
||||||
env = env_fn()
|
|
||||||
|
|
||||||
policy_config = config['policy']
|
|
||||||
policy = Policy(env=env, hidden_sizes=policy_config['hidden_sizes'])
|
|
||||||
|
|
||||||
ppo_config = config['ppo']
|
ppo_config = config['ppo']
|
||||||
loggers_config = config['loggers']
|
actor_config = config['actor']
|
||||||
|
critic_config = config['critic']
|
||||||
|
loggers_config = config.get('loggers', [])
|
||||||
|
|
||||||
loggers = []
|
loggers = [get_logger(**logger_config) for logger_config in loggers_config]
|
||||||
for logger_config in loggers_config:
|
|
||||||
logger_type = logger_config.pop('type')
|
|
||||||
if logger_type == 'terminal':
|
|
||||||
loggers.append(TerminalLogger(**logger_config))
|
|
||||||
elif logger_type == 'wandb':
|
|
||||||
loggers.append(WandbLogger(**logger_config))
|
|
||||||
|
|
||||||
ppo = PPO(policy=policy,
|
ppo = PPO(
|
||||||
env_fn=env_fn,
|
env_spec=env_spec,
|
||||||
loggers=loggers,
|
loggers=loggers,
|
||||||
**ppo_config)
|
actor_hidden_sizes=actor_config['hidden_sizes'],
|
||||||
|
critic_hidden_sizes=critic_config['hidden_sizes'],
|
||||||
|
actor_activation_fn=actor_config['activation_fn'],
|
||||||
|
critic_activation_fn=critic_config['activation_fn'],
|
||||||
|
**ppo_config
|
||||||
|
)
|
||||||
|
|
||||||
ppo.train()
|
ppo.train()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user