From 0c6e58634f90f507521839ce98178fae10f6ac88 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Mon, 21 Oct 2024 15:23:39 +0200 Subject: [PATCH] Rework algo impls --- fancy_rl/algos/algo.py | 28 +++++++++--- fancy_rl/algos/on_policy.py | 2 +- fancy_rl/algos/ppo.py | 39 ++++++++++++----- fancy_rl/algos/trpl.py | 86 ++++++++++++++++++++++++------------- 4 files changed, 105 insertions(+), 50 deletions(-) diff --git a/fancy_rl/algos/algo.py b/fancy_rl/algos/algo.py index b3b150f..90eeee0 100644 --- a/fancy_rl/algos/algo.py +++ b/fancy_rl/algos/algo.py @@ -3,6 +3,7 @@ import gymnasium as gym from torchrl.envs.libs.gym import GymWrapper from torchrl.record import VideoRecorder from abc import ABC +from tensordict import TensorDict from fancy_rl.loggers import TerminalLogger @@ -53,12 +54,11 @@ class Algo(ABC): env = GymWrapper(env).to(self.device) elif callable(env_spec): env = env_spec() - if isinstance(env, gym.Env): - env = GymWrapper(env).to(self.device) - elif isinstance(env, gym.Env): + if not (isinstance(env, gym.Env) or isinstance(env, gym.core.Wrapper)): + raise ValueError("env_spec must be a string or a callable that returns an environment. Was a callable that returned a {}".format(type(env))) env = GymWrapper(env).to(self.device) else: - raise ValueError("env_spec must be a string or a callable that returns an environment.") + raise ValueError("env_spec must be a string or a callable that returns an environment. Was a {}".format(type(env_spec))) return env def train_step(self, batch): @@ -70,6 +70,20 @@ class Algo(ABC): def evaluate(self, epoch): raise NotImplementedError("evaluate method must be implemented in subclass.") -def dump_video(module): - if isinstance(module, VideoRecorder): - module.dump() \ No newline at end of file + def predict( + self, + observation, + state=None, + deterministic=False + ): + with torch.no_grad(): + obs_tensor = torch.as_tensor(observation, device=self.device).unsqueeze(0) + td = TensorDict({"observation": obs_tensor}, batch_size=[1]) + + action_td = self.prob_actor(td) + action = action_td["action"] + + # We're not using recurrent policies, so we'll always return None for the state + next_state = None + + return action.squeeze(0).cpu().numpy(), next_state \ No newline at end of file diff --git a/fancy_rl/algos/on_policy.py b/fancy_rl/algos/on_policy.py index 9da5355..8bc39c2 100644 --- a/fancy_rl/algos/on_policy.py +++ b/fancy_rl/algos/on_policy.py @@ -55,7 +55,7 @@ class OnPolicy(Algo): # Create collector self.collector = SyncDataCollector( create_env_fn=lambda: self.make_env(eval=False), - policy=self.actor, + policy=self.prob_actor, frames_per_batch=self.n_steps, total_frames=self.total_timesteps, device=self.device, diff --git a/fancy_rl/algos/ppo.py b/fancy_rl/algos/ppo.py index d776811..d462ff8 100644 --- a/fancy_rl/algos/ppo.py +++ b/fancy_rl/algos/ppo.py @@ -4,7 +4,7 @@ from torchrl.objectives import ClipPPOLoss from torchrl.objectives.value.advantages import GAE from fancy_rl.algos.on_policy import OnPolicy from fancy_rl.policy import Actor, Critic -from fancy_rl.projections import get_projection # Updated import +from fancy_rl.utils import is_discrete_space class PPO(OnPolicy): def __init__( @@ -31,7 +31,10 @@ class PPO(OnPolicy): device=None, env_spec_eval=None, eval_episodes=10, + full_covariance=False, ): + self.clip_range = clip_range + device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") self.device = device @@ -41,15 +44,29 @@ class PPO(OnPolicy): obs_space = env.observation_space act_space = env.action_space + self.discrete = is_discrete_space(act_space) + self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device) - actor_net = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device) - self.actor = ProbabilisticActor( - module=actor_net, - in_keys=["loc", "scale"], - out_keys=["action"], - distribution_class=torch.distributions.Normal, - return_log_prob=True - ) + self.actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance) + + if self.discrete: + distribution_class = torch.distributions.Categorical + distribution_kwargs = {"logits": "action_logits"} + else: + if full_covariance: + distribution_class = torch.distributions.MultivariateNormal + in_keys = ["loc", "scale_tril"] + else: + distribution_class = torch.distributions.Normal + in_keys = ["loc", "scale"] + + self.prob_actor = ProbabilisticActor( + module=self.actor, + distribution_class=distribution_class, + return_log_prob=True, + in_keys=in_keys, + out_keys=["action"] + ) optimizers = { "actor": torch.optim.Adam(self.actor.parameters(), lr=learning_rate), @@ -86,9 +103,9 @@ class PPO(OnPolicy): self.loss_module = ClipPPOLoss( actor_network=self.actor, critic_network=self.critic, - clip_epsilon=clip_range, + clip_epsilon=self.clip_range, loss_critic_type='l2', entropy_coef=self.entropy_coef, critic_coef=self.critic_coef, normalize_advantage=self.normalize_advantage, - ) + ) \ No newline at end of file diff --git a/fancy_rl/algos/trpl.py b/fancy_rl/algos/trpl.py index 8160e2b..aa94a8b 100644 --- a/fancy_rl/algos/trpl.py +++ b/fancy_rl/algos/trpl.py @@ -10,7 +10,36 @@ from fancy_rl.algos.on_policy import OnPolicy from fancy_rl.policy import Actor, Critic from fancy_rl.projections import get_projection, BaseProjection from fancy_rl.objectives import TRPLLoss +from fancy_rl.utils import is_discrete_space from copy import deepcopy +from tensordict.nn import TensorDictModule +from tensordict import TensorDict + +class ProjectedActor(TensorDictModule): + def __init__(self, raw_actor, old_actor, projection): + combined_module = self.CombinedModule(raw_actor, old_actor, projection) + super().__init__( + module=combined_module, + in_keys=raw_actor.in_keys, + out_keys=raw_actor.out_keys + ) + self.raw_actor = raw_actor + self.old_actor = old_actor + self.projection = projection + + class CombinedModule(nn.Module): + def __init__(self, raw_actor, old_actor, projection): + super().__init__() + self.raw_actor = raw_actor + self.old_actor = old_actor + self.projection = projection + + def forward(self, tensordict): + raw_params = self.raw_actor(tensordict) + old_params = self.old_actor(tensordict) + combined_params = TensorDict({**raw_params, **{f"old_{key}": value for key, value in old_params.items()}}, batch_size=tensordict.batch_size) + projected_params = self.projection(combined_params) + return projected_params class TRPL(OnPolicy): def __init__( @@ -40,6 +69,7 @@ class TRPL(OnPolicy): device=None, env_spec_eval=None, eval_episodes=10, + full_covariance=False, ): device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") self.device = device @@ -50,8 +80,11 @@ class TRPL(OnPolicy): obs_space = env.observation_space act_space = env.action_space + assert not is_discrete_space(act_space), "TRPL does not support discrete action spaces" + self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device) - actor_net = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device) + self.raw_actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance) + self.old_actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance) # Handle projection_class if isinstance(projection_class, str): @@ -60,20 +93,27 @@ class TRPL(OnPolicy): raise ValueError("projection_class must be a string or a subclass of BaseProjection") self.projection = projection_class( - in_keys=["loc", "scale"], - out_keys=["loc", "scale"], - trust_region_bound_mean=trust_region_bound_mean, - trust_region_bound_cov=trust_region_bound_cov + in_keys=["loc", "scale_tril", "old_loc", "old_scale_tril"] if full_covariance else ["loc", "scale", "old_loc", "old_scale"], + out_keys=["loc", "scale_tril"] if full_covariance else ["loc", "scale"], + mean_bound=trust_region_bound_mean, + cov_bound=trust_region_bound_cov ) - self.actor = ProbabilisticActor( - module=actor_net, - in_keys=["observation"], - out_keys=["loc", "scale"], - distribution_class=torch.distributions.Normal, - return_log_prob=True + self.actor = ProjectedActor(self.raw_actor, self.old_actor, self.projection) + + if full_covariance: + distribution_class = torch.distributions.MultivariateNormal + distribution_kwargs = {"loc": "loc", "scale_tril": "scale_tril"} + else: + distribution_class = torch.distributions.Normal + distribution_kwargs = {"loc": "loc", "scale": "scale"} + + self.prob_actor = ProbabilisticActor( + module=self.actor, + distribution_class=distribution_class, + return_log_prob=True, + in_keys=distribution_kwargs, ) - self.old_actor = deepcopy(self.actor) self.trust_region_coef = trust_region_coef self.loss_module = TRPLLoss( @@ -88,7 +128,7 @@ class TRPL(OnPolicy): ) optimizers = { - "actor": torch.optim.Adam(self.actor.parameters(), lr=learning_rate), + "actor": torch.optim.Adam(self.raw_actor.parameters(), lr=learning_rate), "critic": torch.optim.Adam(self.critic.parameters(), lr=learning_rate) } @@ -119,23 +159,7 @@ class TRPL(OnPolicy): ) def update_old_policy(self): - self.old_actor.load_state_dict(self.actor.state_dict()) - - def project_policy(self, obs): - with torch.no_grad(): - old_dist = self.old_actor(obs) - new_dist = self.actor(obs) - projected_params = self.projection.project(new_dist, old_dist) - return projected_params - - def pre_update(self, tensordict): - obs = tensordict["observation"] - projected_dist = self.project_policy(obs) - - # Update tensordict with projected distribution parameters - tensordict["projected_loc"] = projected_dist[0] - tensordict["projected_scale"] = projected_dist[1] - return tensordict + self.old_actor.load_state_dict(self.raw_actor.state_dict()) def post_update(self): - self.update_old_policy() + self.update_old_policy() \ No newline at end of file