Refactor env handling

This commit is contained in:
Dominik Moritz Roth 2024-05-31 13:04:41 +02:00
parent 1d1d9060f9
commit 1d8d217ec0
3 changed files with 185 additions and 133 deletions

View File

@ -1,13 +1,20 @@
import torch
from abc import ABC, abstractmethod
from fancy_rl.loggers import Logger
from torchrl.record.loggers import Logger
from torch.optim import Adam
from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.envs import ExplorationType, set_exploration_type
from torchrl.envs.libs.gym import GymWrapper
from torchrl.record import VideoRecorder
import gymnasium as gym
class OnPolicy(ABC):
def __init__(
self,
policy,
env_fn,
env_spec,
loggers,
learning_rate,
n_steps,
@ -21,11 +28,14 @@ class OnPolicy(ABC):
entropy_coef,
critic_coef,
normalize_advantage,
clip_range=0.2,
device=None,
**kwargs
eval_episodes=10,
env_spec_eval=None,
):
self.policy = policy
self.env_fn = env_fn
self.env_spec = env_spec
self.env_spec_eval = env_spec_eval if env_spec_eval is not None else env_spec
self.loggers = loggers
self.learning_rate = learning_rate
self.n_steps = n_steps
@ -39,93 +49,93 @@ class OnPolicy(ABC):
self.entropy_coef = entropy_coef
self.critic_coef = critic_coef
self.normalize_advantage = normalize_advantage
self.clip_range = clip_range
self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
self.eval_episodes = eval_episodes
self.kwargs = kwargs
self.clip_range = 0.2
# Create collector
self.collector = SyncDataCollector(
create_env_fn=lambda: self.make_env(eval=False),
policy=self.policy,
frames_per_batch=self.n_steps,
total_frames=self.total_timesteps,
device=self.device,
storing_device=self.device,
max_frames_per_traj=-1,
)
# Create data buffer
self.sampler = SamplerWithoutReplacement()
self.data_buffer = TensorDictReplayBuffer(
storage=LazyMemmapStorage(self.n_steps),
sampler=self.sampler,
batch_size=self.batch_size,
)
def make_env(self, eval=False):
"""Creates an environment and wraps it if necessary."""
env_spec = self.env_spec_eval if eval else self.env_spec
if isinstance(env_spec, str):
env = gym.make(env_spec)
env = GymWrapper(env)
elif callable(env_spec):
env = env_spec()
if isinstance(env, gym.Env):
env = GymWrapper(env)
else:
raise ValueError("env_spec must be a string or a callable that returns an environment.")
return env
def train(self):
self.env = self.env_fn()
self.env.reset(seed=self.kwargs.get("seed", None))
collected_frames = 0
state = self.env.reset(seed=self.kwargs.get("seed", None))
episode_return = 0
episode_length = 0
for t in range(self.total_timesteps):
rollout = self.collect_rollouts(state)
for batch in self.get_batches(rollout):
loss = self.train_step(batch)
for logger in self.loggers:
logger.log({
"loss": loss.item()
}, epoch=t)
if (t + 1) % self.eval_interval == 0:
self.evaluate(t)
for t, data in enumerate(self.collector):
frames_in_batch = data.numel()
collected_frames += frames_in_batch
for _ in range(self.n_epochs):
with torch.no_grad():
data = self.adv_module(data)
data_reshape = data.reshape(-1)
self.data_buffer.extend(data_reshape)
for batch in self.data_buffer:
batch = batch.to(self.device)
loss = self.train_step(batch)
for logger in self.loggers:
logger.log_scalar({"loss": loss.item()}, step=collected_frames)
if (t + 1) % self.eval_interval == 0:
self.evaluate(t)
self.collector.update_policy_weights_()
def evaluate(self, epoch):
eval_env = self.env_fn()
eval_env.reset(seed=self.kwargs.get("seed", None))
returns = []
for _ in range(self.kwargs.get("eval_episodes", 10)):
state = eval_env.reset(seed=self.kwargs.get("seed", None))
done = False
total_return = 0
while not done:
with torch.no_grad():
action = (
self.policy.act(state, deterministic=self.eval_deterministic)
if self.eval_deterministic
else self.policy.act(state)
)
state, reward, done, _ = eval_env.step(action)
total_return += reward
returns.append(total_return)
eval_env = self.make_env(eval=True)
eval_env.eval()
avg_return = sum(returns) / len(returns)
test_rewards = []
for _ in range(self.eval_episodes):
with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
td_test = eval_env.rollout(
policy=self.policy,
auto_reset=True,
auto_cast_to_device=True,
break_when_any_done=True,
max_steps=10_000_000,
)
reward = td_test["next", "episode_reward"][td_test["next", "done"]]
test_rewards.append(reward.cpu())
eval_env.apply(dump_video)
avg_return = torch.cat(test_rewards, 0).mean().item()
for logger in self.loggers:
logger.log({"eval_avg_return": avg_return}, epoch=epoch)
def collect_rollouts(self, state):
# Collect rollouts logic
rollouts = []
for _ in range(self.n_steps):
action = self.policy.act(state)
next_state, reward, done, _ = self.env.step(action)
rollouts.append((state, action, reward, next_state, done))
state = next_state
if done:
state = self.env.reset(seed=self.kwargs.get("seed", None))
return rollouts
def get_batches(self, rollouts):
data = self.prepare_data(rollouts)
n_batches = len(data) // self.batch_size
batches = []
for _ in range(n_batches):
batch_indices = torch.randint(0, len(data), (self.batch_size,))
batch = data[batch_indices]
batches.append(batch)
return batches
def prepare_data(self, rollouts):
obs, actions, rewards, next_obs, dones = zip(*rollouts)
obs = torch.tensor(obs, dtype=torch.float32)
actions = torch.tensor(actions, dtype=torch.int64)
rewards = torch.tensor(rewards, dtype=torch.float32)
next_obs = torch.tensor(next_obs, dtype=torch.float32)
dones = torch.tensor(dones, dtype=torch.float32)
data = {
"obs": obs,
"actions": actions,
"rewards": rewards,
"next_obs": next_obs,
"dones": dones
}
data = self.adv_module(data)
return data
logger.log_scalar({"eval_avg_return": avg_return}, step=epoch)
@abstractmethod
def train_step(self, batch):
pass
def dump_video(module):
if isinstance(module, VideoRecorder):
module.dump()

View File

@ -1,27 +1,71 @@
import torch
from torch import nn
from torch.distributions import Categorical, Normal
import gymnasium as gym
class Policy(nn.Module):
def __init__(self, input_dim, output_dim, hidden_sizes=[64, 64]):
class Actor(nn.Module):
def __init__(self, observation_space, action_space, hidden_sizes=[64, 64], activation_fn=nn.ReLU):
super().__init__()
self.continuous = isinstance(action_space, gym.spaces.Box)
input_dim = observation_space.shape[-1]
if self.continuous:
output_dim = action_space.shape[-1]
else:
output_dim = action_space.n
layers = []
last_dim = input_dim
for size in hidden_sizes:
layers.append(nn.Linear(last_dim, size))
layers.append(nn.ReLU())
layers.append(activation_fn())
last_dim = size
layers.append(nn.Linear(last_dim, output_dim))
self.model = nn.Sequential(*layers)
if self.continuous:
self.mu_layer = nn.Linear(last_dim, output_dim)
self.log_std_layer = nn.Linear(last_dim, output_dim)
else:
layers.append(nn.Linear(last_dim, output_dim))
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
if self.continuous:
mu = self.mu_layer(x)
log_std = self.log_std_layer(x)
return mu, log_std.exp()
else:
return self.model(x)
def act(self, observation, deterministic=False):
with torch.no_grad():
logits = self.forward(observation)
if deterministic:
action = logits.argmax(dim=-1)
if self.continuous:
mu, std = self.forward(observation)
if deterministic:
action = mu
else:
action_dist = Normal(mu, std)
action = action_dist.sample()
else:
action_dist = torch.distributions.Categorical(logits=logits)
action = action_dist.sample()
logits = self.forward(observation)
if deterministic:
action = logits.argmax(dim=-1)
else:
action_dist = Categorical(logits=logits)
action = action_dist.sample()
return action
class Critic(nn.Module):
def __init__(self, observation_space, hidden_sizes=[64, 64], activation_fn=nn.ReLU):
super().__init__()
input_dim = observation_space.shape[-1]
layers = []
last_dim = input_dim
for size in hidden_sizes:
layers.append(nn.Linear(last_dim, size))
layers.append(activation_fn())
last_dim = size
layers.append(nn.Linear(last_dim, 1))
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x).squeeze(-1)

View File

@ -1,17 +1,21 @@
import torch
import gymnasium as gym
from fancy_rl.policy import Policy
from fancy_rl.loggers import TerminalLogger
from fancy_rl.on_policy import OnPolicy
import torch.nn as nn
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value.advantages import GAE
from torchrl.record.loggers import get_logger
from on_policy import OnPolicy
from policy import Actor, Critic
import gymnasium as gym
class PPO(OnPolicy):
def __init__(
self,
policy,
env_fn,
env_spec,
loggers=None,
actor_hidden_sizes=[64, 64],
critic_hidden_sizes=[64, 64],
actor_activation_fn="ReLU",
critic_activation_fn="ReLU",
learning_rate=3e-4,
n_steps=2048,
batch_size=64,
@ -24,16 +28,25 @@ class PPO(OnPolicy):
entropy_coef=0.01,
critic_coef=0.5,
normalize_advantage=True,
clip_range=0.2,
device=None,
clip_epsilon=0.2,
**kwargs
env_spec_eval=None,
eval_episodes=10,
):
if loggers is None:
loggers = [TerminalLogger(push_interval=1)]
# Initialize environment to get observation and action space sizes
env = self.make_env(env_spec)
obs_space = env.observation_space
act_space = env.action_space
actor_activation_fn = getattr(nn, actor_activation_fn)
critic_activation_fn = getattr(nn, critic_activation_fn)
self.actor = Actor(obs_space, act_space, hidden_sizes=actor_hidden_sizes, activation_fn=actor_activation_fn)
self.critic = Critic(obs_space, hidden_sizes=critic_hidden_sizes, activation_fn=critic_activation_fn)
super().__init__(
policy=policy,
env_fn=env_fn,
policy=self.actor,
env_spec=env_spec,
loggers=loggers,
learning_rate=learning_rate,
n_steps=n_steps,
@ -47,52 +60,37 @@ class PPO(OnPolicy):
entropy_coef=entropy_coef,
critic_coef=critic_coef,
normalize_advantage=normalize_advantage,
clip_range=clip_range,
device=device,
**kwargs
env_spec_eval=env_spec_eval,
eval_episodes=eval_episodes,
)
self.clip_epsilon = clip_epsilon
self.adv_module = GAE(
gamma=self.gamma,
lmbda=self.gae_lambda,
value_network=self.policy,
value_network=self.critic,
average_gae=False,
)
self.loss_module = ClipPPOLoss(
actor_network=self.policy,
critic_network=self.policy,
clip_epsilon=self.clip_epsilon,
actor_network=self.actor,
critic_network=self.critic,
clip_epsilon=self.clip_range,
loss_critic_type='MSELoss',
entropy_coef=self.entropy_coef,
critic_coef=self.critic_coef,
normalize_advantage=self.normalize_advantage,
)
self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=self.learning_rate)
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=self.learning_rate)
self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=self.learning_rate)
def train_step(self, batch):
self.optimizer.zero_grad()
self.actor_optimizer.zero_grad()
self.critic_optimizer.zero_grad()
loss = self.loss_module(batch)
loss.backward()
self.optimizer.step()
self.actor_optimizer.step()
self.critic_optimizer.step()
return loss
def train(self):
self.env = self.env_fn()
self.env.reset(seed=self.kwargs.get("seed", None))
state = self.env.reset(seed=self.kwargs.get("seed", None))
episode_return = 0
episode_length = 0
for t in range(self.total_timesteps):
rollout = self.collect_rollouts(state)
for batch in self.get_batches(rollout):
loss = self.train_step(batch)
for logger in self.loggers:
logger.log({
"loss": loss.item()
}, epoch=t)
if (t + 1) % self.eval_interval == 0:
self.evaluate(t)