Fix: Issue with env wrapping (ensure batch dim)
This commit is contained in:
parent
52b3f3b71e
commit
8a078fb59e
@ -1,9 +1,12 @@
|
|||||||
import torch
|
import torch
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
from torchrl.envs.libs.gym import GymWrapper
|
from torchrl.envs import GymEnv, TransformedEnv, Compose, RewardSum, StepCounter, ParallelEnv
|
||||||
from torchrl.record import VideoRecorder
|
from torchrl.record import VideoRecorder
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
|
import pdb
|
||||||
from tensordict import TensorDict
|
from tensordict import TensorDict
|
||||||
|
from torchrl.envs import GymWrapper, TransformedEnv
|
||||||
|
from torchrl.envs import BatchSizeTransform
|
||||||
|
|
||||||
from fancy_rl.loggers import TerminalLogger
|
from fancy_rl.loggers import TerminalLogger
|
||||||
|
|
||||||
@ -47,18 +50,33 @@ class Algo(ABC):
|
|||||||
self.eval_episodes = eval_episodes
|
self.eval_episodes = eval_episodes
|
||||||
|
|
||||||
def make_env(self, eval=False):
|
def make_env(self, eval=False):
|
||||||
"""Creates an environment and wraps it if necessary."""
|
|
||||||
env_spec = self.env_spec_eval if eval else self.env_spec
|
env_spec = self.env_spec_eval if eval else self.env_spec
|
||||||
|
env = self._wrap_env(env_spec)
|
||||||
|
env.reset()
|
||||||
|
return env
|
||||||
|
|
||||||
|
def _wrap_env(self, env_spec):
|
||||||
if isinstance(env_spec, str):
|
if isinstance(env_spec, str):
|
||||||
env = gym.make(env_spec)
|
env = GymEnv(env_spec, device=self.device)
|
||||||
env = GymWrapper(env).to(self.device)
|
elif isinstance(env_spec, gym.Env):
|
||||||
|
env = GymWrapper(env_spec, device=self.device)
|
||||||
|
elif isinstance(env_spec, GymEnv):
|
||||||
|
env = env_spec
|
||||||
elif callable(env_spec):
|
elif callable(env_spec):
|
||||||
env = env_spec()
|
base_env = env_spec()
|
||||||
if not (isinstance(env, gym.Env) or isinstance(env, gym.core.Wrapper)):
|
return self._wrap_env(base_env)
|
||||||
raise ValueError("env_spec must be a string or a callable that returns an environment. Was a callable that returned a {}".format(type(env)))
|
|
||||||
env = GymWrapper(env).to(self.device)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("env_spec must be a string or a callable that returns an environment. Was a {}".format(type(env_spec)))
|
raise ValueError(
|
||||||
|
f"env_spec must be a string, callable, Gymnasium environment, or GymEnv, "
|
||||||
|
f"got {type(env_spec)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not env.batch_size:
|
||||||
|
env = TransformedEnv(
|
||||||
|
env,
|
||||||
|
BatchSizeTransform(batch_size=torch.Size([1]))
|
||||||
|
)
|
||||||
|
|
||||||
return env
|
return env
|
||||||
|
|
||||||
def train_step(self, batch):
|
def train_step(self, batch):
|
||||||
@ -78,7 +96,7 @@ class Algo(ABC):
|
|||||||
):
|
):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
obs_tensor = torch.as_tensor(observation, device=self.device).unsqueeze(0)
|
obs_tensor = torch.as_tensor(observation, device=self.device).unsqueeze(0)
|
||||||
td = TensorDict({"observation": obs_tensor}, batch_size=[1])
|
td = TensorDict({"observation": obs_tensor})
|
||||||
|
|
||||||
action_td = self.prob_actor(td)
|
action_td = self.prob_actor(td)
|
||||||
action = action_td["action"]
|
action = action_td["action"]
|
||||||
|
Loading…
Reference in New Issue
Block a user