From c7f5fcbf0f9fbe76cf1b03d9123c942517e7936e Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sun, 2 Jun 2024 13:56:54 +0200 Subject: [PATCH] I hate debugging tensordict weirdness --- fancy_rl/algos/on_policy.py | 23 +++++++++-------------- fancy_rl/algos/ppo.py | 35 +++++++++++------------------------ fancy_rl/policy.py | 36 +++++++----------------------------- 3 files changed, 27 insertions(+), 67 deletions(-) diff --git a/fancy_rl/algos/on_policy.py b/fancy_rl/algos/on_policy.py index 556ffcb..f200c80 100644 --- a/fancy_rl/algos/on_policy.py +++ b/fancy_rl/algos/on_policy.py @@ -6,13 +6,12 @@ from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement from torchrl.envs.libs.gym import GymWrapper from torchrl.envs import ExplorationType, set_exploration_type from torchrl.record import VideoRecorder -from abc import ABC, abstractmethod - +from tensordict import LazyStackedTensorDict, TensorDict +from abc import ABC class OnPolicy(ABC): def __init__( self, - policy, env_spec, loggers, optimizers, @@ -21,19 +20,16 @@ class OnPolicy(ABC): batch_size, n_epochs, gamma, - gae_lambda, total_timesteps, eval_interval, eval_deterministic, entropy_coef, critic_coef, normalize_advantage, - clip_range=0.2, device=None, eval_episodes=10, env_spec_eval=None, ): - self.policy = policy self.env_spec = env_spec self.env_spec_eval = env_spec_eval if env_spec_eval is not None else env_spec self.loggers = loggers @@ -43,21 +39,19 @@ class OnPolicy(ABC): self.batch_size = batch_size self.n_epochs = n_epochs self.gamma = gamma - self.gae_lambda = gae_lambda self.total_timesteps = total_timesteps self.eval_interval = eval_interval self.eval_deterministic = eval_deterministic self.entropy_coef = entropy_coef self.critic_coef = critic_coef self.normalize_advantage = normalize_advantage - self.clip_range = clip_range self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu") self.eval_episodes = eval_episodes # Create collector self.collector = SyncDataCollector( create_env_fn=lambda: self.make_env(eval=False), - policy=self.policy, + policy=self.actor, frames_per_batch=self.n_steps, total_frames=self.total_timesteps, device=self.device, @@ -78,13 +72,13 @@ class OnPolicy(ABC): env_spec = self.env_spec_eval if eval else self.env_spec if isinstance(env_spec, str): env = gym.make(env_spec) - env = GymWrapper(env) + env = GymWrapper(env).to(self.device) elif callable(env_spec): env = env_spec() if isinstance(env, gym.Env): - env = GymWrapper(env) + env = GymWrapper(env).to(self.device) elif isinstance(env, gym.Env): - env = GymWrapper(env) + env = GymWrapper(env).to(self.device) else: raise ValueError("env_spec must be a string or a callable that returns an environment.") return env @@ -92,7 +86,8 @@ class OnPolicy(ABC): def train_step(self, batch): for optimizer in self.optimizers.values(): optimizer.zero_grad() - loss = self.loss_module(batch) + losses = self.loss_module(batch) + loss = losses['loss_objective'] + losses["loss_entropy"] + losses["loss_critic"] loss.backward() for optimizer in self.optimizers.values(): optimizer.step() @@ -130,7 +125,7 @@ class OnPolicy(ABC): for _ in range(self.eval_episodes): with torch.no_grad(), set_exploration_type(ExplorationType.MODE): td_test = eval_env.rollout( - policy=self.policy, + policy=self.actor, auto_reset=True, auto_cast_to_device=True, break_when_any_done=True, diff --git a/fancy_rl/algos/ppo.py b/fancy_rl/algos/ppo.py index b70737e..d77f8b4 100644 --- a/fancy_rl/algos/ppo.py +++ b/fancy_rl/algos/ppo.py @@ -3,13 +3,13 @@ from torchrl.modules import ActorValueOperator, ProbabilisticActor from torchrl.objectives import ClipPPOLoss from torchrl.objectives.value.advantages import GAE from fancy_rl.algos.on_policy import OnPolicy -from fancy_rl.policy import Actor, Critic, SharedModule +from fancy_rl.policy import Actor, Critic class PPO(OnPolicy): def __init__( self, env_spec, - loggers=None, + loggers=[], actor_hidden_sizes=[64, 64], critic_hidden_sizes=[64, 64], actor_activation_fn="Tanh", @@ -33,6 +33,7 @@ class PPO(OnPolicy): eval_episodes=10, ): device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.device = device # Initialize environment to get observation and action space sizes self.env_spec = env_spec @@ -40,21 +41,10 @@ class PPO(OnPolicy): obs_space = env.observation_space act_space = env.action_space - # Define the shared, actor, and critic modules - self.shared_module = SharedModule(obs_space, shared_stem_sizes, actor_activation_fn, device) - self.actor = Actor(self.shared_module, act_space, actor_hidden_sizes, actor_activation_fn, device) - self.critic = Critic(self.shared_module, critic_hidden_sizes, critic_activation_fn, device) - - # Combine into an ActorValueOperator - self.ac_module = ActorValueOperator( - self.shared_module, - self.actor, - self.critic - ) - - # Define the policy as a ProbabilisticActor - policy = ProbabilisticActor( - module=self.ac_module.get_policy_operator(), + self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device) + actor_net = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device) + self.actor = ProbabilisticActor( + module=actor_net, in_keys=["loc", "scale"], out_keys=["action"], distribution_class=torch.distributions.Normal, @@ -67,7 +57,6 @@ class PPO(OnPolicy): } super().__init__( - policy=policy, env_spec=env_spec, loggers=loggers, optimizers=optimizers, @@ -76,14 +65,12 @@ class PPO(OnPolicy): batch_size=batch_size, n_epochs=n_epochs, gamma=gamma, - gae_lambda=gae_lambda, total_timesteps=total_timesteps, eval_interval=eval_interval, eval_deterministic=eval_deterministic, entropy_coef=entropy_coef, critic_coef=critic_coef, normalize_advantage=normalize_advantage, - clip_range=clip_range, device=device, env_spec_eval=env_spec_eval, eval_episodes=eval_episodes, @@ -91,7 +78,7 @@ class PPO(OnPolicy): self.adv_module = GAE( gamma=self.gamma, - lmbda=self.gae_lambda, + lmbda=gae_lambda, value_network=self.critic, average_gae=False, ) @@ -99,9 +86,9 @@ class PPO(OnPolicy): self.loss_module = ClipPPOLoss( actor_network=self.actor, critic_network=self.critic, - clip_epsilon=self.clip_range, - loss_critic_type='MSELoss', + clip_epsilon=clip_range, + loss_critic_type='l2', entropy_coef=self.entropy_coef, critic_coef=self.critic_coef, normalize_advantage=self.normalize_advantage, - ) \ No newline at end of file + ) diff --git a/fancy_rl/policy.py b/fancy_rl/policy.py index 51d708b..aa96b63 100644 --- a/fancy_rl/policy.py +++ b/fancy_rl/policy.py @@ -4,30 +4,8 @@ from torchrl.modules import MLP from tensordict.nn.distributions import NormalParamExtractor from fancy_rl.utils import is_discrete_space, get_space_shape -class SharedModule(TensorDictModule): - def __init__(self, obs_space, hidden_sizes, activation_fn, device): - if hidden_sizes: - shared_module = MLP( - in_features=get_space_shape(obs_space)[-1], - out_features=hidden_sizes[-1], - num_cells=hidden_sizes[:-1], - activation_class=getattr(nn, activation_fn), - device=device - ) - out_features = hidden_sizes[-1] - else: - shared_module = nn.Identity() - out_features = get_space_shape(obs_space)[-1] - - super().__init__( - module=shared_module, - in_keys=["observation"], - out_keys=["shared"], - ) - self.out_features = out_features - class Actor(TensorDictModule): - def __init__(self, shared_module, act_space, hidden_sizes, activation_fn, device): + def __init__(self, obs_space, act_space, hidden_sizes, activation_fn, device): act_space_shape = get_space_shape(act_space) if is_discrete_space(act_space): out_features = act_space_shape[-1] @@ -36,7 +14,7 @@ class Actor(TensorDictModule): actor_module = nn.Sequential( MLP( - in_features=shared_module.out_features, + in_features=get_space_shape(obs_space)[-1], out_features=out_features, num_cells=hidden_sizes, activation_class=getattr(nn, activation_fn), @@ -46,14 +24,14 @@ class Actor(TensorDictModule): ).to(device) super().__init__( module=actor_module, - in_keys=["shared"], + in_keys=["observation"], out_keys=["loc", "scale"] if not is_discrete_space(act_space) else ["action_logits"], ) class Critic(TensorDictModule): - def __init__(self, shared_module, hidden_sizes, activation_fn, device): + def __init__(self, obs_space, hidden_sizes, activation_fn, device): critic_module = MLP( - in_features=shared_module.out_features, + in_features=get_space_shape(obs_space)[-1], out_features=1, num_cells=hidden_sizes, activation_class=getattr(nn, activation_fn), @@ -61,6 +39,6 @@ class Critic(TensorDictModule): ).to(device) super().__init__( module=critic_module, - in_keys=["shared"], + in_keys=["observation"], out_keys=["state_value"], - ) + ) \ No newline at end of file