From 04be117a95f76016488c34a6e893bae0d0e42d2e Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Wed, 22 Jan 2025 13:46:00 +0100 Subject: [PATCH] Various fixes --- fancy_rl/algos/algo.py | 54 ++++++++++++++++++++++++------------------ fancy_rl/algos/ppo.py | 10 ++++---- fancy_rl/algos/trpl.py | 49 ++++++++++++++++++++++++++++++++------ fancy_rl/policy.py | 17 ++++++++++++- 4 files changed, 94 insertions(+), 36 deletions(-) diff --git a/fancy_rl/algos/algo.py b/fancy_rl/algos/algo.py index ce3a477..4a4fbbc 100644 --- a/fancy_rl/algos/algo.py +++ b/fancy_rl/algos/algo.py @@ -1,9 +1,10 @@ import torch import gymnasium as gym -from torchrl.envs import GymEnv, TransformedEnv, Compose, RewardSum, StepCounter, ParallelEnv +from torchrl.envs import GymEnv, TransformedEnv, Compose, RewardSum, StepCounter, SerialEnv from torchrl.record import VideoRecorder from abc import ABC import pdb +import numpy as np from tensordict import TensorDict from torchrl.envs import GymWrapper, TransformedEnv from torchrl.envs import BatchSizeTransform @@ -56,27 +57,31 @@ class Algo(ABC): return env def _wrap_env(self, env_spec): + # If given an existing env, ensure it's properly batched + if isinstance(env_spec, (GymEnv, GymWrapper)): + if not env_spec.batch_size: + raise ValueError("Environment must be batched") + return env_spec + + # Handle callable without wrapping the recursive call + if callable(env_spec): + return self._wrap_env(env_spec()) + + # Create new batched environment using SerialEnv if isinstance(env_spec, str): - env = GymEnv(env_spec, device=self.device) + env = SerialEnv(1, lambda: GymEnv(env_spec, device=self.device)) elif isinstance(env_spec, gym.Env): - env = GymWrapper(env_spec, device=self.device) - elif isinstance(env_spec, GymEnv): - env = env_spec - elif callable(env_spec): - base_env = env_spec() - return self._wrap_env(base_env) + wrapped_env = GymWrapper(env_spec, device=self.device) + if wrapped_env.batch_size: + env = wrapped_env + else: + env = SerialEnv(1, lambda: wrapped_env) else: raise ValueError( f"env_spec must be a string, callable, Gymnasium environment, or GymEnv, " f"got {type(env_spec)}" ) - if not env.batch_size: - env = TransformedEnv( - env, - BatchSizeTransform(batch_size=torch.Size([1])) - ) - return env def train_step(self, batch): @@ -90,18 +95,21 @@ class Algo(ABC): def predict( self, - observation, + tensordict, state=None, deterministic=False ): with torch.no_grad(): - obs_tensor = torch.as_tensor(observation, device=self.device).unsqueeze(0) - td = TensorDict({"observation": obs_tensor}) + # If numpy array, convert to TensorDict + if isinstance(tensordict, np.ndarray): + tensordict = TensorDict( + {"observation": torch.from_numpy(tensordict).float()}, + batch_size=[] + ) - action_td = self.prob_actor(td) - action = action_td["action"] + # Move to device + tensordict = tensordict.to(self.device) - # We're not using recurrent policies, so we'll always return None for the state - next_state = None - - return action.squeeze(0).cpu().numpy(), next_state \ No newline at end of file + # Get action from policy + action_td = self.prob_actor(tensordict) + return action_td diff --git a/fancy_rl/algos/ppo.py b/fancy_rl/algos/ppo.py index 0529241..cf4f6cf 100644 --- a/fancy_rl/algos/ppo.py +++ b/fancy_rl/algos/ppo.py @@ -43,13 +43,13 @@ class PPO(OnPolicy): env = self.make_env() # Get spaces from specs for parallel env - obs_space = env.observation_spec - act_space = env.action_spec + self.obs_space = env.observation_spec + self.act_space = env.action_spec - self.discrete = isinstance(act_space, DiscreteTensorSpec) + self.discrete = isinstance(self.act_space, DiscreteTensorSpec) - self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device) - self.actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance) + self.critic = Critic(self.obs_space, critic_hidden_sizes, critic_activation_fn, device) + self.actor = Actor(self.obs_space, self.act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance) if self.discrete: distribution_class = torch.distributions.Categorical diff --git a/fancy_rl/algos/trpl.py b/fancy_rl/algos/trpl.py index c483fff..b414584 100644 --- a/fancy_rl/algos/trpl.py +++ b/fancy_rl/algos/trpl.py @@ -14,6 +14,8 @@ from fancy_rl.objectives import TRPLLoss from copy import deepcopy from tensordict.nn import TensorDictModule from tensordict import TensorDict +from torch.distributions import Categorical, MultivariateNormal, Normal + class ProjectedActor(TensorDictModule): def __init__(self, raw_actor, old_actor, projection): @@ -26,6 +28,8 @@ class ProjectedActor(TensorDictModule): self.raw_actor = raw_actor self.old_actor = old_actor self.projection = projection + self.discrete = raw_actor.discrete + self.full_covariance = raw_actor.full_covariance class CombinedModule(nn.Module): def __init__(self, raw_actor, old_actor, projection): @@ -35,12 +39,41 @@ class ProjectedActor(TensorDictModule): self.projection = projection def forward(self, tensordict): + # Convert the tuple outputs to TensorDict raw_params = self.raw_actor(tensordict) + if isinstance(raw_params, tuple): + raw_params = TensorDict({ + "loc": raw_params[0], + "scale": raw_params[1] + }, batch_size=[raw_params[0].shape[0]]) # Use the first dimension of the tensor as batch size + old_params = self.old_actor(tensordict) - combined_params = TensorDict({**raw_params, **{f"old_{key}": value for key, value in old_params.items()}}, batch_size=tensordict.batch_size) + if isinstance(old_params, tuple): + old_params = TensorDict({ + "loc": old_params[0], + "scale": old_params[1] + }, batch_size=[old_params[0].shape[0]]) # Use the first dimension of the tensor as batch size + + # Now combine them + combined_params = TensorDict({ + **raw_params, + **{f"old_{key}": value for key, value in old_params.items()} + }, batch_size=[raw_params["loc"].shape[0]]) # Use the first dimension of loc tensor as batch size + projected_params = self.projection(combined_params) return projected_params + def get_dist(self, tensordict): + # Forward the observation through the network + out = self.forward(tensordict) + if self.discrete: + return Categorical(logits=out["logits"]) + else: + if self.full_covariance: + return MultivariateNormal(loc=out["loc"], scale_tril=out["scale_tril"]) + else: + return Normal(loc=out["loc"], scale=out["scale"]) + class TRPL(OnPolicy): def __init__( self, @@ -77,14 +110,16 @@ class TRPL(OnPolicy): # Initialize environment to get observation and action space sizes self.env_spec = env_spec env = self.make_env() - obs_space = env.observation_space - act_space = env.action_space - assert not isinstance(act_space, DiscreteTensorSpec), "TRPL does not support discrete action spaces" + # Get spaces from specs for parallel env + self.obs_space = env.observation_spec + self.act_space = env.action_spec - self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device) - self.raw_actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance) - self.old_actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance) + assert not isinstance(self.act_space, DiscreteTensorSpec), "TRPL does not support discrete action spaces" + + self.critic = Critic(self.obs_space, critic_hidden_sizes, critic_activation_fn, device) + self.raw_actor = Actor(self.obs_space, self.act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance) + self.old_actor = Actor(self.obs_space, self.act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance) # Handle projection_class if isinstance(projection_class, str): diff --git a/fancy_rl/policy.py b/fancy_rl/policy.py index 6ce1590..9359f5b 100644 --- a/fancy_rl/policy.py +++ b/fancy_rl/policy.py @@ -4,6 +4,7 @@ from torchrl.modules import MLP from torchrl.data.tensor_specs import DiscreteTensorSpec from tensordict.nn.distributions import NormalParamExtractor from tensordict import TensorDict +from torch.distributions import Categorical, MultivariateNormal, Normal class Actor(TensorDictModule): def __init__(self, obs_space, act_space, hidden_sizes, activation_fn, device, full_covariance=False): @@ -50,6 +51,17 @@ class Actor(TensorDictModule): out_keys=out_keys ) + def get_dist(self, tensordict): + # Forward the observation through the network + out = self.forward(tensordict) + if self.discrete: + return Categorical(logits=out["logits"]) + else: + if self.full_covariance: + return MultivariateNormal(loc=out["loc"], scale_tril=out["scale_tril"]) + else: + return Normal(loc=out["loc"], scale=out["scale"]) + class FullCovarianceNormalParamExtractor(nn.Module): def __init__(self, action_dim): super().__init__() @@ -65,8 +77,11 @@ class FullCovarianceNormalParamExtractor(nn.Module): class Critic(TensorDictModule): def __init__(self, obs_space, hidden_sizes, activation_fn, device): + obs_space = obs_space["observation"] + obs_space_shape = obs_space.shape[1:] + critic_module = MLP( - in_features=obs_space.shape[-1], + in_features=obs_space_shape[0], out_features=1, num_cells=hidden_sizes, activation_class=getattr(nn, activation_fn),