Compare commits
5 Commits
dc70d045ab
...
360d2569f0
Author | SHA1 | Date | |
---|---|---|---|
360d2569f0 | |||
0808655136 | |||
1d8d217ec0 | |||
1d1d9060f9 | |||
7ea0bdcec6 |
35
README.md
35
README.md
@ -6,7 +6,7 @@
|
|||||||
<br><br>
|
<br><br>
|
||||||
</h1>
|
</h1>
|
||||||
|
|
||||||
Fancy RL is a minimalistic and efficient implementation of Proximal Policy Optimization (PPO) and Trust Region Policy Layers (TRPL) using primitives from [torchrl](https://pypi.org/project/torchrl/). Future plans include implementing Soft Actor-Critic (SAC). This library focuses on providing clean, understandable code and reusable modules while leveraging the powerful functionalities of torchrl. We provide optional integration with wandb.
|
Fancy RL provides a minimalistic and efficient implementation of Proximal Policy Optimization (PPO) and Trust Region Policy Layers (TRPL) using primitives from [torchrl](https://pypi.org/project/torchrl/). This library focuses on providing clean, understandable code and reusable modules while leveraging the powerful functionalities of torchrl.
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
@ -22,28 +22,19 @@ Fancy RL provides two main components:
|
|||||||
|
|
||||||
1. **Ready-to-use Classes for PPO / TRPL**: These classes allow you to quickly get started with reinforcement learning algorithms, enjoying the performance and hackability that comes with using TorchRL.
|
1. **Ready-to-use Classes for PPO / TRPL**: These classes allow you to quickly get started with reinforcement learning algorithms, enjoying the performance and hackability that comes with using TorchRL.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from fancy_rl.ppo import PPO
|
from ppo import PPO
|
||||||
from fancy_rl.policy import Policy
|
import gymnasium as gym
|
||||||
import gymnasium as gym
|
|
||||||
|
|
||||||
def env_fn():
|
|
||||||
return gym.make("CartPole-v1")
|
|
||||||
|
|
||||||
# Create policy
|
|
||||||
env = env_fn()
|
|
||||||
policy = Policy(env.observation_space, env.action_space)
|
|
||||||
|
|
||||||
# Create PPO instance with default config
|
|
||||||
ppo = PPO(policy=policy, env_fn=env_fn)
|
|
||||||
|
|
||||||
# Train the agent
|
|
||||||
ppo.train()
|
|
||||||
```
|
|
||||||
|
|
||||||
For environments, you can pass any torchrl environments, gymnasium environments (which we handle with a compatibility layer), or a string which we will interpret as a gymnasium ID.
|
env_spec = "CartPole-v1"
|
||||||
|
ppo = PPO(env_spec)
|
||||||
|
ppo.train()
|
||||||
|
```
|
||||||
|
|
||||||
2. **Additional Modules for TRPL**: Designed to integrate with torchrl's primitives-first approach, these modules are ideal for building custom algorithms with precise trust region projections. For detailed documentation, refer to the [docs](#).
|
For environments, you can pass any gymnasium environment ID as a string, a function returning a gymnasium environment, or an already instantiated gymnasium environment. Future plans include supporting other torchrl environments.
|
||||||
|
Check 'example/example.py' for a more complete usage example.
|
||||||
|
|
||||||
|
2. **Additional Modules for TRPL**: Designed to integrate with torchrl's primitives-first approach, these modules are ideal for building custom algorithms with precise trust region projections.
|
||||||
|
|
||||||
### Background on Trust Region Policy Layers (TRPL)
|
### Background on Trust Region Policy Layers (TRPL)
|
||||||
|
|
||||||
@ -65,4 +56,4 @@ Contributions are welcome! Feel free to open issues or submit pull requests to e
|
|||||||
|
|
||||||
## License
|
## License
|
||||||
|
|
||||||
This project is licensed under the MIT License.
|
This project is licensed under the MIT License.
|
||||||
|
@ -1,7 +1,10 @@
|
|||||||
policy:
|
actor:
|
||||||
input_dim: 4
|
|
||||||
output_dim: 2
|
|
||||||
hidden_sizes: [64, 64]
|
hidden_sizes: [64, 64]
|
||||||
|
activation_fn: "ReLU"
|
||||||
|
|
||||||
|
critic:
|
||||||
|
hidden_sizes: [64, 64]
|
||||||
|
activation_fn: "ReLU"
|
||||||
|
|
||||||
ppo:
|
ppo:
|
||||||
learning_rate: 3e-4
|
learning_rate: 3e-4
|
||||||
@ -15,11 +18,11 @@ ppo:
|
|||||||
eval_interval: 2048
|
eval_interval: 2048
|
||||||
eval_deterministic: true
|
eval_deterministic: true
|
||||||
eval_episodes: 10
|
eval_episodes: 10
|
||||||
seed: 42
|
|
||||||
|
|
||||||
loggers:
|
loggers:
|
||||||
- type: terminal
|
- backend: 'wandb'
|
||||||
- type: wandb
|
logger_name: "ppo"
|
||||||
|
experiment_name: "PPO"
|
||||||
project: "PPO_project"
|
project: "PPO_project"
|
||||||
entity: "your_entity"
|
entity: "your_entity"
|
||||||
push_interval: 10
|
push_interval: 10
|
||||||
|
@ -1,35 +1,31 @@
|
|||||||
import yaml
|
import yaml
|
||||||
import torch
|
import torch
|
||||||
from fancy_rl.ppo import PPO
|
from ppo import PPO
|
||||||
from fancy_rl.policy import Policy
|
from torchrl.record.loggers import get_logger
|
||||||
from fancy_rl.loggers import TerminalLogger, WandbLogger
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
|
|
||||||
def main(config_file):
|
def main(config_file):
|
||||||
with open(config_file, 'r') as file:
|
with open(config_file, 'r') as file:
|
||||||
config = yaml.safe_load(file)
|
config = yaml.safe_load(file)
|
||||||
|
|
||||||
env_fn = lambda: gym.make("CartPole-v1")
|
env_spec = "CartPole-v1"
|
||||||
env = env_fn()
|
|
||||||
|
|
||||||
policy_config = config['policy']
|
|
||||||
policy = Policy(env=env, hidden_sizes=policy_config['hidden_sizes'])
|
|
||||||
|
|
||||||
ppo_config = config['ppo']
|
ppo_config = config['ppo']
|
||||||
loggers_config = config['loggers']
|
actor_config = config['actor']
|
||||||
|
critic_config = config['critic']
|
||||||
|
loggers_config = config.get('loggers', [])
|
||||||
|
|
||||||
loggers = []
|
loggers = [get_logger(**logger_config) for logger_config in loggers_config]
|
||||||
for logger_config in loggers_config:
|
|
||||||
logger_type = logger_config.pop('type')
|
|
||||||
if logger_type == 'terminal':
|
|
||||||
loggers.append(TerminalLogger(**logger_config))
|
|
||||||
elif logger_type == 'wandb':
|
|
||||||
loggers.append(WandbLogger(**logger_config))
|
|
||||||
|
|
||||||
ppo = PPO(policy=policy,
|
ppo = PPO(
|
||||||
env_fn=env_fn,
|
env_spec=env_spec,
|
||||||
loggers=loggers,
|
loggers=loggers,
|
||||||
**ppo_config)
|
actor_hidden_sizes=actor_config['hidden_sizes'],
|
||||||
|
critic_hidden_sizes=critic_config['hidden_sizes'],
|
||||||
|
actor_activation_fn=actor_config['activation_fn'],
|
||||||
|
critic_activation_fn=critic_config['activation_fn'],
|
||||||
|
**ppo_config
|
||||||
|
)
|
||||||
|
|
||||||
ppo.train()
|
ppo.train()
|
||||||
|
|
||||||
|
@ -1,36 +0,0 @@
|
|||||||
class Logger:
|
|
||||||
def __init__(self, push_interval=1):
|
|
||||||
self.data = {}
|
|
||||||
self.push_interval = push_interval
|
|
||||||
|
|
||||||
def log(self, key, value, epoch):
|
|
||||||
if key not in self.data:
|
|
||||||
self.data[key] = []
|
|
||||||
self.data[key].append((epoch, value))
|
|
||||||
|
|
||||||
def end_of_epoch(self, epoch):
|
|
||||||
if epoch % self.push_interval == 0:
|
|
||||||
self.push()
|
|
||||||
|
|
||||||
def push(self):
|
|
||||||
raise NotImplementedError("Push method should be implemented by subclasses")
|
|
||||||
|
|
||||||
class TerminalLogger(Logger):
|
|
||||||
def push(self):
|
|
||||||
for key, values in self.data.items():
|
|
||||||
for epoch, value in values:
|
|
||||||
print(f"Epoch {epoch}: {key} = {value}")
|
|
||||||
self.data = {}
|
|
||||||
|
|
||||||
class WandbLogger(Logger):
|
|
||||||
def __init__(self, project, entity, config, push_interval=1):
|
|
||||||
super().__init__(push_interval)
|
|
||||||
import wandb
|
|
||||||
self.wandb = wandb
|
|
||||||
self.wandb.init(project=project, entity=entity, config=config)
|
|
||||||
|
|
||||||
def push(self):
|
|
||||||
for key, values in self.data.items():
|
|
||||||
for epoch, value in values:
|
|
||||||
self.wandb.log({key: value, 'epoch': epoch})
|
|
||||||
self.data = {}
|
|
@ -1,13 +1,20 @@
|
|||||||
import torch
|
import torch
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from fancy_rl.loggers import Logger
|
from torchrl.record.loggers import Logger
|
||||||
from torch.optim import Adam
|
from torch.optim import Adam
|
||||||
|
from torchrl.collectors import SyncDataCollector
|
||||||
|
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
|
||||||
|
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
|
||||||
|
from torchrl.envs import ExplorationType, set_exploration_type
|
||||||
|
from torchrl.envs.libs.gym import GymWrapper
|
||||||
|
from torchrl.record import VideoRecorder
|
||||||
|
import gymnasium as gym
|
||||||
|
|
||||||
class OnPolicy(ABC):
|
class OnPolicy(ABC):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
policy,
|
policy,
|
||||||
env_fn,
|
env_spec,
|
||||||
loggers,
|
loggers,
|
||||||
learning_rate,
|
learning_rate,
|
||||||
n_steps,
|
n_steps,
|
||||||
@ -21,11 +28,14 @@ class OnPolicy(ABC):
|
|||||||
entropy_coef,
|
entropy_coef,
|
||||||
critic_coef,
|
critic_coef,
|
||||||
normalize_advantage,
|
normalize_advantage,
|
||||||
|
clip_range=0.2,
|
||||||
device=None,
|
device=None,
|
||||||
**kwargs
|
eval_episodes=10,
|
||||||
|
env_spec_eval=None,
|
||||||
):
|
):
|
||||||
self.policy = policy
|
self.policy = policy
|
||||||
self.env_fn = env_fn
|
self.env_spec = env_spec
|
||||||
|
self.env_spec_eval = env_spec_eval if env_spec_eval is not None else env_spec
|
||||||
self.loggers = loggers
|
self.loggers = loggers
|
||||||
self.learning_rate = learning_rate
|
self.learning_rate = learning_rate
|
||||||
self.n_steps = n_steps
|
self.n_steps = n_steps
|
||||||
@ -39,93 +49,93 @@ class OnPolicy(ABC):
|
|||||||
self.entropy_coef = entropy_coef
|
self.entropy_coef = entropy_coef
|
||||||
self.critic_coef = critic_coef
|
self.critic_coef = critic_coef
|
||||||
self.normalize_advantage = normalize_advantage
|
self.normalize_advantage = normalize_advantage
|
||||||
|
self.clip_range = clip_range
|
||||||
self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
|
self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
self.eval_episodes = eval_episodes
|
||||||
|
|
||||||
self.kwargs = kwargs
|
# Create collector
|
||||||
self.clip_range = 0.2
|
self.collector = SyncDataCollector(
|
||||||
|
create_env_fn=lambda: self.make_env(eval=False),
|
||||||
|
policy=self.policy,
|
||||||
|
frames_per_batch=self.n_steps,
|
||||||
|
total_frames=self.total_timesteps,
|
||||||
|
device=self.device,
|
||||||
|
storing_device=self.device,
|
||||||
|
max_frames_per_traj=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create data buffer
|
||||||
|
self.sampler = SamplerWithoutReplacement()
|
||||||
|
self.data_buffer = TensorDictReplayBuffer(
|
||||||
|
storage=LazyMemmapStorage(self.n_steps),
|
||||||
|
sampler=self.sampler,
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
def make_env(self, eval=False):
|
||||||
|
"""Creates an environment and wraps it if necessary."""
|
||||||
|
env_spec = self.env_spec_eval if eval else self.env_spec
|
||||||
|
if isinstance(env_spec, str):
|
||||||
|
env = gym.make(env_spec)
|
||||||
|
env = GymWrapper(env)
|
||||||
|
elif callable(env_spec):
|
||||||
|
env = env_spec()
|
||||||
|
if isinstance(env, gym.Env):
|
||||||
|
env = GymWrapper(env)
|
||||||
|
else:
|
||||||
|
raise ValueError("env_spec must be a string or a callable that returns an environment.")
|
||||||
|
return env
|
||||||
|
|
||||||
def train(self):
|
def train(self):
|
||||||
self.env = self.env_fn()
|
collected_frames = 0
|
||||||
self.env.reset(seed=self.kwargs.get("seed", None))
|
|
||||||
|
|
||||||
state = self.env.reset(seed=self.kwargs.get("seed", None))
|
for t, data in enumerate(self.collector):
|
||||||
episode_return = 0
|
frames_in_batch = data.numel()
|
||||||
episode_length = 0
|
collected_frames += frames_in_batch
|
||||||
for t in range(self.total_timesteps):
|
|
||||||
rollout = self.collect_rollouts(state)
|
for _ in range(self.n_epochs):
|
||||||
for batch in self.get_batches(rollout):
|
with torch.no_grad():
|
||||||
loss = self.train_step(batch)
|
data = self.adv_module(data)
|
||||||
for logger in self.loggers:
|
data_reshape = data.reshape(-1)
|
||||||
logger.log({
|
self.data_buffer.extend(data_reshape)
|
||||||
"loss": loss.item()
|
|
||||||
}, epoch=t)
|
for batch in self.data_buffer:
|
||||||
|
batch = batch.to(self.device)
|
||||||
if (t + 1) % self.eval_interval == 0:
|
loss = self.train_step(batch)
|
||||||
self.evaluate(t)
|
for logger in self.loggers:
|
||||||
|
logger.log_scalar({"loss": loss.item()}, step=collected_frames)
|
||||||
|
|
||||||
|
if (t + 1) % self.eval_interval == 0:
|
||||||
|
self.evaluate(t)
|
||||||
|
|
||||||
|
self.collector.update_policy_weights_()
|
||||||
|
|
||||||
def evaluate(self, epoch):
|
def evaluate(self, epoch):
|
||||||
eval_env = self.env_fn()
|
eval_env = self.make_env(eval=True)
|
||||||
eval_env.reset(seed=self.kwargs.get("seed", None))
|
eval_env.eval()
|
||||||
returns = []
|
|
||||||
for _ in range(self.kwargs.get("eval_episodes", 10)):
|
|
||||||
state = eval_env.reset(seed=self.kwargs.get("seed", None))
|
|
||||||
done = False
|
|
||||||
total_return = 0
|
|
||||||
while not done:
|
|
||||||
with torch.no_grad():
|
|
||||||
action = (
|
|
||||||
self.policy.act(state, deterministic=self.eval_deterministic)
|
|
||||||
if self.eval_deterministic
|
|
||||||
else self.policy.act(state)
|
|
||||||
)
|
|
||||||
state, reward, done, _ = eval_env.step(action)
|
|
||||||
total_return += reward
|
|
||||||
returns.append(total_return)
|
|
||||||
|
|
||||||
avg_return = sum(returns) / len(returns)
|
test_rewards = []
|
||||||
|
for _ in range(self.eval_episodes):
|
||||||
|
with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
|
||||||
|
td_test = eval_env.rollout(
|
||||||
|
policy=self.policy,
|
||||||
|
auto_reset=True,
|
||||||
|
auto_cast_to_device=True,
|
||||||
|
break_when_any_done=True,
|
||||||
|
max_steps=10_000_000,
|
||||||
|
)
|
||||||
|
reward = td_test["next", "episode_reward"][td_test["next", "done"]]
|
||||||
|
test_rewards.append(reward.cpu())
|
||||||
|
eval_env.apply(dump_video)
|
||||||
|
|
||||||
|
avg_return = torch.cat(test_rewards, 0).mean().item()
|
||||||
for logger in self.loggers:
|
for logger in self.loggers:
|
||||||
logger.log({"eval_avg_return": avg_return}, epoch=epoch)
|
logger.log_scalar({"eval_avg_return": avg_return}, step=epoch)
|
||||||
|
|
||||||
def collect_rollouts(self, state):
|
|
||||||
# Collect rollouts logic
|
|
||||||
rollouts = []
|
|
||||||
for _ in range(self.n_steps):
|
|
||||||
action = self.policy.act(state)
|
|
||||||
next_state, reward, done, _ = self.env.step(action)
|
|
||||||
rollouts.append((state, action, reward, next_state, done))
|
|
||||||
state = next_state
|
|
||||||
if done:
|
|
||||||
state = self.env.reset(seed=self.kwargs.get("seed", None))
|
|
||||||
return rollouts
|
|
||||||
|
|
||||||
def get_batches(self, rollouts):
|
|
||||||
data = self.prepare_data(rollouts)
|
|
||||||
n_batches = len(data) // self.batch_size
|
|
||||||
batches = []
|
|
||||||
for _ in range(n_batches):
|
|
||||||
batch_indices = torch.randint(0, len(data), (self.batch_size,))
|
|
||||||
batch = data[batch_indices]
|
|
||||||
batches.append(batch)
|
|
||||||
return batches
|
|
||||||
|
|
||||||
def prepare_data(self, rollouts):
|
|
||||||
obs, actions, rewards, next_obs, dones = zip(*rollouts)
|
|
||||||
obs = torch.tensor(obs, dtype=torch.float32)
|
|
||||||
actions = torch.tensor(actions, dtype=torch.int64)
|
|
||||||
rewards = torch.tensor(rewards, dtype=torch.float32)
|
|
||||||
next_obs = torch.tensor(next_obs, dtype=torch.float32)
|
|
||||||
dones = torch.tensor(dones, dtype=torch.float32)
|
|
||||||
|
|
||||||
data = {
|
|
||||||
"obs": obs,
|
|
||||||
"actions": actions,
|
|
||||||
"rewards": rewards,
|
|
||||||
"next_obs": next_obs,
|
|
||||||
"dones": dones
|
|
||||||
}
|
|
||||||
data = self.adv_module(data)
|
|
||||||
return data
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def train_step(self, batch):
|
def train_step(self, batch):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def dump_video(module):
|
||||||
|
if isinstance(module, VideoRecorder):
|
||||||
|
module.dump()
|
||||||
|
@ -1,27 +1,71 @@
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from torch.distributions import Categorical, Normal
|
||||||
|
import gymnasium as gym
|
||||||
|
|
||||||
class Policy(nn.Module):
|
class Actor(nn.Module):
|
||||||
def __init__(self, input_dim, output_dim, hidden_sizes=[64, 64]):
|
def __init__(self, observation_space, action_space, hidden_sizes=[64, 64], activation_fn=nn.ReLU):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.continuous = isinstance(action_space, gym.spaces.Box)
|
||||||
|
input_dim = observation_space.shape[-1]
|
||||||
|
if self.continuous:
|
||||||
|
output_dim = action_space.shape[-1]
|
||||||
|
else:
|
||||||
|
output_dim = action_space.n
|
||||||
|
|
||||||
layers = []
|
layers = []
|
||||||
last_dim = input_dim
|
last_dim = input_dim
|
||||||
for size in hidden_sizes:
|
for size in hidden_sizes:
|
||||||
layers.append(nn.Linear(last_dim, size))
|
layers.append(nn.Linear(last_dim, size))
|
||||||
layers.append(nn.ReLU())
|
layers.append(activation_fn())
|
||||||
last_dim = size
|
last_dim = size
|
||||||
layers.append(nn.Linear(last_dim, output_dim))
|
|
||||||
self.model = nn.Sequential(*layers)
|
if self.continuous:
|
||||||
|
self.mu_layer = nn.Linear(last_dim, output_dim)
|
||||||
|
self.log_std_layer = nn.Linear(last_dim, output_dim)
|
||||||
|
else:
|
||||||
|
layers.append(nn.Linear(last_dim, output_dim))
|
||||||
|
self.model = nn.Sequential(*layers)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.model(x)
|
if self.continuous:
|
||||||
|
mu = self.mu_layer(x)
|
||||||
|
log_std = self.log_std_layer(x)
|
||||||
|
return mu, log_std.exp()
|
||||||
|
else:
|
||||||
|
return self.model(x)
|
||||||
|
|
||||||
def act(self, observation, deterministic=False):
|
def act(self, observation, deterministic=False):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
logits = self.forward(observation)
|
if self.continuous:
|
||||||
if deterministic:
|
mu, std = self.forward(observation)
|
||||||
action = logits.argmax(dim=-1)
|
if deterministic:
|
||||||
|
action = mu
|
||||||
|
else:
|
||||||
|
action_dist = Normal(mu, std)
|
||||||
|
action = action_dist.sample()
|
||||||
else:
|
else:
|
||||||
action_dist = torch.distributions.Categorical(logits=logits)
|
logits = self.forward(observation)
|
||||||
action = action_dist.sample()
|
if deterministic:
|
||||||
|
action = logits.argmax(dim=-1)
|
||||||
|
else:
|
||||||
|
action_dist = Categorical(logits=logits)
|
||||||
|
action = action_dist.sample()
|
||||||
return action
|
return action
|
||||||
|
|
||||||
|
class Critic(nn.Module):
|
||||||
|
def __init__(self, observation_space, hidden_sizes=[64, 64], activation_fn=nn.ReLU):
|
||||||
|
super().__init__()
|
||||||
|
input_dim = observation_space.shape[-1]
|
||||||
|
|
||||||
|
layers = []
|
||||||
|
last_dim = input_dim
|
||||||
|
for size in hidden_sizes:
|
||||||
|
layers.append(nn.Linear(last_dim, size))
|
||||||
|
layers.append(activation_fn())
|
||||||
|
last_dim = size
|
||||||
|
layers.append(nn.Linear(last_dim, 1))
|
||||||
|
self.model = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.model(x).squeeze(-1)
|
||||||
|
@ -1,17 +1,21 @@
|
|||||||
import torch
|
import torch
|
||||||
import gymnasium as gym
|
import torch.nn as nn
|
||||||
from fancy_rl.policy import Policy
|
|
||||||
from fancy_rl.loggers import TerminalLogger
|
|
||||||
from fancy_rl.on_policy import OnPolicy
|
|
||||||
from torchrl.objectives import ClipPPOLoss
|
from torchrl.objectives import ClipPPOLoss
|
||||||
from torchrl.objectives.value.advantages import GAE
|
from torchrl.objectives.value.advantages import GAE
|
||||||
|
from torchrl.record.loggers import get_logger
|
||||||
|
from on_policy import OnPolicy
|
||||||
|
from policy import Actor, Critic
|
||||||
|
import gymnasium as gym
|
||||||
|
|
||||||
class PPO(OnPolicy):
|
class PPO(OnPolicy):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
policy,
|
env_spec,
|
||||||
env_fn,
|
|
||||||
loggers=None,
|
loggers=None,
|
||||||
|
actor_hidden_sizes=[64, 64],
|
||||||
|
critic_hidden_sizes=[64, 64],
|
||||||
|
actor_activation_fn="ReLU",
|
||||||
|
critic_activation_fn="ReLU",
|
||||||
learning_rate=3e-4,
|
learning_rate=3e-4,
|
||||||
n_steps=2048,
|
n_steps=2048,
|
||||||
batch_size=64,
|
batch_size=64,
|
||||||
@ -24,16 +28,25 @@ class PPO(OnPolicy):
|
|||||||
entropy_coef=0.01,
|
entropy_coef=0.01,
|
||||||
critic_coef=0.5,
|
critic_coef=0.5,
|
||||||
normalize_advantage=True,
|
normalize_advantage=True,
|
||||||
|
clip_range=0.2,
|
||||||
device=None,
|
device=None,
|
||||||
clip_epsilon=0.2,
|
env_spec_eval=None,
|
||||||
**kwargs
|
eval_episodes=10,
|
||||||
):
|
):
|
||||||
if loggers is None:
|
# Initialize environment to get observation and action space sizes
|
||||||
loggers = [TerminalLogger(push_interval=1)]
|
env = self.make_env(env_spec)
|
||||||
|
obs_space = env.observation_space
|
||||||
|
act_space = env.action_space
|
||||||
|
|
||||||
|
actor_activation_fn = getattr(nn, actor_activation_fn)
|
||||||
|
critic_activation_fn = getattr(nn, critic_activation_fn)
|
||||||
|
|
||||||
|
self.actor = Actor(obs_space, act_space, hidden_sizes=actor_hidden_sizes, activation_fn=actor_activation_fn)
|
||||||
|
self.critic = Critic(obs_space, hidden_sizes=critic_hidden_sizes, activation_fn=critic_activation_fn)
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
policy=policy,
|
policy=self.actor,
|
||||||
env_fn=env_fn,
|
env_spec=env_spec,
|
||||||
loggers=loggers,
|
loggers=loggers,
|
||||||
learning_rate=learning_rate,
|
learning_rate=learning_rate,
|
||||||
n_steps=n_steps,
|
n_steps=n_steps,
|
||||||
@ -47,52 +60,37 @@ class PPO(OnPolicy):
|
|||||||
entropy_coef=entropy_coef,
|
entropy_coef=entropy_coef,
|
||||||
critic_coef=critic_coef,
|
critic_coef=critic_coef,
|
||||||
normalize_advantage=normalize_advantage,
|
normalize_advantage=normalize_advantage,
|
||||||
|
clip_range=clip_range,
|
||||||
device=device,
|
device=device,
|
||||||
**kwargs
|
env_spec_eval=env_spec_eval,
|
||||||
|
eval_episodes=eval_episodes,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.clip_epsilon = clip_epsilon
|
|
||||||
self.adv_module = GAE(
|
self.adv_module = GAE(
|
||||||
gamma=self.gamma,
|
gamma=self.gamma,
|
||||||
lmbda=self.gae_lambda,
|
lmbda=self.gae_lambda,
|
||||||
value_network=self.policy,
|
value_network=self.critic,
|
||||||
average_gae=False,
|
average_gae=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.loss_module = ClipPPOLoss(
|
self.loss_module = ClipPPOLoss(
|
||||||
actor_network=self.policy,
|
actor_network=self.actor,
|
||||||
critic_network=self.policy,
|
critic_network=self.critic,
|
||||||
clip_epsilon=self.clip_epsilon,
|
clip_epsilon=self.clip_range,
|
||||||
loss_critic_type='MSELoss',
|
loss_critic_type='MSELoss',
|
||||||
entropy_coef=self.entropy_coef,
|
entropy_coef=self.entropy_coef,
|
||||||
critic_coef=self.critic_coef,
|
critic_coef=self.critic_coef,
|
||||||
normalize_advantage=self.normalize_advantage,
|
normalize_advantage=self.normalize_advantage,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=self.learning_rate)
|
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=self.learning_rate)
|
||||||
|
self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=self.learning_rate)
|
||||||
|
|
||||||
def train_step(self, batch):
|
def train_step(self, batch):
|
||||||
self.optimizer.zero_grad()
|
self.actor_optimizer.zero_grad()
|
||||||
|
self.critic_optimizer.zero_grad()
|
||||||
loss = self.loss_module(batch)
|
loss = self.loss_module(batch)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
self.optimizer.step()
|
self.actor_optimizer.step()
|
||||||
|
self.critic_optimizer.step()
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def train(self):
|
|
||||||
self.env = self.env_fn()
|
|
||||||
self.env.reset(seed=self.kwargs.get("seed", None))
|
|
||||||
|
|
||||||
state = self.env.reset(seed=self.kwargs.get("seed", None))
|
|
||||||
episode_return = 0
|
|
||||||
episode_length = 0
|
|
||||||
for t in range(self.total_timesteps):
|
|
||||||
rollout = self.collect_rollouts(state)
|
|
||||||
for batch in self.get_batches(rollout):
|
|
||||||
loss = self.train_step(batch)
|
|
||||||
for logger in self.loggers:
|
|
||||||
logger.log({
|
|
||||||
"loss": loss.item()
|
|
||||||
}, epoch=t)
|
|
||||||
|
|
||||||
if (t + 1) % self.eval_interval == 0:
|
|
||||||
self.evaluate(t)
|
|
||||||
|
@ -1,4 +0,0 @@
|
|||||||
import gymnasium as gym
|
|
||||||
|
|
||||||
def make_env(env_name):
|
|
||||||
return lambda: gym.make(env_name)
|
|
Loading…
Reference in New Issue
Block a user