From 1d8d217ec04142d42d1daa940315b4d2f98eec45 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Fri, 31 May 2024 13:04:41 +0200 Subject: [PATCH] Refactor env handling --- fancy_rl/on_policy.py | 172 ++++++++++++++++++++++-------------------- fancy_rl/policy.py | 66 +++++++++++++--- fancy_rl/ppo.py | 80 ++++++++++---------- 3 files changed, 185 insertions(+), 133 deletions(-) diff --git a/fancy_rl/on_policy.py b/fancy_rl/on_policy.py index 7152948..eeb22ac 100644 --- a/fancy_rl/on_policy.py +++ b/fancy_rl/on_policy.py @@ -1,13 +1,20 @@ import torch from abc import ABC, abstractmethod -from fancy_rl.loggers import Logger +from torchrl.record.loggers import Logger from torch.optim import Adam +from torchrl.collectors import SyncDataCollector +from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer +from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement +from torchrl.envs import ExplorationType, set_exploration_type +from torchrl.envs.libs.gym import GymWrapper +from torchrl.record import VideoRecorder +import gymnasium as gym class OnPolicy(ABC): def __init__( self, policy, - env_fn, + env_spec, loggers, learning_rate, n_steps, @@ -21,11 +28,14 @@ class OnPolicy(ABC): entropy_coef, critic_coef, normalize_advantage, + clip_range=0.2, device=None, - **kwargs + eval_episodes=10, + env_spec_eval=None, ): self.policy = policy - self.env_fn = env_fn + self.env_spec = env_spec + self.env_spec_eval = env_spec_eval if env_spec_eval is not None else env_spec self.loggers = loggers self.learning_rate = learning_rate self.n_steps = n_steps @@ -39,93 +49,93 @@ class OnPolicy(ABC): self.entropy_coef = entropy_coef self.critic_coef = critic_coef self.normalize_advantage = normalize_advantage + self.clip_range = clip_range self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu") + self.eval_episodes = eval_episodes - self.kwargs = kwargs - self.clip_range = 0.2 + # Create collector + self.collector = SyncDataCollector( + create_env_fn=lambda: self.make_env(eval=False), + policy=self.policy, + frames_per_batch=self.n_steps, + total_frames=self.total_timesteps, + device=self.device, + storing_device=self.device, + max_frames_per_traj=-1, + ) + + # Create data buffer + self.sampler = SamplerWithoutReplacement() + self.data_buffer = TensorDictReplayBuffer( + storage=LazyMemmapStorage(self.n_steps), + sampler=self.sampler, + batch_size=self.batch_size, + ) + + def make_env(self, eval=False): + """Creates an environment and wraps it if necessary.""" + env_spec = self.env_spec_eval if eval else self.env_spec + if isinstance(env_spec, str): + env = gym.make(env_spec) + env = GymWrapper(env) + elif callable(env_spec): + env = env_spec() + if isinstance(env, gym.Env): + env = GymWrapper(env) + else: + raise ValueError("env_spec must be a string or a callable that returns an environment.") + return env def train(self): - self.env = self.env_fn() - self.env.reset(seed=self.kwargs.get("seed", None)) + collected_frames = 0 - state = self.env.reset(seed=self.kwargs.get("seed", None)) - episode_return = 0 - episode_length = 0 - for t in range(self.total_timesteps): - rollout = self.collect_rollouts(state) - for batch in self.get_batches(rollout): - loss = self.train_step(batch) - for logger in self.loggers: - logger.log({ - "loss": loss.item() - }, epoch=t) - - if (t + 1) % self.eval_interval == 0: - self.evaluate(t) + for t, data in enumerate(self.collector): + frames_in_batch = data.numel() + collected_frames += frames_in_batch + + for _ in range(self.n_epochs): + with torch.no_grad(): + data = self.adv_module(data) + data_reshape = data.reshape(-1) + self.data_buffer.extend(data_reshape) + + for batch in self.data_buffer: + batch = batch.to(self.device) + loss = self.train_step(batch) + for logger in self.loggers: + logger.log_scalar({"loss": loss.item()}, step=collected_frames) + + if (t + 1) % self.eval_interval == 0: + self.evaluate(t) + + self.collector.update_policy_weights_() def evaluate(self, epoch): - eval_env = self.env_fn() - eval_env.reset(seed=self.kwargs.get("seed", None)) - returns = [] - for _ in range(self.kwargs.get("eval_episodes", 10)): - state = eval_env.reset(seed=self.kwargs.get("seed", None)) - done = False - total_return = 0 - while not done: - with torch.no_grad(): - action = ( - self.policy.act(state, deterministic=self.eval_deterministic) - if self.eval_deterministic - else self.policy.act(state) - ) - state, reward, done, _ = eval_env.step(action) - total_return += reward - returns.append(total_return) + eval_env = self.make_env(eval=True) + eval_env.eval() - avg_return = sum(returns) / len(returns) + test_rewards = [] + for _ in range(self.eval_episodes): + with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + td_test = eval_env.rollout( + policy=self.policy, + auto_reset=True, + auto_cast_to_device=True, + break_when_any_done=True, + max_steps=10_000_000, + ) + reward = td_test["next", "episode_reward"][td_test["next", "done"]] + test_rewards.append(reward.cpu()) + eval_env.apply(dump_video) + + avg_return = torch.cat(test_rewards, 0).mean().item() for logger in self.loggers: - logger.log({"eval_avg_return": avg_return}, epoch=epoch) - - def collect_rollouts(self, state): - # Collect rollouts logic - rollouts = [] - for _ in range(self.n_steps): - action = self.policy.act(state) - next_state, reward, done, _ = self.env.step(action) - rollouts.append((state, action, reward, next_state, done)) - state = next_state - if done: - state = self.env.reset(seed=self.kwargs.get("seed", None)) - return rollouts - - def get_batches(self, rollouts): - data = self.prepare_data(rollouts) - n_batches = len(data) // self.batch_size - batches = [] - for _ in range(n_batches): - batch_indices = torch.randint(0, len(data), (self.batch_size,)) - batch = data[batch_indices] - batches.append(batch) - return batches - - def prepare_data(self, rollouts): - obs, actions, rewards, next_obs, dones = zip(*rollouts) - obs = torch.tensor(obs, dtype=torch.float32) - actions = torch.tensor(actions, dtype=torch.int64) - rewards = torch.tensor(rewards, dtype=torch.float32) - next_obs = torch.tensor(next_obs, dtype=torch.float32) - dones = torch.tensor(dones, dtype=torch.float32) - - data = { - "obs": obs, - "actions": actions, - "rewards": rewards, - "next_obs": next_obs, - "dones": dones - } - data = self.adv_module(data) - return data + logger.log_scalar({"eval_avg_return": avg_return}, step=epoch) @abstractmethod def train_step(self, batch): pass + +def dump_video(module): + if isinstance(module, VideoRecorder): + module.dump() diff --git a/fancy_rl/policy.py b/fancy_rl/policy.py index 40c6a48..5875986 100644 --- a/fancy_rl/policy.py +++ b/fancy_rl/policy.py @@ -1,27 +1,71 @@ import torch from torch import nn +from torch.distributions import Categorical, Normal +import gymnasium as gym -class Policy(nn.Module): - def __init__(self, input_dim, output_dim, hidden_sizes=[64, 64]): +class Actor(nn.Module): + def __init__(self, observation_space, action_space, hidden_sizes=[64, 64], activation_fn=nn.ReLU): super().__init__() + self.continuous = isinstance(action_space, gym.spaces.Box) + input_dim = observation_space.shape[-1] + if self.continuous: + output_dim = action_space.shape[-1] + else: + output_dim = action_space.n + layers = [] last_dim = input_dim for size in hidden_sizes: layers.append(nn.Linear(last_dim, size)) - layers.append(nn.ReLU()) + layers.append(activation_fn()) last_dim = size - layers.append(nn.Linear(last_dim, output_dim)) - self.model = nn.Sequential(*layers) + + if self.continuous: + self.mu_layer = nn.Linear(last_dim, output_dim) + self.log_std_layer = nn.Linear(last_dim, output_dim) + else: + layers.append(nn.Linear(last_dim, output_dim)) + self.model = nn.Sequential(*layers) def forward(self, x): - return self.model(x) + if self.continuous: + mu = self.mu_layer(x) + log_std = self.log_std_layer(x) + return mu, log_std.exp() + else: + return self.model(x) def act(self, observation, deterministic=False): with torch.no_grad(): - logits = self.forward(observation) - if deterministic: - action = logits.argmax(dim=-1) + if self.continuous: + mu, std = self.forward(observation) + if deterministic: + action = mu + else: + action_dist = Normal(mu, std) + action = action_dist.sample() else: - action_dist = torch.distributions.Categorical(logits=logits) - action = action_dist.sample() + logits = self.forward(observation) + if deterministic: + action = logits.argmax(dim=-1) + else: + action_dist = Categorical(logits=logits) + action = action_dist.sample() return action + +class Critic(nn.Module): + def __init__(self, observation_space, hidden_sizes=[64, 64], activation_fn=nn.ReLU): + super().__init__() + input_dim = observation_space.shape[-1] + + layers = [] + last_dim = input_dim + for size in hidden_sizes: + layers.append(nn.Linear(last_dim, size)) + layers.append(activation_fn()) + last_dim = size + layers.append(nn.Linear(last_dim, 1)) + self.model = nn.Sequential(*layers) + + def forward(self, x): + return self.model(x).squeeze(-1) diff --git a/fancy_rl/ppo.py b/fancy_rl/ppo.py index 7b4abe2..02f36da 100644 --- a/fancy_rl/ppo.py +++ b/fancy_rl/ppo.py @@ -1,17 +1,21 @@ import torch -import gymnasium as gym -from fancy_rl.policy import Policy -from fancy_rl.loggers import TerminalLogger -from fancy_rl.on_policy import OnPolicy +import torch.nn as nn from torchrl.objectives import ClipPPOLoss from torchrl.objectives.value.advantages import GAE +from torchrl.record.loggers import get_logger +from on_policy import OnPolicy +from policy import Actor, Critic +import gymnasium as gym class PPO(OnPolicy): def __init__( self, - policy, - env_fn, + env_spec, loggers=None, + actor_hidden_sizes=[64, 64], + critic_hidden_sizes=[64, 64], + actor_activation_fn="ReLU", + critic_activation_fn="ReLU", learning_rate=3e-4, n_steps=2048, batch_size=64, @@ -24,16 +28,25 @@ class PPO(OnPolicy): entropy_coef=0.01, critic_coef=0.5, normalize_advantage=True, + clip_range=0.2, device=None, - clip_epsilon=0.2, - **kwargs + env_spec_eval=None, + eval_episodes=10, ): - if loggers is None: - loggers = [TerminalLogger(push_interval=1)] + # Initialize environment to get observation and action space sizes + env = self.make_env(env_spec) + obs_space = env.observation_space + act_space = env.action_space + + actor_activation_fn = getattr(nn, actor_activation_fn) + critic_activation_fn = getattr(nn, critic_activation_fn) + + self.actor = Actor(obs_space, act_space, hidden_sizes=actor_hidden_sizes, activation_fn=actor_activation_fn) + self.critic = Critic(obs_space, hidden_sizes=critic_hidden_sizes, activation_fn=critic_activation_fn) super().__init__( - policy=policy, - env_fn=env_fn, + policy=self.actor, + env_spec=env_spec, loggers=loggers, learning_rate=learning_rate, n_steps=n_steps, @@ -47,52 +60,37 @@ class PPO(OnPolicy): entropy_coef=entropy_coef, critic_coef=critic_coef, normalize_advantage=normalize_advantage, + clip_range=clip_range, device=device, - **kwargs + env_spec_eval=env_spec_eval, + eval_episodes=eval_episodes, ) - - self.clip_epsilon = clip_epsilon + self.adv_module = GAE( gamma=self.gamma, lmbda=self.gae_lambda, - value_network=self.policy, + value_network=self.critic, average_gae=False, ) self.loss_module = ClipPPOLoss( - actor_network=self.policy, - critic_network=self.policy, - clip_epsilon=self.clip_epsilon, + actor_network=self.actor, + critic_network=self.critic, + clip_epsilon=self.clip_range, loss_critic_type='MSELoss', entropy_coef=self.entropy_coef, critic_coef=self.critic_coef, normalize_advantage=self.normalize_advantage, ) - self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=self.learning_rate) + self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=self.learning_rate) + self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=self.learning_rate) def train_step(self, batch): - self.optimizer.zero_grad() + self.actor_optimizer.zero_grad() + self.critic_optimizer.zero_grad() loss = self.loss_module(batch) loss.backward() - self.optimizer.step() + self.actor_optimizer.step() + self.critic_optimizer.step() return loss - - def train(self): - self.env = self.env_fn() - self.env.reset(seed=self.kwargs.get("seed", None)) - - state = self.env.reset(seed=self.kwargs.get("seed", None)) - episode_return = 0 - episode_length = 0 - for t in range(self.total_timesteps): - rollout = self.collect_rollouts(state) - for batch in self.get_batches(rollout): - loss = self.train_step(batch) - for logger in self.loggers: - logger.log({ - "loss": loss.item() - }, epoch=t) - - if (t + 1) % self.eval_interval == 0: - self.evaluate(t)