Fix: Issue with env wrapping (ensure batch dim)

This commit is contained in:
Dominik Moritz Roth 2024-11-07 11:39:09 +01:00
parent 52b3f3b71e
commit 8a078fb59e

View File

@ -1,9 +1,12 @@
import torch import torch
import gymnasium as gym import gymnasium as gym
from torchrl.envs.libs.gym import GymWrapper from torchrl.envs import GymEnv, TransformedEnv, Compose, RewardSum, StepCounter, ParallelEnv
from torchrl.record import VideoRecorder from torchrl.record import VideoRecorder
from abc import ABC from abc import ABC
import pdb
from tensordict import TensorDict from tensordict import TensorDict
from torchrl.envs import GymWrapper, TransformedEnv
from torchrl.envs import BatchSizeTransform
from fancy_rl.loggers import TerminalLogger from fancy_rl.loggers import TerminalLogger
@ -47,18 +50,33 @@ class Algo(ABC):
self.eval_episodes = eval_episodes self.eval_episodes = eval_episodes
def make_env(self, eval=False): def make_env(self, eval=False):
"""Creates an environment and wraps it if necessary."""
env_spec = self.env_spec_eval if eval else self.env_spec env_spec = self.env_spec_eval if eval else self.env_spec
env = self._wrap_env(env_spec)
env.reset()
return env
def _wrap_env(self, env_spec):
if isinstance(env_spec, str): if isinstance(env_spec, str):
env = gym.make(env_spec) env = GymEnv(env_spec, device=self.device)
env = GymWrapper(env).to(self.device) elif isinstance(env_spec, gym.Env):
env = GymWrapper(env_spec, device=self.device)
elif isinstance(env_spec, GymEnv):
env = env_spec
elif callable(env_spec): elif callable(env_spec):
env = env_spec() base_env = env_spec()
if not (isinstance(env, gym.Env) or isinstance(env, gym.core.Wrapper)): return self._wrap_env(base_env)
raise ValueError("env_spec must be a string or a callable that returns an environment. Was a callable that returned a {}".format(type(env)))
env = GymWrapper(env).to(self.device)
else: else:
raise ValueError("env_spec must be a string or a callable that returns an environment. Was a {}".format(type(env_spec))) raise ValueError(
f"env_spec must be a string, callable, Gymnasium environment, or GymEnv, "
f"got {type(env_spec)}"
)
if not env.batch_size:
env = TransformedEnv(
env,
BatchSizeTransform(batch_size=torch.Size([1]))
)
return env return env
def train_step(self, batch): def train_step(self, batch):
@ -78,7 +96,7 @@ class Algo(ABC):
): ):
with torch.no_grad(): with torch.no_grad():
obs_tensor = torch.as_tensor(observation, device=self.device).unsqueeze(0) obs_tensor = torch.as_tensor(observation, device=self.device).unsqueeze(0)
td = TensorDict({"observation": obs_tensor}, batch_size=[1]) td = TensorDict({"observation": obs_tensor})
action_td = self.prob_actor(td) action_td = self.prob_actor(td)
action = action_td["action"] action = action_td["action"]