Fix: Issue with env wrapping (ensure batch dim)
This commit is contained in:
parent
52b3f3b71e
commit
8a078fb59e
@ -1,9 +1,12 @@
|
||||
import torch
|
||||
import gymnasium as gym
|
||||
from torchrl.envs.libs.gym import GymWrapper
|
||||
from torchrl.envs import GymEnv, TransformedEnv, Compose, RewardSum, StepCounter, ParallelEnv
|
||||
from torchrl.record import VideoRecorder
|
||||
from abc import ABC
|
||||
import pdb
|
||||
from tensordict import TensorDict
|
||||
from torchrl.envs import GymWrapper, TransformedEnv
|
||||
from torchrl.envs import BatchSizeTransform
|
||||
|
||||
from fancy_rl.loggers import TerminalLogger
|
||||
|
||||
@ -47,18 +50,33 @@ class Algo(ABC):
|
||||
self.eval_episodes = eval_episodes
|
||||
|
||||
def make_env(self, eval=False):
|
||||
"""Creates an environment and wraps it if necessary."""
|
||||
env_spec = self.env_spec_eval if eval else self.env_spec
|
||||
env = self._wrap_env(env_spec)
|
||||
env.reset()
|
||||
return env
|
||||
|
||||
def _wrap_env(self, env_spec):
|
||||
if isinstance(env_spec, str):
|
||||
env = gym.make(env_spec)
|
||||
env = GymWrapper(env).to(self.device)
|
||||
env = GymEnv(env_spec, device=self.device)
|
||||
elif isinstance(env_spec, gym.Env):
|
||||
env = GymWrapper(env_spec, device=self.device)
|
||||
elif isinstance(env_spec, GymEnv):
|
||||
env = env_spec
|
||||
elif callable(env_spec):
|
||||
env = env_spec()
|
||||
if not (isinstance(env, gym.Env) or isinstance(env, gym.core.Wrapper)):
|
||||
raise ValueError("env_spec must be a string or a callable that returns an environment. Was a callable that returned a {}".format(type(env)))
|
||||
env = GymWrapper(env).to(self.device)
|
||||
base_env = env_spec()
|
||||
return self._wrap_env(base_env)
|
||||
else:
|
||||
raise ValueError("env_spec must be a string or a callable that returns an environment. Was a {}".format(type(env_spec)))
|
||||
raise ValueError(
|
||||
f"env_spec must be a string, callable, Gymnasium environment, or GymEnv, "
|
||||
f"got {type(env_spec)}"
|
||||
)
|
||||
|
||||
if not env.batch_size:
|
||||
env = TransformedEnv(
|
||||
env,
|
||||
BatchSizeTransform(batch_size=torch.Size([1]))
|
||||
)
|
||||
|
||||
return env
|
||||
|
||||
def train_step(self, batch):
|
||||
@ -78,7 +96,7 @@ class Algo(ABC):
|
||||
):
|
||||
with torch.no_grad():
|
||||
obs_tensor = torch.as_tensor(observation, device=self.device).unsqueeze(0)
|
||||
td = TensorDict({"observation": obs_tensor}, batch_size=[1])
|
||||
td = TensorDict({"observation": obs_tensor})
|
||||
|
||||
action_td = self.prob_actor(td)
|
||||
action = action_td["action"]
|
||||
|
Loading…
Reference in New Issue
Block a user