Better handling of envs without defined max_steps

This commit is contained in:
Dominik Moritz Roth 2023-06-18 14:23:59 +02:00
parent 60a4cf11d6
commit b032dec5fe

View File

@ -8,9 +8,10 @@ from typing import Iterable, Type, Union, Optional
import gymnasium as gym import gymnasium as gym
import numpy as np import numpy as np
from gymnasium.envs.registration import register, registry from gymnasium.envs.registration import register, registry
from gymnasium.wrappers import FlattenObservation from gymnasium.wrappers import TimeLimit
from fancy_gym.utils.env_compatibility import EnvCompatibility from fancy_gym.utils.env_compatibility import EnvCompatibility
from fancy_gym.utils.wrappers import FlattenObservation
try: try:
from dm_control import suite, manipulation from dm_control import suite, manipulation
@ -31,7 +32,7 @@ from fancy_gym.black_box.factory.controller_factory import get_controller
from fancy_gym.black_box.factory.phase_generator_factory import get_phase_generator from fancy_gym.black_box.factory.phase_generator_factory import get_phase_generator
from fancy_gym.black_box.factory.trajectory_generator_factory import get_trajectory_generator from fancy_gym.black_box.factory.trajectory_generator_factory import get_trajectory_generator
from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper
from fancy_gym.utils.time_aware_observation import TimeAwareObservation from fancy_gym.utils.wrappers import TimeAwareObservation
from fancy_gym.utils.utils import nested_update from fancy_gym.utils.utils import nested_update
@ -114,7 +115,7 @@ def make(env_id: str, seed: int, **kwargs):
return env return env
def _make_wrapped_env(env_id: str, wrappers: Iterable[Type[gym.Wrapper]], seed=1, **kwargs): def _make_wrapped_env(env_id: str, wrappers: Iterable[Type[gym.Wrapper]], seed=1, fallback_max_steps=None, **kwargs):
""" """
Helper function for creating a wrapped gym environment using MPs. Helper function for creating a wrapped gym environment using MPs.
It adds all provided wrappers to the specified environment and verifies at least one RawInterfaceWrapper is It adds all provided wrappers to the specified environment and verifies at least one RawInterfaceWrapper is
@ -130,6 +131,8 @@ def _make_wrapped_env(env_id: str, wrappers: Iterable[Type[gym.Wrapper]], seed=1
""" """
# _env = gym.make(env_id) # _env = gym.make(env_id)
_env = make(env_id, seed, **kwargs) _env = make(env_id, seed, **kwargs)
if fallback_max_steps:
_env = ensure_finite_time(_env, fallback_max_steps)
has_black_box_wrapper = False has_black_box_wrapper = False
for w in wrappers: for w in wrappers:
# only wrap the environment if not BlackBoxWrapper, e.g. for vision # only wrap the environment if not BlackBoxWrapper, e.g. for vision
@ -144,7 +147,7 @@ def _make_wrapped_env(env_id: str, wrappers: Iterable[Type[gym.Wrapper]], seed=1
def make_bb( def make_bb(
env_id: str, wrappers: Iterable, black_box_kwargs: MutableMapping, traj_gen_kwargs: MutableMapping, env_id: str, wrappers: Iterable, black_box_kwargs: MutableMapping, traj_gen_kwargs: MutableMapping,
controller_kwargs: MutableMapping, phase_kwargs: MutableMapping, basis_kwargs: MutableMapping, seed: int = 1, controller_kwargs: MutableMapping, phase_kwargs: MutableMapping, basis_kwargs: MutableMapping, seed: int = 1,
**kwargs): fallback_max_steps: int = None, **kwargs):
""" """
This can also be used standalone for manually building a custom DMP environment. This can also be used standalone for manually building a custom DMP environment.
Args: Args:
@ -172,7 +175,7 @@ def make_bb(
# Add as first wrapper in order to alter observation # Add as first wrapper in order to alter observation
wrappers.insert(0, TimeAwareObservation) wrappers.insert(0, TimeAwareObservation)
env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs) env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, fallback_max_steps=fallback_max_steps, **kwargs)
# BB expects a spaces.Box to be exposed, need to convert for dict-observations # BB expects a spaces.Box to be exposed, need to convert for dict-observations
if type(env.observation_space) == gym.spaces.dict.Dict: if type(env.observation_space) == gym.spaces.dict.Dict:
@ -209,6 +212,15 @@ def make_bb(
return bb_env return bb_env
def ensure_finite_time(env: gym.Env, fallback_max_steps=500):
cur_limit = env.spec.max_episode_steps
if not cur_limit:
if hasattr(env.unwrapped, 'max_path_length'):
return TimeLimit(env, env.unwrapped.__getattribute__('max_path_length'))
return TimeLimit(env, fallback_max_steps)
return env
def get_env_duration(env: gym.Env): def get_env_duration(env: gym.Env):
try: try:
duration = env.spec.max_episode_steps * env.dt duration = env.spec.max_episode_steps * env.dt