From b032dec5fe597d5802adcb71dfc72119eb2bb128 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sun, 18 Jun 2023 14:23:59 +0200 Subject: [PATCH] Better handling of envs without defined max_steps --- fancy_gym/utils/make_env_helpers.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/fancy_gym/utils/make_env_helpers.py b/fancy_gym/utils/make_env_helpers.py index 7f1878e..848c083 100644 --- a/fancy_gym/utils/make_env_helpers.py +++ b/fancy_gym/utils/make_env_helpers.py @@ -8,9 +8,10 @@ from typing import Iterable, Type, Union, Optional import gymnasium as gym import numpy as np from gymnasium.envs.registration import register, registry -from gymnasium.wrappers import FlattenObservation +from gymnasium.wrappers import TimeLimit from fancy_gym.utils.env_compatibility import EnvCompatibility +from fancy_gym.utils.wrappers import FlattenObservation try: from dm_control import suite, manipulation @@ -31,7 +32,7 @@ from fancy_gym.black_box.factory.controller_factory import get_controller from fancy_gym.black_box.factory.phase_generator_factory import get_phase_generator from fancy_gym.black_box.factory.trajectory_generator_factory import get_trajectory_generator from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper -from fancy_gym.utils.time_aware_observation import TimeAwareObservation +from fancy_gym.utils.wrappers import TimeAwareObservation from fancy_gym.utils.utils import nested_update @@ -114,7 +115,7 @@ def make(env_id: str, seed: int, **kwargs): return env -def _make_wrapped_env(env_id: str, wrappers: Iterable[Type[gym.Wrapper]], seed=1, **kwargs): +def _make_wrapped_env(env_id: str, wrappers: Iterable[Type[gym.Wrapper]], seed=1, fallback_max_steps=None, **kwargs): """ Helper function for creating a wrapped gym environment using MPs. It adds all provided wrappers to the specified environment and verifies at least one RawInterfaceWrapper is @@ -130,6 +131,8 @@ def _make_wrapped_env(env_id: str, wrappers: Iterable[Type[gym.Wrapper]], seed=1 """ # _env = gym.make(env_id) _env = make(env_id, seed, **kwargs) + if fallback_max_steps: + _env = ensure_finite_time(_env, fallback_max_steps) has_black_box_wrapper = False for w in wrappers: # only wrap the environment if not BlackBoxWrapper, e.g. for vision @@ -144,7 +147,7 @@ def _make_wrapped_env(env_id: str, wrappers: Iterable[Type[gym.Wrapper]], seed=1 def make_bb( env_id: str, wrappers: Iterable, black_box_kwargs: MutableMapping, traj_gen_kwargs: MutableMapping, controller_kwargs: MutableMapping, phase_kwargs: MutableMapping, basis_kwargs: MutableMapping, seed: int = 1, - **kwargs): + fallback_max_steps: int = None, **kwargs): """ This can also be used standalone for manually building a custom DMP environment. Args: @@ -172,7 +175,7 @@ def make_bb( # Add as first wrapper in order to alter observation wrappers.insert(0, TimeAwareObservation) - env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs) + env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, fallback_max_steps=fallback_max_steps, **kwargs) # BB expects a spaces.Box to be exposed, need to convert for dict-observations if type(env.observation_space) == gym.spaces.dict.Dict: @@ -209,6 +212,15 @@ def make_bb( return bb_env +def ensure_finite_time(env: gym.Env, fallback_max_steps=500): + cur_limit = env.spec.max_episode_steps + if not cur_limit: + if hasattr(env.unwrapped, 'max_path_length'): + return TimeLimit(env, env.unwrapped.__getattribute__('max_path_length')) + return TimeLimit(env, fallback_max_steps) + return env + + def get_env_duration(env: gym.Env): try: duration = env.spec.max_episode_steps * env.dt