Fix: Issue with env wrapping (ensure batch dim)

This commit is contained in:
Dominik Moritz Roth 2024-11-07 11:39:09 +01:00
parent 52b3f3b71e
commit 8a078fb59e

View File

@ -1,9 +1,12 @@
import torch
import gymnasium as gym
from torchrl.envs.libs.gym import GymWrapper
from torchrl.envs import GymEnv, TransformedEnv, Compose, RewardSum, StepCounter, ParallelEnv
from torchrl.record import VideoRecorder
from abc import ABC
import pdb
from tensordict import TensorDict
from torchrl.envs import GymWrapper, TransformedEnv
from torchrl.envs import BatchSizeTransform
from fancy_rl.loggers import TerminalLogger
@ -47,18 +50,33 @@ class Algo(ABC):
self.eval_episodes = eval_episodes
def make_env(self, eval=False):
"""Creates an environment and wraps it if necessary."""
env_spec = self.env_spec_eval if eval else self.env_spec
env = self._wrap_env(env_spec)
env.reset()
return env
def _wrap_env(self, env_spec):
if isinstance(env_spec, str):
env = gym.make(env_spec)
env = GymWrapper(env).to(self.device)
env = GymEnv(env_spec, device=self.device)
elif isinstance(env_spec, gym.Env):
env = GymWrapper(env_spec, device=self.device)
elif isinstance(env_spec, GymEnv):
env = env_spec
elif callable(env_spec):
env = env_spec()
if not (isinstance(env, gym.Env) or isinstance(env, gym.core.Wrapper)):
raise ValueError("env_spec must be a string or a callable that returns an environment. Was a callable that returned a {}".format(type(env)))
env = GymWrapper(env).to(self.device)
base_env = env_spec()
return self._wrap_env(base_env)
else:
raise ValueError("env_spec must be a string or a callable that returns an environment. Was a {}".format(type(env_spec)))
raise ValueError(
f"env_spec must be a string, callable, Gymnasium environment, or GymEnv, "
f"got {type(env_spec)}"
)
if not env.batch_size:
env = TransformedEnv(
env,
BatchSizeTransform(batch_size=torch.Size([1]))
)
return env
def train_step(self, batch):
@ -78,7 +96,7 @@ class Algo(ABC):
):
with torch.no_grad():
obs_tensor = torch.as_tensor(observation, device=self.device).unsqueeze(0)
td = TensorDict({"observation": obs_tensor}, batch_size=[1])
td = TensorDict({"observation": obs_tensor})
action_td = self.prob_actor(td)
action = action_td["action"]