From 8a078fb59e90563972c36e6792b05be011793b4e Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Thu, 7 Nov 2024 11:39:09 +0100 Subject: [PATCH] Fix: Issue with env wrapping (ensure batch dim) --- fancy_rl/algos/algo.py | 38 ++++++++++++++++++++++++++++---------- 1 file changed, 28 insertions(+), 10 deletions(-) diff --git a/fancy_rl/algos/algo.py b/fancy_rl/algos/algo.py index 90eeee0..ce3a477 100644 --- a/fancy_rl/algos/algo.py +++ b/fancy_rl/algos/algo.py @@ -1,9 +1,12 @@ import torch import gymnasium as gym -from torchrl.envs.libs.gym import GymWrapper +from torchrl.envs import GymEnv, TransformedEnv, Compose, RewardSum, StepCounter, ParallelEnv from torchrl.record import VideoRecorder from abc import ABC +import pdb from tensordict import TensorDict +from torchrl.envs import GymWrapper, TransformedEnv +from torchrl.envs import BatchSizeTransform from fancy_rl.loggers import TerminalLogger @@ -47,18 +50,33 @@ class Algo(ABC): self.eval_episodes = eval_episodes def make_env(self, eval=False): - """Creates an environment and wraps it if necessary.""" env_spec = self.env_spec_eval if eval else self.env_spec + env = self._wrap_env(env_spec) + env.reset() + return env + + def _wrap_env(self, env_spec): if isinstance(env_spec, str): - env = gym.make(env_spec) - env = GymWrapper(env).to(self.device) + env = GymEnv(env_spec, device=self.device) + elif isinstance(env_spec, gym.Env): + env = GymWrapper(env_spec, device=self.device) + elif isinstance(env_spec, GymEnv): + env = env_spec elif callable(env_spec): - env = env_spec() - if not (isinstance(env, gym.Env) or isinstance(env, gym.core.Wrapper)): - raise ValueError("env_spec must be a string or a callable that returns an environment. Was a callable that returned a {}".format(type(env))) - env = GymWrapper(env).to(self.device) + base_env = env_spec() + return self._wrap_env(base_env) else: - raise ValueError("env_spec must be a string or a callable that returns an environment. Was a {}".format(type(env_spec))) + raise ValueError( + f"env_spec must be a string, callable, Gymnasium environment, or GymEnv, " + f"got {type(env_spec)}" + ) + + if not env.batch_size: + env = TransformedEnv( + env, + BatchSizeTransform(batch_size=torch.Size([1])) + ) + return env def train_step(self, batch): @@ -78,7 +96,7 @@ class Algo(ABC): ): with torch.no_grad(): obs_tensor = torch.as_tensor(observation, device=self.device).unsqueeze(0) - td = TensorDict({"observation": obs_tensor}, batch_size=[1]) + td = TensorDict({"observation": obs_tensor}) action_td = self.prob_actor(td) action = action_td["action"]